mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-07 19:00:23 +08:00
fix(infer): argument mismatch
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, List
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -33,11 +33,11 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
|||||||
upsample_kernel_sizes: List[int],
|
upsample_kernel_sizes: List[int],
|
||||||
spk_embed_dim: int,
|
spk_embed_dim: int,
|
||||||
gin_channels: int,
|
gin_channels: int,
|
||||||
sr: Optional[str | int],
|
sr: Optional[Union[str, int]],
|
||||||
encoder_dim: int,
|
encoder_dim: int,
|
||||||
use_f0: bool,
|
use_f0: bool,
|
||||||
):
|
):
|
||||||
super(SynthesizerTrnMs256NSFsid, self).__init__()
|
super().__init__()
|
||||||
if isinstance(sr, str):
|
if isinstance(sr, str):
|
||||||
sr = {
|
sr = {
|
||||||
"32k": 32000,
|
"32k": 32000,
|
||||||
@@ -175,7 +175,7 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
|||||||
phone_lengths: torch.Tensor,
|
phone_lengths: torch.Tensor,
|
||||||
sid: torch.Tensor,
|
sid: torch.Tensor,
|
||||||
pitch: Optional[torch.Tensor] = None,
|
pitch: Optional[torch.Tensor] = None,
|
||||||
nsff0: Optional[torch.Tensor] = None,
|
pitchf: Optional[torch.Tensor] = None, # nsff0
|
||||||
skip_head: Optional[torch.Tensor] = None,
|
skip_head: Optional[torch.Tensor] = None,
|
||||||
return_length: Optional[torch.Tensor] = None,
|
return_length: Optional[torch.Tensor] = None,
|
||||||
# return_length2: Optional[torch.Tensor] = None,
|
# return_length2: Optional[torch.Tensor] = None,
|
||||||
@@ -191,17 +191,17 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
|||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
z = z[:, :, dec_head : dec_head + length]
|
z = z[:, :, dec_head : dec_head + length]
|
||||||
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
||||||
if nsff0 is not None:
|
if pitchf is not None:
|
||||||
nsff0 = nsff0[:, head : head + length]
|
pitchf = pitchf[:, head : head + length]
|
||||||
else:
|
else:
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
del z_p, m_p, logs_p
|
del z_p, m_p, logs_p
|
||||||
if nsff0 is not None:
|
if pitchf is not None:
|
||||||
o = self.dec(
|
o = self.dec(
|
||||||
z * x_mask,
|
z * x_mask,
|
||||||
nsff0,
|
pitchf,
|
||||||
g=g,
|
g=g,
|
||||||
# n_res=return_length2,
|
# n_res=return_length2,
|
||||||
)
|
)
|
||||||
@@ -235,7 +235,7 @@ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
|
|||||||
upsample_kernel_sizes: List[int],
|
upsample_kernel_sizes: List[int],
|
||||||
spk_embed_dim: int,
|
spk_embed_dim: int,
|
||||||
gin_channels: int,
|
gin_channels: int,
|
||||||
sr: str | int,
|
sr: Union[str, int],
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
spec_channels,
|
spec_channels,
|
||||||
@@ -281,7 +281,7 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
|
|||||||
upsample_kernel_sizes: List[int],
|
upsample_kernel_sizes: List[int],
|
||||||
spk_embed_dim: int,
|
spk_embed_dim: int,
|
||||||
gin_channels: int,
|
gin_channels: int,
|
||||||
sr: str | int,
|
sr: Union[str, int],
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
spec_channels,
|
spec_channels,
|
||||||
|
|||||||
Reference in New Issue
Block a user