from typing import Optional, List import torch from torch import nn from torch.nn import Conv1d, Conv2d from torch.nn import functional as F from torch.nn.utils import spectral_norm, weight_norm from rvc import residuals from rvc.residuals import ResidualCouplingBlock from rvc.utils import ( get_padding, slice_on_last_dim, rand_slice_segments_on_last_dim, ) from rvc.encoders import TextEncoder, PosteriorEncoder from rvc.generators import Generator from rvc.nsf import NSFGenerator has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) class SynthesizerTrnMsNSFsid(nn.Module): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: str | int, text_encoder_in_channels: int, ): super(SynthesizerTrnMs256NSFsid, self).__init__() if isinstance(sr, str): sr = { "32k": 32000, "40k": 40000, "48k": 48000, }[sr] self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = float(p_dropout) self.resblock = resblock self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.upsample_rates = upsample_rates self.upsample_initial_channel = upsample_initial_channel self.upsample_kernel_sizes = upsample_kernel_sizes self.segment_size = segment_size self.gin_channels = gin_channels # self.hop_length = hop_length# self.spk_embed_dim = spk_embed_dim self.enc_p = TextEncoder( text_encoder_in_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), ) self.dec = NSFGenerator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels, ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) def remove_weight_norm(self): self.dec.remove_weight_norm() self.flow.remove_weight_norm() if hasattr(self, "enc_q"): self.enc_q.remove_weight_norm() def __prepare_scriptable__(self): for hook in self.dec._forward_pre_hooks.values(): # The hook we want to remove is an instance of WeightNorm class, so # normally we would do `if isinstance(...)` but this class is not accessible # because of shadowing, so we check the module name directly. # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.dec) for hook in self.flow._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.flow) if hasattr(self, "enc_q"): for hook in self.enc_q._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.enc_q) return self @torch.jit.ignore def forward( self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, pitchf: torch.Tensor, y: torch.Tensor, y_lengths: torch.Tensor, ds: Optional[torch.Tensor] = None, ): # 这里ds是id,[bs,1] # print(1,pitch.shape)#[bs,t] g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) z_slice, ids_slice = rand_slice_segments_on_last_dim( z, y_lengths, self.segment_size ) # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size) # print(-2,pitchf.shape,z_slice.shape) o = self.dec(z_slice, pitchf, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) @torch.jit.export def infer( self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: torch.Tensor, nsff0: torch.Tensor, sid: torch.Tensor, skip_head: Optional[torch.Tensor] = None, return_length: Optional[torch.Tensor] = None, # return_length2: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) if skip_head is not None and return_length is not None: assert isinstance(skip_head, torch.Tensor) assert isinstance(return_length, torch.Tensor) head = int(skip_head.item()) length = int(return_length.item()) flow_head = torch.clamp(skip_head - 24, min=0) dec_head = head - int(flow_head.item()) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) z = z[:, :, dec_head : dec_head + length] x_mask = x_mask[:, :, dec_head : dec_head + length] nsff0 = nsff0[:, head : head + length] else: m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec( z * x_mask, nsff0, g=g, # n_res=return_length2, ) return o, x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: str | int, ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 256, ) class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: str | int, ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 768, ) class SynthesizerTrnMs256NSFsid_nono(nn.Module): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr = None, ): super(SynthesizerTrnMs256NSFsid_nono, self).__init__() self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = float(p_dropout) self.resblock = resblock self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.upsample_rates = upsample_rates self.upsample_initial_channel = upsample_initial_channel self.upsample_kernel_sizes = upsample_kernel_sizes self.segment_size = segment_size self.gin_channels = gin_channels # self.hop_length = hop_length# self.spk_embed_dim = spk_embed_dim self.enc_p = TextEncoder( 256, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False, ) self.dec = Generator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels, ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) def remove_weight_norm(self): self.dec.remove_weight_norm() self.flow.remove_weight_norm() if hasattr(self, "enc_q"): self.enc_q.remove_weight_norm() def __prepare_scriptable__(self): for hook in self.dec._forward_pre_hooks.values(): # The hook we want to remove is an instance of WeightNorm class, so # normally we would do `if isinstance(...)` but this class is not accessible # because of shadowing, so we check the module name directly. # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.dec) for hook in self.flow._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.flow) if hasattr(self, "enc_q"): for hook in self.enc_q._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.enc_q) return self @torch.jit.ignore def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1] g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) z_slice, ids_slice = rand_slice_segments_on_last_dim( z, y_lengths, self.segment_size ) o = self.dec(z_slice, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) @torch.jit.export def infer( self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, skip_head: Optional[torch.Tensor] = None, return_length: Optional[torch.Tensor] = None, # return_length2: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) if skip_head is not None and return_length is not None: assert isinstance(skip_head, torch.Tensor) assert isinstance(return_length, torch.Tensor) head = int(skip_head.item()) length = int(return_length.item()) flow_head = torch.clamp(skip_head - 24, min=0) dec_head = head - int(flow_head.item()) m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths, flow_head) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) z = z[:, :, dec_head : dec_head + length] x_mask = x_mask[:, :, dec_head : dec_head + length] else: m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec( z * x_mask, g=g, # n_res=return_length2 ) return o, x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr = None, ): super(SynthesizerTrnMs768NSFsid_nono, self).__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, ) del self.enc_p self.enc_p = TextEncoder( 768, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=False, ) class MultiPeriodDiscriminator(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(MultiPeriodDiscriminator, self).__init__() periods = [2, 3, 5, 7, 11, 17] # periods = [3, 5, 7, 11, 17, 23, 37] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = discs + [ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods ] self.discriminators = nn.ModuleList(discs) def forward(self, y, y_hat): y_d_rs = [] # y_d_gs = [] fmap_rs = [] fmap_gs = [] for i, d in enumerate(self.discriminators): y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) # for j in range(len(fmap_r)): # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) y_d_rs.append(y_d_r) y_d_gs.append(y_d_g) fmap_rs.append(fmap_r) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs class MultiPeriodDiscriminatorV2(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(MultiPeriodDiscriminatorV2, self).__init__() # periods = [2, 3, 5, 7, 11, 17] periods = [2, 3, 5, 7, 11, 17, 23, 37] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = discs + [ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods ] self.discriminators = nn.ModuleList(discs) def forward(self, y, y_hat): y_d_rs = [] # y_d_gs = [] fmap_rs = [] fmap_gs = [] for i, d in enumerate(self.discriminators): y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) # for j in range(len(fmap_r)): # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) y_d_rs.append(y_d_r) y_d_gs.append(y_d_g) fmap_rs.append(fmap_r) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() norm_f = weight_norm if use_spectral_norm == False else spectral_norm self.convs = nn.ModuleList( [ norm_f(Conv1d(1, 16, 15, 1, padding=7)), norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), ] ) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) def forward(self, x): fmap = [] for l in self.convs: x = l(x) x = F.leaky_relu(x, residuals.LRELU_SLOPE) fmap.append(x) x = self.conv_post(x) fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP, self).__init__() self.period = period self.use_spectral_norm = use_spectral_norm norm_f = weight_norm if use_spectral_norm == False else spectral_norm self.convs = nn.ModuleList( [ norm_f( Conv2d( 1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0), ) ), norm_f( Conv2d( 32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0), ) ), norm_f( Conv2d( 128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0), ) ), norm_f( Conv2d( 512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0), ) ), norm_f( Conv2d( 1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0), ) ), ] ) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): fmap = [] # 1d to 2d b, c, t = x.shape if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) if has_xpu and x.dtype == torch.bfloat16: x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to( dtype=torch.bfloat16 ) else: x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) for l in self.convs: x = l(x) x = F.leaky_relu(x, residuals.LRELU_SLOPE) fmap.append(x) x = self.conv_post(x) fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap