From e33ef19200e68e0f62e3ce5972b7084723e40232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Mon, 10 Jun 2024 21:34:35 +0900 Subject: [PATCH] optimize(infer.synthesizer): all modules inherit from one --- infer/lib/infer_pack/models.py | 291 +++++++++++---------------------- infer/lib/rtrvc.py | 28 ++-- infer/modules/train/train.py | 24 +-- infer/modules/vc/pipeline.py | 13 +- tools/cmd/infer-pm-index256.py | 2 +- 5 files changed, 127 insertions(+), 231 deletions(-) diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index 16d19c4..304c367 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -33,8 +33,9 @@ class SynthesizerTrnMsNSFsid(nn.Module): upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, - sr: str | int, + sr: Optional[str | int], encoder_dim: int, + use_f0: bool, ): super(SynthesizerTrnMs256NSFsid, self).__init__() if isinstance(sr, str): @@ -59,8 +60,8 @@ class SynthesizerTrnMsNSFsid(nn.Module): self.upsample_kernel_sizes = upsample_kernel_sizes self.segment_size = segment_size self.gin_channels = gin_channels - # self.hop_length = hop_length# self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder( encoder_dim, inter_channels, @@ -70,18 +71,31 @@ class SynthesizerTrnMsNSFsid(nn.Module): n_layers, kernel_size, float(p_dropout), + f0=use_f0, ) - self.dec = NSFGenerator( - inter_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=gin_channels, - sr=sr, - ) + if use_f0: + self.dec = NSFGenerator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + sr=sr, + ) + else: + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, @@ -133,11 +147,11 @@ class SynthesizerTrnMsNSFsid(nn.Module): self, phone: torch.Tensor, phone_lengths: torch.Tensor, - pitch: torch.Tensor, - pitchf: torch.Tensor, y: torch.Tensor, y_lengths: torch.Tensor, ds: Optional[torch.Tensor] = None, + pitch: Optional[torch.Tensor] = None, + pitchf: Optional[torch.Tensor] = None, ): # 这里ds是id,[bs,1] # print(1,pitch.shape)#[bs,t] g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 @@ -147,10 +161,11 @@ class SynthesizerTrnMsNSFsid(nn.Module): z_slice, ids_slice = rand_slice_segments_on_last_dim( z, y_lengths, self.segment_size ) - # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) - pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size) - # print(-2,pitchf.shape,z_slice.shape) - o = self.dec(z_slice, pitchf, g=g) + if pitchf is not None: + pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size) + o = self.dec(z_slice, pitchf, g=g) + else: + o = self.dec(z_slice, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) @torch.jit.export @@ -158,17 +173,15 @@ class SynthesizerTrnMsNSFsid(nn.Module): self, phone: torch.Tensor, phone_lengths: torch.Tensor, - pitch: torch.Tensor, - nsff0: torch.Tensor, sid: torch.Tensor, + pitch: Optional[torch.Tensor] = None, + nsff0: Optional[torch.Tensor] = None, skip_head: Optional[torch.Tensor] = None, return_length: Optional[torch.Tensor] = None, # return_length2: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) if skip_head is not None and return_length is not None: - assert isinstance(skip_head, torch.Tensor) - assert isinstance(return_length, torch.Tensor) head = int(skip_head.item()) length = int(return_length.item()) flow_head = torch.clamp(skip_head - 24, min=0) @@ -178,18 +191,28 @@ class SynthesizerTrnMsNSFsid(nn.Module): z = self.flow(z_p, x_mask, g=g, reverse=True) z = z[:, :, dec_head : dec_head + length] x_mask = x_mask[:, :, dec_head : dec_head + length] - nsff0 = nsff0[:, head : head + length] + if nsff0 is not None: + nsff0 = nsff0[:, head : head + length] else: m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) - o = self.dec( - z * x_mask, - nsff0, - g=g, - # n_res=return_length2, - ) - return o, x_mask, (z, z_p, m_p, logs_p) + del z_p, m_p, logs_p + if nsff0 is not None: + o = self.dec( + z * x_mask, + nsff0, + g=g, + # n_res=return_length2, + ) + else: + o = self.dec( + z * x_mask, + g=g, + # n_res=return_length2 + ) + del x_mask, z + return o # , x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid): @@ -234,6 +257,7 @@ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid): gin_channels, sr, 256, + True, ) @@ -279,10 +303,11 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid): gin_channels, sr, 768, + True, ) -class SynthesizerTrnMs256NSFsid_nono(nn.Module): +class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, @@ -304,162 +329,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module): gin_channels: int, sr=None, ): - super(SynthesizerTrnMs256NSFsid_nono, self).__init__() - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = float(p_dropout) - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.gin_channels = gin_channels - # self.hop_length = hop_length# - self.spk_embed_dim = spk_embed_dim - self.enc_p = TextEncoder( - 256, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - float(p_dropout), - f0=False, - ) - self.dec = Generator( - inter_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=gin_channels, - ) - self.enc_q = PosteriorEncoder( - spec_channels, - inter_channels, - hidden_channels, - 5, - 1, - 16, - gin_channels=gin_channels, - ) - self.flow = ResidualCouplingBlock( - inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels - ) - self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) - - def remove_weight_norm(self): - self.dec.remove_weight_norm() - self.flow.remove_weight_norm() - if hasattr(self, "enc_q"): - self.enc_q.remove_weight_norm() - - def __prepare_scriptable__(self): - for hook in self.dec._forward_pre_hooks.values(): - # The hook we want to remove is an instance of WeightNorm class, so - # normally we would do `if isinstance(...)` but this class is not accessible - # because of shadowing, so we check the module name directly. - # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.dec) - for hook in self.flow._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.flow) - if hasattr(self, "enc_q"): - for hook in self.enc_q._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.enc_q) - return self - - @torch.jit.ignore - def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1] - g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 - m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) - z_p = self.flow(z, y_mask, g=g) - z_slice, ids_slice = rand_slice_segments_on_last_dim( - z, y_lengths, self.segment_size - ) - o = self.dec(z_slice, g=g) - return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - - @torch.jit.export - def infer( - self, - phone: torch.Tensor, - phone_lengths: torch.Tensor, - sid: torch.Tensor, - skip_head: Optional[torch.Tensor] = None, - return_length: Optional[torch.Tensor] = None, - # return_length2: Optional[torch.Tensor] = None, - ): - g = self.emb_g(sid).unsqueeze(-1) - if skip_head is not None and return_length is not None: - assert isinstance(skip_head, torch.Tensor) - assert isinstance(return_length, torch.Tensor) - head = int(skip_head.item()) - length = int(return_length.item()) - flow_head = torch.clamp(skip_head - 24, min=0) - dec_head = head - int(flow_head.item()) - m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths, flow_head) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - z = self.flow(z_p, x_mask, g=g, reverse=True) - z = z[:, :, dec_head : dec_head + length] - x_mask = x_mask[:, :, dec_head : dec_head + length] - else: - m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) - z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask - z = self.flow(z_p, x_mask, g=g, reverse=True) - o = self.dec( - z * x_mask, - g=g, - # n_res=return_length2 - ) - return o, x_mask, (z, z_p, m_p, logs_p) - - -class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono): - def __init__( - self, - spec_channels: int, - segment_size: int, - inter_channels: int, - hidden_channels: int, - filter_channels: int, - n_heads: int, - n_layers: int, - kernel_size: int, - p_dropout: int, - resblock: str, - resblock_kernel_sizes: List[int], - resblock_dilation_sizes: List[List[int]], - upsample_rates: List[int], - upsample_initial_channel: int, - upsample_kernel_sizes: List[int], - spk_embed_dim: int, - gin_channels: int, - sr=None, - ): - super(SynthesizerTrnMs768NSFsid_nono, self).__init__( + super().__init__( spec_channels, segment_size, inter_channels, @@ -477,16 +347,51 @@ class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono): upsample_kernel_sizes, spk_embed_dim, gin_channels, + 256, + False, ) - del self.enc_p - self.enc_p = TextEncoder( - 768, + + +class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid): + def __init__( + self, + spec_channels: int, + segment_size: int, + inter_channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: int, + resblock: str, + resblock_kernel_sizes: List[int], + resblock_dilation_sizes: List[List[int]], + upsample_rates: List[int], + upsample_initial_channel: int, + upsample_kernel_sizes: List[int], + spk_embed_dim: int, + gin_channels: int, + sr=None, + ): + super().__init__( + spec_channels, + segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, - float(p_dropout), - f0=False, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + 768, + False, ) diff --git a/infer/lib/rtrvc.py b/infer/lib/rtrvc.py index 49f6140..dfe7e28 100644 --- a/infer/lib/rtrvc.py +++ b/infer/lib/rtrvc.py @@ -399,6 +399,7 @@ class RVC: p_len = input_wav.shape[0] // 160 factor = pow(2, self.formant_shift / 12) return_length2 = int(np.ceil(return_length * factor)) + cache_pitch = cache_pitchf = None if self.if_f0 == 1: f0_extractor_frame = block_frame_16k + 800 if f0method == "rmvpe": @@ -424,25 +425,18 @@ class RVC: p_len = torch.LongTensor([p_len]).to(self.device) sid = torch.LongTensor([0]).to(self.device) skip_head = torch.LongTensor([skip_head]) - return_length2 = torch.LongTensor([return_length2]) + # return_length2 = torch.LongTensor([return_length2]) return_length = torch.LongTensor([return_length]) with torch.no_grad(): - if self.if_f0 == 1: - infered_audio, _, _ = self.net_g.infer( - feats, - p_len, - cache_pitch, - cache_pitchf, - sid, - skip_head, - return_length, - return_length2, - ) - else: - infered_audio, _, _ = self.net_g.infer( - feats, p_len, sid, skip_head, return_length, return_length2 - ) - infered_audio = infered_audio.squeeze(1).float() + infered_audio = self.net_g.infer( + feats, + p_len, + sid, + pitch=cache_pitch, + pitchf=cache_pitchf, + skip_head=skip_head, + return_length=return_length, + ).squeeze(1).float() upp_res = int(np.floor(factor * self.tgt_sr // 100)) if upp_res != self.tgt_sr // 100: if upp_res not in self.resample_kernel: diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 1ba6b18..7881687 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -415,6 +415,7 @@ def train_and_evaluate( for batch_idx, info in data_iterator: # Data ## Unpack + pitch = pitchf = None if hps.if_f0 == 1: ( phone, @@ -444,22 +445,13 @@ def train_and_evaluate( # Calculate with autocast(enabled=hps.train.fp16_run): - if hps.if_f0 == 1: - ( - y_hat, - ids_slice, - x_mask, - z_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) - else: - ( - y_hat, - ids_slice, - x_mask, - z_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) = net_g(phone, phone_lengths, spec, spec_lengths, sid) + ( + y_hat, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) = net_g(phone, phone_lengths, spec, spec_lengths, sid, pitch, pitchf) mel = spec_to_mel_torch( spec, hps.data.filter_length, diff --git a/infer/modules/vc/pipeline.py b/infer/modules/vc/pipeline.py index 68b45b6..f3ec6ba 100644 --- a/infer/modules/vc/pipeline.py +++ b/infer/modules/vc/pipeline.py @@ -290,10 +290,15 @@ class Pipeline(object): feats = feats.to(feats0.dtype) p_len = torch.tensor([p_len], device=self.device).long() with torch.no_grad(): - hasp = pitch is not None and pitchf is not None - arg = (feats, p_len, pitch, pitchf, sid) if hasp else (feats, p_len, sid) - audio1 = (net_g.infer(*arg)[0][0, 0]).data.cpu().float().numpy() - del arg + audio1 = ( + net_g.infer( + feats, + p_len, + sid, + pitch=pitch, + pitchf=pitchf, + )[0, 0] + ).data.cpu().float().numpy() del feats, p_len, padding_mask if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/tools/cmd/infer-pm-index256.py b/tools/cmd/infer-pm-index256.py index d6b3b74..9a84c44 100644 --- a/tools/cmd/infer-pm-index256.py +++ b/tools/cmd/infer-pm-index256.py @@ -183,7 +183,7 @@ for idx, name in enumerate( pitchf = torch.FloatTensor(pitchf).unsqueeze(0).to(device) with torch.no_grad(): audio = ( - net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0] + net_g.infer(feats, p_len, sid, pitch=pitch, pitchf=pitchf)[0, 0] .data.cpu() .float() .numpy()