1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

fix(infer): argument mismatch

This commit is contained in:
源文雨
2024-06-10 21:50:43 +09:00
parent fe7a2bf41a
commit 87433c5bd9

View File

@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Union
import torch
from torch import nn
@@ -33,11 +33,11 @@ class SynthesizerTrnMsNSFsid(nn.Module):
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr: Optional[str | int],
sr: Optional[Union[str, int]],
encoder_dim: int,
use_f0: bool,
):
super(SynthesizerTrnMs256NSFsid, self).__init__()
super().__init__()
if isinstance(sr, str):
sr = {
"32k": 32000,
@@ -175,7 +175,7 @@ class SynthesizerTrnMsNSFsid(nn.Module):
phone_lengths: torch.Tensor,
sid: torch.Tensor,
pitch: Optional[torch.Tensor] = None,
nsff0: Optional[torch.Tensor] = None,
pitchf: Optional[torch.Tensor] = None, # nsff0
skip_head: Optional[torch.Tensor] = None,
return_length: Optional[torch.Tensor] = None,
# return_length2: Optional[torch.Tensor] = None,
@@ -191,17 +191,17 @@ class SynthesizerTrnMsNSFsid(nn.Module):
z = self.flow(z_p, x_mask, g=g, reverse=True)
z = z[:, :, dec_head : dec_head + length]
x_mask = x_mask[:, :, dec_head : dec_head + length]
if nsff0 is not None:
nsff0 = nsff0[:, head : head + length]
if pitchf is not None:
pitchf = pitchf[:, head : head + length]
else:
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
del z_p, m_p, logs_p
if nsff0 is not None:
if pitchf is not None:
o = self.dec(
z * x_mask,
nsff0,
pitchf,
g=g,
# n_res=return_length2,
)
@@ -235,7 +235,7 @@ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr: str | int,
sr: Union[str, int],
):
super().__init__(
spec_channels,
@@ -281,7 +281,7 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr: str | int,
sr: Union[str, int],
):
super().__init__(
spec_channels,