From 87433c5bd976849b6253d9919aad574e942e2c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Mon, 10 Jun 2024 21:50:43 +0900 Subject: [PATCH] fix(infer): argument mismatch --- infer/lib/infer_pack/models.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index 76b4a64..3e8ce1a 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Union import torch from torch import nn @@ -33,11 +33,11 @@ class SynthesizerTrnMsNSFsid(nn.Module): upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, - sr: Optional[str | int], + sr: Optional[Union[str, int]], encoder_dim: int, use_f0: bool, ): - super(SynthesizerTrnMs256NSFsid, self).__init__() + super().__init__() if isinstance(sr, str): sr = { "32k": 32000, @@ -175,7 +175,7 @@ class SynthesizerTrnMsNSFsid(nn.Module): phone_lengths: torch.Tensor, sid: torch.Tensor, pitch: Optional[torch.Tensor] = None, - nsff0: Optional[torch.Tensor] = None, + pitchf: Optional[torch.Tensor] = None, # nsff0 skip_head: Optional[torch.Tensor] = None, return_length: Optional[torch.Tensor] = None, # return_length2: Optional[torch.Tensor] = None, @@ -191,17 +191,17 @@ class SynthesizerTrnMsNSFsid(nn.Module): z = self.flow(z_p, x_mask, g=g, reverse=True) z = z[:, :, dec_head : dec_head + length] x_mask = x_mask[:, :, dec_head : dec_head + length] - if nsff0 is not None: - nsff0 = nsff0[:, head : head + length] + if pitchf is not None: + pitchf = pitchf[:, head : head + length] else: m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) del z_p, m_p, logs_p - if nsff0 is not None: + if pitchf is not None: o = self.dec( z * x_mask, - nsff0, + pitchf, g=g, # n_res=return_length2, ) @@ -235,7 +235,7 @@ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid): upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, - sr: str | int, + sr: Union[str, int], ): super().__init__( spec_channels, @@ -281,7 +281,7 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid): upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, - sr: str | int, + sr: Union[str, int], ): super().__init__( spec_channels,