1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

fix(rtrvc): skip head unimplemented

This commit is contained in:
源文雨
2024-06-16 16:46:59 +09:00
parent df83554ac1
commit 0d5cd347bc
5 changed files with 32 additions and 38 deletions

View File

@@ -138,7 +138,7 @@ class RVC:
self,
input_wav: torch.Tensor,
block_frame_16k: int,
skip_head: torch.Tensor,
skip_head: int,
return_length: int,
f0method: Union[tuple, str],
inp_f0: Optional[np.ndarray] = None,
@@ -241,8 +241,6 @@ class RVC:
feats = feats.to(feats0.dtype)
p_len = torch.LongTensor([p_len]).to(self.device)
sid = torch.LongTensor([0]).to(self.device)
skip_head = torch.LongTensor([skip_head])
return_length = torch.LongTensor([return_length])
with torch.no_grad():
infered_audio = (
self.net_g.infer(
@@ -253,6 +251,7 @@ class RVC:
pitchf=cache_pitchf,
skip_head=skip_head,
return_length=return_length,
return_length2=return_length2,
)
.squeeze(1)
.float()

View File

@@ -123,13 +123,13 @@ class TextEncoder(nn.Module):
phone: torch.Tensor,
pitch: torch.Tensor,
lengths: torch.Tensor,
# skip_head: Optional[torch.Tensor] = None,
skip_head: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return super().__call__(
phone,
pitch,
lengths,
# skip_head=skip_head,
skip_head=skip_head,
)
def forward(
@@ -137,7 +137,7 @@ class TextEncoder(nn.Module):
phone: torch.Tensor,
pitch: torch.Tensor,
lengths: torch.Tensor,
# skip_head: Optional[torch.Tensor] = None,
skip_head: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = self.emb_phone(phone)
if pitch is not None:
@@ -150,13 +150,10 @@ class TextEncoder(nn.Module):
1,
).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
"""
if skip_head is not None:
assert isinstance(skip_head, torch.Tensor)
head = int(skip_head.item())
head = int(skip_head)
x = x[:, :, head:]
x_mask = x_mask[:, :, head:]
"""
stats: torch.Tensor = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask

View File

@@ -61,23 +61,21 @@ class Generator(torch.nn.Module):
self,
x: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
n_res: Optional[int] = None,
) -> torch.Tensor:
return super().__call__(x, g=g)
return super().__call__(x, g=g, n_res=n_res)
def forward(
self,
x: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
n_res: Optional[int] = None,
):
"""
if n_res is not None:
assert isinstance(n_res, torch.Tensor)
n = int(n_res.item())
n = int(n_res)
if n != x.shape[-1]:
x = F.interpolate(x, size=n, mode="linear")
"""
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)

View File

@@ -136,28 +136,27 @@ class NSFGenerator(torch.nn.Module):
x: torch.Tensor,
f0: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
n_res: Optional[int] = None,
) -> torch.Tensor:
return super().__call__(x, f0, g=g)
return super().__call__(x, f0, g=g, n_res=n_res)
def forward(
self,
x: torch.Tensor,
f0: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
n_res: Optional[int] = None,
) -> torch.Tensor:
har_source = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2)
"""
if n_res is not None:
assert isinstance(n_res, torch.Tensor)
n = int(n_res.item())
if n * self.upp != har_source.shape[-1]:
har_source = F.interpolate(har_source, size=n * self.upp, mode="linear")
if n != x.shape[-1]:
x = F.interpolate(x, size=n, mode="linear")
"""
n_res = int(n_res)
if n_res * self.upp != har_source.shape[-1]:
har_source = F.interpolate(har_source, size=n_res * self.upp, mode="linear")
if n_res != x.shape[-1]:
x = F.interpolate(x, size=n_res, mode="linear")
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)

View File

@@ -177,17 +177,18 @@ class SynthesizerTrnMsNSFsid(nn.Module):
sid: torch.Tensor,
pitch: Optional[torch.Tensor] = None,
pitchf: Optional[torch.Tensor] = None, # nsff0
skip_head: Optional[torch.Tensor] = None,
return_length: Optional[torch.Tensor] = None,
# return_length2: Optional[torch.Tensor] = None,
skip_head: Optional[int] = None,
return_length: Optional[int] = None,
return_length2: Optional[int] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
if skip_head is not None and return_length is not None:
head = int(skip_head.item())
length = int(return_length.item())
flow_head = torch.clamp(skip_head - 24, min=0)
dec_head = head - int(flow_head.item())
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
head = int(skip_head)
length = int(return_length)
flow_head = head - 24
if flow_head < 0: flow_head = 0
dec_head = head - flow_head
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, head)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
z = z[:, :, dec_head : dec_head + length]
@@ -204,13 +205,13 @@ class SynthesizerTrnMsNSFsid(nn.Module):
z * x_mask,
pitchf,
g=g,
# n_res=return_length2,
n_res=return_length2,
)
else:
o = self.dec(
z * x_mask,
g=g,
# n_res=return_length2
n_res=return_length2
)
del x_mask, z
return o # , x_mask, (z, z_p, m_p, logs_p)