from typing import Optional, List, Union import torch from torch import nn from torch.nn.utils import parametrize from .encoders import TextEncoder, PosteriorEncoder from .generators import Generator from .nsf import NSFGenerator from .residuals import ResidualCouplingBlock from .utils import ( slice_on_last_dim, rand_slice_segments_on_last_dim, ) class SynthesizerTrnMsNSFsid(nn.Module): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Union[str, int], encoder_dim: int, use_f0: bool, ): super().__init__() if isinstance(sr, str): sr = { "32k": 32000, "40k": 40000, "48k": 48000, }[sr] self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = float(p_dropout) self.resblock = resblock self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.upsample_rates = upsample_rates self.upsample_initial_channel = upsample_initial_channel self.upsample_kernel_sizes = upsample_kernel_sizes self.segment_size = segment_size self.gin_channels = gin_channels self.spk_embed_dim = spk_embed_dim self.enc_p = TextEncoder( encoder_dim, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=use_f0, ) if use_f0: self.dec = NSFGenerator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, ) else: self.dec = Generator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels, ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) def remove_weight_norm(self): self.dec.remove_weight_norm() self.flow.remove_weight_norm() if hasattr(self, "enc_q"): self.enc_q.remove_weight_norm() def __prepare_scriptable__(self): if parametrize.is_parametrized(self.dec, "weight"): parametrize.remove_parametrizations(self.dec, "weight") if parametrize.is_parametrized(self.flow, "weight"): parametrize.remove_parametrizations(self.flow, "weight") if hasattr(self, "enc_q"): if parametrize.is_parametrized(self.enc_q, "weight"): parametrize.remove_parametrizations(self.enc_q, "weight") return self @torch.jit.ignore() def forward( self, phone: torch.Tensor, phone_lengths: torch.Tensor, y: torch.Tensor, y_lengths: torch.Tensor, ds: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, ): # 这里ds是id,[bs,1] # print(1,pitch.shape)#[bs,t] embg = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=embg) z_p = self.flow(z, y_mask, g=embg) z_slice, ids_slice = rand_slice_segments_on_last_dim( z, y_lengths, self.segment_size ) if pitchf is not None and isinstance(self.dec, NSFGenerator): pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size) o = self.dec(z_slice, pitchf, g=embg) # type: ignore elif isinstance(self.dec, Generator): o = self.dec(z_slice, g=embg) else: raise KeyError(f"unknown dec type: {type(self.dec).__name__}") return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) @torch.jit.export def infer( self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, # nsff0 skip_head: Optional[int] = None, return_length: Optional[int] = None, return_length2: Optional[int] = None, ): g = self.emb_g(sid).unsqueeze(-1) if skip_head is not None and return_length is not None: head = int(skip_head) length = int(return_length) flow_head = head - 24 if flow_head < 0: flow_head = 0 dec_head = head - flow_head m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) z = z[:, :, dec_head : dec_head + length] x_mask = x_mask[:, :, dec_head : dec_head + length] if pitchf is not None: pitchf = pitchf[:, head : head + length] else: m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) del z_p, m_p, logs_p if pitchf is not None and isinstance(self.dec, NSFGenerator): o = self.dec( z * x_mask, pitchf, g=g, n_res=return_length2, ) elif isinstance(self.dec, Generator): o = self.dec(z * x_mask, g=g, n_res=return_length2) else: raise KeyError(f"unknown dec type: {type(self.dec).__name__}") del x_mask, z return o # , x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Union[str, int], ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 256, True, ) class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Union[str, int], ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 768, True, ) class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Union[str, int], ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 256, False, ) class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Union[str, int], ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 768, False, )