diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index 76b4a64..3e8ce1a 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Union import torch from torch import nn @@ -33,11 +33,11 @@ class SynthesizerTrnMsNSFsid(nn.Module): upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, - sr: Optional[str | int], + sr: Optional[Union[str, int]], encoder_dim: int, use_f0: bool, ): - super(SynthesizerTrnMs256NSFsid, self).__init__() + super().__init__() if isinstance(sr, str): sr = { "32k": 32000, @@ -175,7 +175,7 @@ class SynthesizerTrnMsNSFsid(nn.Module): phone_lengths: torch.Tensor, sid: torch.Tensor, pitch: Optional[torch.Tensor] = None, - nsff0: Optional[torch.Tensor] = None, + pitchf: Optional[torch.Tensor] = None, # nsff0 skip_head: Optional[torch.Tensor] = None, return_length: Optional[torch.Tensor] = None, # return_length2: Optional[torch.Tensor] = None, @@ -191,17 +191,17 @@ class SynthesizerTrnMsNSFsid(nn.Module): z = self.flow(z_p, x_mask, g=g, reverse=True) z = z[:, :, dec_head : dec_head + length] x_mask = x_mask[:, :, dec_head : dec_head + length] - if nsff0 is not None: - nsff0 = nsff0[:, head : head + length] + if pitchf is not None: + pitchf = pitchf[:, head : head + length] else: m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) del z_p, m_p, logs_p - if nsff0 is not None: + if pitchf is not None: o = self.dec( z * x_mask, - nsff0, + pitchf, g=g, # n_res=return_length2, ) @@ -235,7 +235,7 @@ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid): upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, - sr: str | int, + sr: Union[str, int], ): super().__init__( spec_channels, @@ -281,7 +281,7 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid): upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, - sr: str | int, + sr: Union[str, int], ): super().__init__( spec_channels,