1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-08 12:00:49 +08:00

fix(rtrvc): skip head unimplemented

This commit is contained in:
源文雨
2024-06-16 16:46:59 +09:00
parent df83554ac1
commit 0d5cd347bc
5 changed files with 32 additions and 38 deletions

View File

@@ -123,13 +123,13 @@ class TextEncoder(nn.Module):
phone: torch.Tensor,
pitch: torch.Tensor,
lengths: torch.Tensor,
# skip_head: Optional[torch.Tensor] = None,
skip_head: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return super().__call__(
phone,
pitch,
lengths,
# skip_head=skip_head,
skip_head=skip_head,
)
def forward(
@@ -137,7 +137,7 @@ class TextEncoder(nn.Module):
phone: torch.Tensor,
pitch: torch.Tensor,
lengths: torch.Tensor,
# skip_head: Optional[torch.Tensor] = None,
skip_head: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = self.emb_phone(phone)
if pitch is not None:
@@ -150,13 +150,10 @@ class TextEncoder(nn.Module):
1,
).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
"""
if skip_head is not None:
assert isinstance(skip_head, torch.Tensor)
head = int(skip_head.item())
head = int(skip_head)
x = x[:, :, head:]
x_mask = x_mask[:, :, head:]
"""
stats: torch.Tensor = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask