mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-08 12:00:49 +08:00
fix(rtrvc): skip head unimplemented
This commit is contained in:
@@ -123,13 +123,13 @@ class TextEncoder(nn.Module):
|
||||
phone: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
# skip_head: Optional[torch.Tensor] = None,
|
||||
skip_head: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return super().__call__(
|
||||
phone,
|
||||
pitch,
|
||||
lengths,
|
||||
# skip_head=skip_head,
|
||||
skip_head=skip_head,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -137,7 +137,7 @@ class TextEncoder(nn.Module):
|
||||
phone: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
# skip_head: Optional[torch.Tensor] = None,
|
||||
skip_head: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = self.emb_phone(phone)
|
||||
if pitch is not None:
|
||||
@@ -150,13 +150,10 @@ class TextEncoder(nn.Module):
|
||||
1,
|
||||
).to(x.dtype)
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
"""
|
||||
if skip_head is not None:
|
||||
assert isinstance(skip_head, torch.Tensor)
|
||||
head = int(skip_head.item())
|
||||
head = int(skip_head)
|
||||
x = x[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
"""
|
||||
stats: torch.Tensor = self.proj(x) * x_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return m, logs, x_mask
|
||||
|
||||
Reference in New Issue
Block a user