diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index eece2f9..7908c23 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -17,9 +17,10 @@ from infer.lib.infer_pack.commons import get_padding, init_weights has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) -class TextEncoder256(nn.Module): +class TextEncoder(nn.Module): def __init__( self, + in_channels, out_channels, hidden_channels, filter_channels, @@ -29,7 +30,7 @@ class TextEncoder256(nn.Module): p_dropout, f0=True, ): - super(TextEncoder256, self).__init__() + super(TextEncoder, self).__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -37,7 +38,7 @@ class TextEncoder256(nn.Module): self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = float(p_dropout) - self.emb_phone = nn.Linear(256, hidden_channels) + self.emb_phone = nn.Linear(in_channels, hidden_channels) self.lrelu = nn.LeakyReLU(0.1, inplace=True) if f0 == True: self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 @@ -51,9 +52,7 @@ class TextEncoder256(nn.Module): ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward( - self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor - ): + def forward(self, phone: torch.Tensor, pitch: torch.Tensor, lengths: torch.Tensor, skip_head: Optional[torch.Tensor] = None): if pitch is None: x = self.emb_phone(phone) else: @@ -65,60 +64,12 @@ class TextEncoder256(nn.Module): x.dtype ) x = self.encoder(x * x_mask, x_mask) + if skip_head is not None: + assert isinstance(skip_head, torch.Tensor) + head = int(skip_head.item()) + x = x[:, :, head : ] + x_mask = x_mask[:, :, head : ] stats = self.proj(x) * x_mask - - m, logs = torch.split(stats, self.out_channels, dim=1) - return m, logs, x_mask - - -class TextEncoder768(nn.Module): - def __init__( - self, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - f0=True, - ): - super(TextEncoder768, self).__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = float(p_dropout) - self.emb_phone = nn.Linear(768, hidden_channels) - self.lrelu = nn.LeakyReLU(0.1, inplace=True) - if f0 == True: - self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 - self.encoder = attentions.Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - float(p_dropout), - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, phone: torch.Tensor, pitch: torch.Tensor, lengths: torch.Tensor): - if pitch is None: - x = self.emb_phone(phone) - else: - x = self.emb_phone(phone) + self.emb_pitch(pitch) - x = x * math.sqrt(self.hidden_channels) # [b, t, h] - x = self.lrelu(x) - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( - x.dtype - ) - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs, x_mask @@ -682,7 +633,8 @@ class SynthesizerTrnMs256NSFsid(nn.Module): self.gin_channels = gin_channels # self.hop_length = hop_length# self.spk_embed_dim = spk_embed_dim - self.enc_p = TextEncoder256( + self.enc_p = TextEncoder( + 256, inter_channels, hidden_channels, filter_channels, @@ -792,22 +744,28 @@ class SynthesizerTrnMs256NSFsid(nn.Module): return_length: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - z = self.flow(z_p, x_mask, g=g, reverse=True) if skip_head is not None and return_length is not None: assert isinstance(skip_head, torch.Tensor) assert isinstance(return_length, torch.Tensor) head = int(skip_head.item()) length = int(return_length.item()) - z = z[:, :, head : head + length] - x_mask = x_mask[:, :, head : head + length] - nsff0 = nsff0[:, head : head + length] + flow_head = torch.clamp(skip_head - 24, min=0) + dec_head = head - int(flow_head.item()) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + z = self.flow(z_p, x_mask, g=g, reverse=True) + z = z[:, :, dec_head : dec_head + length] + x_mask = x_mask[:, :, dec_head : dec_head + length] + nsff0 = nsff0[:, head : head + length] + else: + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec(z * x_mask, nsff0, g=g) return o, x_mask, (z, z_p, m_p, logs_p) -class SynthesizerTrnMs768NSFsid(nn.Module): +class SynthesizerTrnMs768NSFsid(SynthesizerTrnMs256NSFsid): def __init__( self, spec_channels, @@ -830,28 +788,30 @@ class SynthesizerTrnMs768NSFsid(nn.Module): sr, **kwargs ): - super(SynthesizerTrnMs768NSFsid, self).__init__() - if isinstance(sr, str): - sr = sr2sr[sr] - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = float(p_dropout) - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.gin_channels = gin_channels - # self.hop_length = hop_length# - self.spk_embed_dim = spk_embed_dim - self.enc_p = TextEncoder768( + super(SynthesizerTrnMs768NSFsid, self).__init__( + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr, + **kwargs + ) + del self.enc_p + self.enc_p = TextEncoder( + 768, inter_channels, hidden_channels, filter_channels, @@ -860,113 +820,6 @@ class SynthesizerTrnMs768NSFsid(nn.Module): kernel_size, float(p_dropout), ) - self.dec = GeneratorNSF( - inter_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=gin_channels, - sr=sr, - is_half=kwargs["is_half"], - ) - self.enc_q = PosteriorEncoder( - spec_channels, - inter_channels, - hidden_channels, - 5, - 1, - 16, - gin_channels=gin_channels, - ) - self.flow = ResidualCouplingBlock( - inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels - ) - self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) - logger.debug( - "gin_channels: " - + str(gin_channels) - + ", self.spk_embed_dim: " - + str(self.spk_embed_dim) - ) - - def remove_weight_norm(self): - self.dec.remove_weight_norm() - self.flow.remove_weight_norm() - if hasattr(self, "enc_q"): - self.enc_q.remove_weight_norm() - - def __prepare_scriptable__(self): - for hook in self.dec._forward_pre_hooks.values(): - # The hook we want to remove is an instance of WeightNorm class, so - # normally we would do `if isinstance(...)` but this class is not accessible - # because of shadowing, so we check the module name directly. - # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.dec) - for hook in self.flow._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.flow) - if hasattr(self, "enc_q"): - for hook in self.enc_q._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.enc_q) - return self - - @torch.jit.ignore - def forward( - self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds - ): # 这里ds是id,[bs,1] - # print(1,pitch.shape)#[bs,t] - g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 - m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) - z_p = self.flow(z, y_mask, g=g) - z_slice, ids_slice = commons.rand_slice_segments( - z, y_lengths, self.segment_size - ) - # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) - pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size) - # print(-2,pitchf.shape,z_slice.shape) - o = self.dec(z_slice, pitchf, g=g) - return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - - @torch.jit.export - def infer( - self, - phone: torch.Tensor, - phone_lengths: torch.Tensor, - pitch: torch.Tensor, - nsff0: torch.Tensor, - sid: torch.Tensor, - skip_head: Optional[torch.Tensor] = None, - return_length: Optional[torch.Tensor] = None, - ): - g = self.emb_g(sid).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - z = self.flow(z_p, x_mask, g=g, reverse=True) - if skip_head is not None and return_length is not None: - assert isinstance(skip_head, torch.Tensor) - assert isinstance(return_length, torch.Tensor) - head = int(skip_head.item()) - length = int(return_length.item()) - z = z[:, :, head : head + length] - x_mask = x_mask[:, :, head : head + length] - nsff0 = nsff0[:, head : head + length] - o = self.dec(z * x_mask, nsff0, g=g) - return o, x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs256NSFsid_nono(nn.Module): @@ -1011,7 +864,8 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module): self.gin_channels = gin_channels # self.hop_length = hop_length# self.spk_embed_dim = spk_embed_dim - self.enc_p = TextEncoder256( + self.enc_p = TextEncoder( + 256, inter_channels, hidden_channels, filter_channels, @@ -1105,21 +959,27 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module): return_length: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - z = self.flow(z_p, x_mask, g=g, reverse=True) if skip_head is not None and return_length is not None: assert isinstance(skip_head, torch.Tensor) assert isinstance(return_length, torch.Tensor) head = int(skip_head.item()) length = int(return_length.item()) - z = z[:, :, head : head + length] - x_mask = x_mask[:, :, head : head + length] + flow_head = torch.clamp(skip_head - 24, min=0) + dec_head = head - int(flow_head.item()) + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths, flow_head) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + z = self.flow(z_p, x_mask, g=g, reverse=True) + z = z[:, :, dec_head : dec_head + length] + x_mask = x_mask[:, :, dec_head : dec_head + length] + else: + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec(z * x_mask, g=g) return o, x_mask, (z, z_p, m_p, logs_p) -class SynthesizerTrnMs768NSFsid_nono(nn.Module): +class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono): def __init__( self, spec_channels, @@ -1142,26 +1002,30 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module): sr=None, **kwargs ): - super(SynthesizerTrnMs768NSFsid_nono, self).__init__() - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = float(p_dropout) - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.gin_channels = gin_channels - # self.hop_length = hop_length# - self.spk_embed_dim = spk_embed_dim - self.enc_p = TextEncoder768( + super(SynthesizerTrnMs768NSFsid_nono, self).__init__( + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr, + **kwargs + ) + del self.enc_p + self.enc_p = TextEncoder( + 768, inter_channels, hidden_channels, filter_channels, @@ -1171,102 +1035,6 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module): float(p_dropout), f0=False, ) - self.dec = Generator( - inter_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=gin_channels, - ) - self.enc_q = PosteriorEncoder( - spec_channels, - inter_channels, - hidden_channels, - 5, - 1, - 16, - gin_channels=gin_channels, - ) - self.flow = ResidualCouplingBlock( - inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels - ) - self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) - logger.debug( - "gin_channels: " - + str(gin_channels) - + ", self.spk_embed_dim: " - + str(self.spk_embed_dim) - ) - - def remove_weight_norm(self): - self.dec.remove_weight_norm() - self.flow.remove_weight_norm() - if hasattr(self, "enc_q"): - self.enc_q.remove_weight_norm() - - def __prepare_scriptable__(self): - for hook in self.dec._forward_pre_hooks.values(): - # The hook we want to remove is an instance of WeightNorm class, so - # normally we would do `if isinstance(...)` but this class is not accessible - # because of shadowing, so we check the module name directly. - # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.dec) - for hook in self.flow._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.flow) - if hasattr(self, "enc_q"): - for hook in self.enc_q._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.enc_q) - return self - - @torch.jit.ignore - def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1] - g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 - m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) - z_p = self.flow(z, y_mask, g=g) - z_slice, ids_slice = commons.rand_slice_segments( - z, y_lengths, self.segment_size - ) - o = self.dec(z_slice, g=g) - return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - - @torch.jit.export - def infer( - self, - phone: torch.Tensor, - phone_lengths: torch.Tensor, - sid: torch.Tensor, - skip_head: Optional[torch.Tensor] = None, - return_length: Optional[torch.Tensor] = None, - ): - g = self.emb_g(sid).unsqueeze(-1) - m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - z = self.flow(z_p, x_mask, g=g, reverse=True) - if skip_head is not None and return_length is not None: - assert isinstance(skip_head, torch.Tensor) - assert isinstance(return_length, torch.Tensor) - head = int(skip_head.item()) - length = int(return_length.item()) - z = z[:, :, head : head + length] - x_mask = x_mask[:, :, head : head + length] - o = self.dec(z * x_mask, g=g) - return o, x_mask, (z, z_p, m_p, logs_p) class MultiPeriodDiscriminator(torch.nn.Module):