mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 09:10:25 +08:00
fix(rtrvc): skip head unimplemented
This commit is contained in:
@@ -138,7 +138,7 @@ class RVC:
|
||||
self,
|
||||
input_wav: torch.Tensor,
|
||||
block_frame_16k: int,
|
||||
skip_head: torch.Tensor,
|
||||
skip_head: int,
|
||||
return_length: int,
|
||||
f0method: Union[tuple, str],
|
||||
inp_f0: Optional[np.ndarray] = None,
|
||||
@@ -241,8 +241,6 @@ class RVC:
|
||||
feats = feats.to(feats0.dtype)
|
||||
p_len = torch.LongTensor([p_len]).to(self.device)
|
||||
sid = torch.LongTensor([0]).to(self.device)
|
||||
skip_head = torch.LongTensor([skip_head])
|
||||
return_length = torch.LongTensor([return_length])
|
||||
with torch.no_grad():
|
||||
infered_audio = (
|
||||
self.net_g.infer(
|
||||
@@ -253,6 +251,7 @@ class RVC:
|
||||
pitchf=cache_pitchf,
|
||||
skip_head=skip_head,
|
||||
return_length=return_length,
|
||||
return_length2=return_length2,
|
||||
)
|
||||
.squeeze(1)
|
||||
.float()
|
||||
|
||||
@@ -123,13 +123,13 @@ class TextEncoder(nn.Module):
|
||||
phone: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
# skip_head: Optional[torch.Tensor] = None,
|
||||
skip_head: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return super().__call__(
|
||||
phone,
|
||||
pitch,
|
||||
lengths,
|
||||
# skip_head=skip_head,
|
||||
skip_head=skip_head,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -137,7 +137,7 @@ class TextEncoder(nn.Module):
|
||||
phone: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
# skip_head: Optional[torch.Tensor] = None,
|
||||
skip_head: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = self.emb_phone(phone)
|
||||
if pitch is not None:
|
||||
@@ -150,13 +150,10 @@ class TextEncoder(nn.Module):
|
||||
1,
|
||||
).to(x.dtype)
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
"""
|
||||
if skip_head is not None:
|
||||
assert isinstance(skip_head, torch.Tensor)
|
||||
head = int(skip_head.item())
|
||||
head = int(skip_head)
|
||||
x = x[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
"""
|
||||
stats: torch.Tensor = self.proj(x) * x_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return m, logs, x_mask
|
||||
|
||||
@@ -61,23 +61,21 @@ class Generator(torch.nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
# n_res: Optional[torch.Tensor] = None,
|
||||
n_res: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return super().__call__(x, g=g)
|
||||
return super().__call__(x, g=g, n_res=n_res)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
# n_res: Optional[torch.Tensor] = None,
|
||||
n_res: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
if n_res is not None:
|
||||
assert isinstance(n_res, torch.Tensor)
|
||||
n = int(n_res.item())
|
||||
n = int(n_res)
|
||||
if n != x.shape[-1]:
|
||||
x = F.interpolate(x, size=n, mode="linear")
|
||||
"""
|
||||
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
@@ -136,28 +136,27 @@ class NSFGenerator(torch.nn.Module):
|
||||
x: torch.Tensor,
|
||||
f0: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
# n_res: Optional[torch.Tensor] = None,
|
||||
n_res: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return super().__call__(x, f0, g=g)
|
||||
return super().__call__(x, f0, g=g, n_res=n_res)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
f0: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
# n_res: Optional[torch.Tensor] = None,
|
||||
n_res: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
har_source = self.m_source(f0, self.upp)
|
||||
har_source = har_source.transpose(1, 2)
|
||||
"""
|
||||
|
||||
if n_res is not None:
|
||||
assert isinstance(n_res, torch.Tensor)
|
||||
n = int(n_res.item())
|
||||
if n * self.upp != har_source.shape[-1]:
|
||||
har_source = F.interpolate(har_source, size=n * self.upp, mode="linear")
|
||||
if n != x.shape[-1]:
|
||||
x = F.interpolate(x, size=n, mode="linear")
|
||||
"""
|
||||
n_res = int(n_res)
|
||||
if n_res * self.upp != har_source.shape[-1]:
|
||||
har_source = F.interpolate(har_source, size=n_res * self.upp, mode="linear")
|
||||
if n_res != x.shape[-1]:
|
||||
x = F.interpolate(x, size=n_res, mode="linear")
|
||||
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
@@ -177,17 +177,18 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
sid: torch.Tensor,
|
||||
pitch: Optional[torch.Tensor] = None,
|
||||
pitchf: Optional[torch.Tensor] = None, # nsff0
|
||||
skip_head: Optional[torch.Tensor] = None,
|
||||
return_length: Optional[torch.Tensor] = None,
|
||||
# return_length2: Optional[torch.Tensor] = None,
|
||||
skip_head: Optional[int] = None,
|
||||
return_length: Optional[int] = None,
|
||||
return_length2: Optional[int] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
if skip_head is not None and return_length is not None:
|
||||
head = int(skip_head.item())
|
||||
length = int(return_length.item())
|
||||
flow_head = torch.clamp(skip_head - 24, min=0)
|
||||
dec_head = head - int(flow_head.item())
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
head = int(skip_head)
|
||||
length = int(return_length)
|
||||
flow_head = head - 24
|
||||
if flow_head < 0: flow_head = 0
|
||||
dec_head = head - flow_head
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, head)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
z = z[:, :, dec_head : dec_head + length]
|
||||
@@ -204,13 +205,13 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
z * x_mask,
|
||||
pitchf,
|
||||
g=g,
|
||||
# n_res=return_length2,
|
||||
n_res=return_length2,
|
||||
)
|
||||
else:
|
||||
o = self.dec(
|
||||
z * x_mask,
|
||||
g=g,
|
||||
# n_res=return_length2
|
||||
n_res=return_length2
|
||||
)
|
||||
del x_mask, z
|
||||
return o # , x_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
Reference in New Issue
Block a user