From 49488dcae9daf8b4e4c9c64fb51a22262563b9f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 7 Jun 2024 19:33:45 +0900 Subject: [PATCH] optimize(rvc.utils): more type defs & rename --- infer/lib/infer_pack/attentions.py | 44 +++++++---------------------- infer/lib/infer_pack/models.py | 20 ++++++------- infer/lib/infer_pack/models_onnx.py | 5 +--- infer/lib/infer_pack/modules.py | 7 ++--- infer/modules/train/train.py | 4 +-- rvc/utils.py | 36 ++++++++++------------- 6 files changed, 41 insertions(+), 75 deletions(-) diff --git a/infer/lib/infer_pack/attentions.py b/infer/lib/infer_pack/attentions.py index f18212f..8185477 100644 --- a/infer/lib/infer_pack/attentions.py +++ b/infer/lib/infer_pack/attentions.py @@ -18,7 +18,6 @@ class Encoder(nn.Module): kernel_size=1, p_dropout=0.0, window_size=10, - **kwargs ): super(Encoder, self).__init__() self.hidden_channels = hidden_channels @@ -55,8 +54,11 @@ class Encoder(nn.Module): ) ) self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def __call__(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor: + return super().__call__(x, x_mask) - def forward(self, x, x_mask): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor: attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask zippep = zip( @@ -86,7 +88,6 @@ class Decoder(nn.Module): p_dropout=0.0, proximal_bias=False, proximal_init=True, - **kwargs ): super(Decoder, self).__init__() self.hidden_channels = hidden_channels @@ -311,7 +312,6 @@ class MultiHeadAttention(nn.Module): if pad_length > 0: padded_relative_embeddings = F.pad( relative_embeddings, - # commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), [0, 0, pad_length, pad_length, 0, 0], ) else: @@ -328,19 +328,11 @@ class MultiHeadAttention(nn.Module): """ batch, heads, length, _ = x.size() # Concat columns of pad to shift from relative to absolute indexing. - x = F.pad( - x, - # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]) - [0, 1, 0, 0, 0, 0, 0, 0], - ) + x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0], ) # Concat extra elements so to add up to shape (len+1, 2*len-1). x_flat = x.view([batch, heads, length * 2 * length]) - x_flat = F.pad( - x_flat, - # commons.convert_pad_shape([[0, 0], [0, 0], [0, int(length) - 1]]) - [0, length - 1, 0, 0, 0, 0], - ) + x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0]) # Reshape and slice out the padded elements. x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ @@ -355,18 +347,10 @@ class MultiHeadAttention(nn.Module): """ batch, heads, length, _ = x.size() # padd along column - x = F.pad( - x, - # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, int(length) - 1]]) - [0, length - 1, 0, 0, 0, 0, 0, 0], - ) + x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) x_flat = x.view([batch, heads, (length**2) + (length * (length - 1))]) # add 0's in the beginning that will skew the elements after reshape - x_flat = F.pad( - x_flat, - # commons.convert_pad_shape([[0, 0], [0, 0], [int(length), 0]]) - [length, 0, 0, 0, 0, 0], - ) + x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] return x_final @@ -435,11 +419,7 @@ class FFN(nn.Module): pad_l: int = self.kernel_size - 1 pad_r: int = 0 # padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad( - x, - # commons.convert_pad_shape(padding) - [pad_l, pad_r, 0, 0, 0, 0], - ) + x = F.pad(x, [pad_l, pad_r, 0, 0, 0, 0]) return x def _same_padding(self, x): @@ -448,9 +428,5 @@ class FFN(nn.Module): pad_l: int = (self.kernel_size - 1) // 2 pad_r: int = self.kernel_size // 2 # padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad( - x, - # commons.convert_pad_shape(padding) - [pad_l, pad_r, 0, 0, 0, 0], - ) + x = F.pad(x, [pad_l, pad_r, 0, 0, 0, 0]) return x diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index d9bca42..f8d5771 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -95,7 +95,7 @@ class TextEncoder(nn.Module): x = x[:, :, head:] x_mask = x_mask[:, :, head:] """ - stats = self.proj(x) * x_mask + stats: torch.Tensor = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs, x_mask @@ -169,12 +169,12 @@ class ResidualCouplingBlock(nn.Module): class PosteriorEncoder(nn.Module): def __init__( self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, gin_channels=0, ): super(PosteriorEncoder, self).__init__() @@ -648,7 +648,7 @@ class GeneratorNSF(torch.nn.Module): class SynthesizerTrnMs256NSFsid(nn.Module): def __init__( self, - spec_channels, + spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, @@ -783,7 +783,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module): m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) - z_slice, ids_slice = utils.rand_slice_segments(z, y_lengths, self.segment_size) + z_slice, ids_slice = utils.rand_slice_segments_on_last_dim(z, y_lengths, self.segment_size) # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) pitchf = utils.slice_on_last_dim(pitchf, ids_slice, self.segment_size) # print(-2,pitchf.shape,z_slice.shape) @@ -1007,7 +1007,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module): m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) - z_slice, ids_slice = utils.rand_slice_segments(z, y_lengths, self.segment_size) + z_slice, ids_slice = utils.rand_slice_segments_on_last_dim(z, y_lengths, self.segment_size) o = self.dec(z_slice, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) diff --git a/infer/lib/infer_pack/models_onnx.py b/infer/lib/infer_pack/models_onnx.py index 00a42a8..26ad8e8 100644 --- a/infer/lib/infer_pack/models_onnx.py +++ b/infer/lib/infer_pack/models_onnx.py @@ -5,9 +5,6 @@ from .attentions import ( TextEncoder, ResidualCouplingBlock, PosteriorEncoder, - Generator, - SineGen, - SourceModuleHnNSF, GeneratorNSF, ) @@ -15,7 +12,7 @@ from .attentions import ( class SynthesizerTrnMsNSFsidM(nn.Module): def __init__( self, - spec_channels, + spec_channels: int, segment_size, inter_channels, hidden_channels, diff --git a/infer/lib/infer_pack/modules.py b/infer/lib/infer_pack/modules.py index 0587c18..593c301 100644 --- a/infer/lib/infer_pack/modules.py +++ b/infer/lib/infer_pack/modules.py @@ -136,7 +136,7 @@ class DDSConv(nn.Module): class WN(torch.nn.Module): def __init__( self, - hidden_channels, + hidden_channels: int, kernel_size, dilation_rate, n_layers, @@ -189,7 +189,6 @@ class WN(torch.nn.Module): self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None ): output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) if g is not None: g = self.cond_layer(g) @@ -197,14 +196,14 @@ class WN(torch.nn.Module): for i, (in_layer, res_skip_layer) in enumerate( zip(self.in_layers, self.res_skip_layers) ): - x_in = in_layer(x) + x_in: torch.Tensor = in_layer(x) if g is not None: cond_offset = i * 2 * self.hidden_channels g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] else: g_l = torch.zeros_like(x_in) - acts = utils.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = utils.activate_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels) acts = self.drop(acts) res_skip_acts = res_skip_layer(acts) diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index a6f3550..ef40b6a 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -481,7 +481,7 @@ def train_and_evaluate( optim_d.zero_grad() scaler.scale(loss_disc).backward() scaler.unscale_(optim_d) - grad_norm_d = utils.clip_grad_value_(net_d.parameters(), None) + grad_norm_d = utils.total_grad_norm(net_d.parameters()) scaler.step(optim_d) with autocast(enabled=hps.train.fp16_run): @@ -496,7 +496,7 @@ def train_and_evaluate( optim_g.zero_grad() scaler.scale(loss_gen_all).backward() scaler.unscale_(optim_g) - grad_norm_g = utils.clip_grad_value_(net_g.parameters(), None) + grad_norm_g = utils.total_grad_norm(net_g.parameters()) scaler.step(optim_g) scaler.update() diff --git a/rvc/utils.py b/rvc/utils.py index 94b52e5..4a94f3f 100644 --- a/rvc/utils.py +++ b/rvc/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Iterator import torch @@ -11,7 +11,7 @@ def call_weight_data_normal_if_Conv(m: torch.nn.Module): m.weight.data.normal_(mean, std) -def get_padding(kernel_size: int, dilation=1): +def get_padding(kernel_size: int, dilation=1) -> int: return int((kernel_size * dilation - dilation) / 2) @@ -30,7 +30,7 @@ def slice_on_last_dim( return ret -def rand_slice_segments( +def rand_slice_segments_on_last_dim( x: torch.Tensor, x_lengths: int = None, segment_size=4, @@ -45,19 +45,16 @@ def rand_slice_segments( @torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] +def activate_add_tanh_sigmoid_multiply( + input_a: torch.Tensor, input_b: torch.Tensor, n_channels: int +) -> torch.Tensor: in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + t_act = torch.tanh(in_act[:, :n_channels, :]) + s_act = torch.sigmoid(in_act[:, n_channels:, :]) acts = t_act * s_act return acts -def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]: - return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist() - - def sequence_mask( length: torch.Tensor, max_length: Optional[int] = None, @@ -68,19 +65,16 @@ def sequence_mask( return x.unsqueeze(0) < length.unsqueeze(1) -def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) +def total_grad_norm( + parameters: Iterator[torch.nn.Parameter], norm_type: float=2.0, +) -> float: norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) + total_norm = 0.0 - total_norm = 0 for p in parameters: + if p.grad is None: continue param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type - if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm += float(param_norm.item()) ** norm_type total_norm = total_norm ** (1.0 / norm_type) + return total_norm