from typing import Optional, List import torch from torch import nn from rvc.residuals import ResidualCouplingBlock from rvc.utils import ( slice_on_last_dim, rand_slice_segments_on_last_dim, ) from rvc.encoders import TextEncoder, PosteriorEncoder from rvc.generators import Generator from rvc.nsf import NSFGenerator class SynthesizerTrnMsNSFsid(nn.Module): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Optional[str | int], encoder_dim: int, use_f0: bool, ): super(SynthesizerTrnMs256NSFsid, self).__init__() if isinstance(sr, str): sr = { "32k": 32000, "40k": 40000, "48k": 48000, }[sr] self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = float(p_dropout) self.resblock = resblock self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.upsample_rates = upsample_rates self.upsample_initial_channel = upsample_initial_channel self.upsample_kernel_sizes = upsample_kernel_sizes self.segment_size = segment_size self.gin_channels = gin_channels self.spk_embed_dim = spk_embed_dim self.enc_p = TextEncoder( encoder_dim, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=use_f0, ) if use_f0: self.dec = NSFGenerator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, ) else: self.dec = Generator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels, ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) def remove_weight_norm(self): self.dec.remove_weight_norm() self.flow.remove_weight_norm() if hasattr(self, "enc_q"): self.enc_q.remove_weight_norm() def __prepare_scriptable__(self): for hook in self.dec._forward_pre_hooks.values(): # The hook we want to remove is an instance of WeightNorm class, so # normally we would do `if isinstance(...)` but this class is not accessible # because of shadowing, so we check the module name directly. # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.dec) for hook in self.flow._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.flow) if hasattr(self, "enc_q"): for hook in self.enc_q._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.enc_q) return self @torch.jit.ignore def forward( self, phone: torch.Tensor, phone_lengths: torch.Tensor, y: torch.Tensor, y_lengths: torch.Tensor, ds: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, ): # 这里ds是id,[bs,1] # print(1,pitch.shape)#[bs,t] g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) z_slice, ids_slice = rand_slice_segments_on_last_dim( z, y_lengths, self.segment_size ) if pitchf is not None: pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size) o = self.dec(z_slice, pitchf, g=g) else: o = self.dec(z_slice, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) @torch.jit.export def infer( self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, pitch: Optional[torch.Tensor] = None, nsff0: Optional[torch.Tensor] = None, skip_head: Optional[torch.Tensor] = None, return_length: Optional[torch.Tensor] = None, # return_length2: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) if skip_head is not None and return_length is not None: head = int(skip_head.item()) length = int(return_length.item()) flow_head = torch.clamp(skip_head - 24, min=0) dec_head = head - int(flow_head.item()) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) z = z[:, :, dec_head : dec_head + length] x_mask = x_mask[:, :, dec_head : dec_head + length] if nsff0 is not None: nsff0 = nsff0[:, head : head + length] else: m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) del z_p, m_p, logs_p if nsff0 is not None: o = self.dec( z * x_mask, nsff0, g=g, # n_res=return_length2, ) else: o = self.dec( z * x_mask, g=g, # n_res=return_length2 ) del x_mask, z return o # , x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: str | int, ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 256, True, ) class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: str | int, ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 768, True, ) class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr=None, ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, 256, False, ) class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr=None, ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, 768, False, )