From 5eed789fe7a99b6520c6fc3c5ada4ad9f3be68a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 7 Jun 2024 00:42:35 +0900 Subject: [PATCH] optimize(rvc): move commons to rvc.utils - remove redundant attentions_onnx - shrink models_onnx - add some type note to rvc.utils --- infer/lib/infer_pack/attentions.py | 21 +- infer/lib/infer_pack/attentions_onnx.py | 459 -------------- infer/lib/infer_pack/commons.py | 172 ------ infer/lib/infer_pack/models.py | 134 ++-- infer/lib/infer_pack/models_onnx.py | 776 +----------------------- infer/lib/infer_pack/modules.py | 12 +- infer/modules/train/train.py | 10 +- rvc/utils.py | 79 +++ 8 files changed, 186 insertions(+), 1477 deletions(-) delete mode 100644 infer/lib/infer_pack/attentions_onnx.py delete mode 100644 infer/lib/infer_pack/commons.py create mode 100644 rvc/utils.py diff --git a/infer/lib/infer_pack/attentions.py b/infer/lib/infer_pack/attentions.py index 2cc745a..fb43440 100644 --- a/infer/lib/infer_pack/attentions.py +++ b/infer/lib/infer_pack/attentions.py @@ -1,13 +1,10 @@ -import copy import math from typing import Optional -import numpy as np import torch from torch import nn from torch.nn import functional as F -from infer.lib.infer_pack import commons, modules from infer.lib.infer_pack.modules import LayerNorm @@ -76,7 +73,7 @@ class Encoder(nn.Module): x = x * x_mask return x - +""" class Decoder(nn.Module): def __init__( self, @@ -138,11 +135,9 @@ class Decoder(nn.Module): self.norm_layers_2.append(LayerNorm(hidden_channels)) def forward(self, x, x_mask, h, h_mask): - """ - x: decoder input - h: encoder output - """ - self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + # x: decoder input + # h: encoder output + self_attn_mask = utils.subsequent_mask(x_mask.size(2)).to( device=x.device, dtype=x.dtype ) encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) @@ -161,7 +156,7 @@ class Decoder(nn.Module): x = self.norm_layers_2[i](x + y) x = x * x_mask return x - +""" class MultiHeadAttention(nn.Module): def __init__( @@ -342,7 +337,7 @@ class MultiHeadAttention(nn.Module): x_flat = F.pad( x_flat, # commons.convert_pad_shape([[0, 0], [0, 0], [0, int(length) - 1]]) - [0, int(length) - 1, 0, 0, 0, 0], + [0, length - 1, 0, 0, 0, 0], ) # Reshape and slice out the padded elements. @@ -361,9 +356,9 @@ class MultiHeadAttention(nn.Module): x = F.pad( x, # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, int(length) - 1]]) - [0, int(length) - 1, 0, 0, 0, 0, 0, 0], + [0, length - 1, 0, 0, 0, 0, 0, 0], ) - x_flat = x.view([batch, heads, int(length**2) + int(length * (length - 1))]) + x_flat = x.view([batch, heads, (length**2) + (length * (length - 1))]) # add 0's in the beginning that will skew the elements after reshape x_flat = F.pad( x_flat, diff --git a/infer/lib/infer_pack/attentions_onnx.py b/infer/lib/infer_pack/attentions_onnx.py deleted file mode 100644 index 934c58b..0000000 --- a/infer/lib/infer_pack/attentions_onnx.py +++ /dev/null @@ -1,459 +0,0 @@ -import copy -import math -from typing import Optional - -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - -from infer.lib.infer_pack import commons, modules -from infer.lib.infer_pack.modules import LayerNorm - - -class Encoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - window_size=10, - **kwargs - ): - super(Encoder, self).__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = int(n_layers) - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.window_size = window_size - - self.drop = nn.Dropout(p_dropout) - self.attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - for i in range(self.n_layers): - self.attn_layers.append( - MultiHeadAttention( - hidden_channels, - hidden_channels, - n_heads, - p_dropout=p_dropout, - window_size=window_size, - ) - ) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - ) - ) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask): - attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - zippep = zip( - self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2 - ) - for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zippep: - y = attn_layers(x, x, attn_mask) - y = self.drop(y) - x = norm_layers_1(x + y) - - y = ffn_layers(x, x_mask) - y = self.drop(y) - x = norm_layers_2(x + y) - x = x * x_mask - return x - - -class Decoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - proximal_bias=False, - proximal_init=True, - **kwargs - ): - super(Decoder, self).__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - - self.drop = nn.Dropout(p_dropout) - self.self_attn_layers = nn.ModuleList() - self.norm_layers_0 = nn.ModuleList() - self.encdec_attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - for i in range(self.n_layers): - self.self_attn_layers.append( - MultiHeadAttention( - hidden_channels, - hidden_channels, - n_heads, - p_dropout=p_dropout, - proximal_bias=proximal_bias, - proximal_init=proximal_init, - ) - ) - self.norm_layers_0.append(LayerNorm(hidden_channels)) - self.encdec_attn_layers.append( - MultiHeadAttention( - hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout - ) - ) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - causal=True, - ) - ) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask, h, h_mask): - """ - x: decoder input - h: encoder output - """ - self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( - device=x.device, dtype=x.dtype - ) - encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - for i in range(self.n_layers): - y = self.self_attn_layers[i](x, x, self_attn_mask) - y = self.drop(y) - x = self.norm_layers_0[i](x + y) - - y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x - - -class MultiHeadAttention(nn.Module): - def __init__( - self, - channels, - out_channels, - n_heads, - p_dropout=0.0, - window_size=None, - heads_share=True, - block_length=None, - proximal_bias=False, - proximal_init=False, - ): - super(MultiHeadAttention, self).__init__() - assert channels % n_heads == 0 - - self.channels = channels - self.out_channels = out_channels - self.n_heads = n_heads - self.p_dropout = p_dropout - self.window_size = window_size - self.heads_share = heads_share - self.block_length = block_length - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - self.attn = None - - self.k_channels = channels // n_heads - self.conv_q = nn.Conv1d(channels, channels, 1) - self.conv_k = nn.Conv1d(channels, channels, 1) - self.conv_v = nn.Conv1d(channels, channels, 1) - self.conv_o = nn.Conv1d(channels, out_channels, 1) - self.drop = nn.Dropout(p_dropout) - - if window_size is not None: - n_heads_rel = 1 if heads_share else n_heads - rel_stddev = self.k_channels**-0.5 - self.emb_rel_k = nn.Parameter( - torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) - * rel_stddev - ) - self.emb_rel_v = nn.Parameter( - torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) - * rel_stddev - ) - - nn.init.xavier_uniform_(self.conv_q.weight) - nn.init.xavier_uniform_(self.conv_k.weight) - nn.init.xavier_uniform_(self.conv_v.weight) - if proximal_init: - with torch.no_grad(): - self.conv_k.weight.copy_(self.conv_q.weight) - self.conv_k.bias.copy_(self.conv_q.bias) - - def forward( - self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None - ): - q = self.conv_q(x) - k = self.conv_k(c) - v = self.conv_v(c) - - x, _ = self.attention(q, k, v, mask=attn_mask) - - x = self.conv_o(x) - return x - - def attention( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ): - # reshape [b, d, t] -> [b, n_h, t, d_k] - b, d, t_s = key.size() - t_t = query.size(2) - query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) - key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) - if self.window_size is not None: - assert ( - t_s == t_t - ), "Relative attention is only available for self-attention." - key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys( - query / math.sqrt(self.k_channels), key_relative_embeddings - ) - scores_local = self._relative_position_to_absolute_position(rel_logits) - scores = scores + scores_local - if self.proximal_bias: - assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to( - device=scores.device, dtype=scores.dtype - ) - if mask is not None: - scores = scores.masked_fill(mask == 0, -1e4) - if self.block_length is not None: - assert ( - t_s == t_t - ), "Local attention is only available for self-attention." - block_mask = ( - torch.ones_like(scores) - .triu(-self.block_length) - .tril(self.block_length) - ) - scores = scores.masked_fill(block_mask == 0, -1e4) - p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] - p_attn = self.drop(p_attn) - output = torch.matmul(p_attn, value) - if self.window_size is not None: - relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_v, t_s - ) - output = output + self._matmul_with_relative_values( - relative_weights, value_relative_embeddings - ) - output = ( - output.transpose(2, 3).contiguous().view(b, d, t_t) - ) # [b, n_h, t_t, d_k] -> [b, d, t_t] - return output, p_attn - - def _matmul_with_relative_values(self, x, y): - """ - x: [b, h, l, m] - y: [h or 1, m, d] - ret: [b, h, l, d] - """ - ret = torch.matmul(x, y.unsqueeze(0)) - return ret - - def _matmul_with_relative_keys(self, x, y): - """ - x: [b, h, l, d] - y: [h or 1, m, d] - ret: [b, h, l, m] - """ - ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) - return ret - - def _get_relative_embeddings(self, relative_embeddings, length: int): - max_relative_position = 2 * self.window_size + 1 - # Pad first before slice to avoid using cond ops. - pad_length: int = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) - slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, - # commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), - [0, 0, pad_length, pad_length, 0, 0], - ) - else: - padded_relative_embeddings = relative_embeddings - used_relative_embeddings = padded_relative_embeddings[ - :, slice_start_position:slice_end_position - ] - return used_relative_embeddings - - def _relative_position_to_absolute_position(self, x): - """ - x: [b, h, l, 2*l-1] - ret: [b, h, l, l] - """ - batch, heads, length, _ = x.size() - # Concat columns of pad to shift from relative to absolute indexing. - x = F.pad( - x, - # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]) - [0, 1, 0, 0, 0, 0, 0, 0], - ) - - # Concat extra elements so to add up to shape (len+1, 2*len-1). - x_flat = x.view([batch, heads, length * 2 * length]) - x_flat = F.pad( - x_flat, - # commons.convert_pad_shape([[0, 0], [0, 0], [0, int(length) - 1]]) - [0, length - 1, 0, 0, 0, 0], - ) - - # Reshape and slice out the padded elements. - x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ - :, :, :length, length - 1 : - ] - return x_final - - def _absolute_position_to_relative_position(self, x): - """ - x: [b, h, l, l] - ret: [b, h, l, 2*l-1] - """ - batch, heads, length, _ = x.size() - # padd along column - x = F.pad( - x, - # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, int(length) - 1]]) - [0, length - 1, 0, 0, 0, 0, 0, 0], - ) - x_flat = x.view([batch, heads, (length**2) + (length * (length - 1))]) - # add 0's in the beginning that will skew the elements after reshape - x_flat = F.pad( - x_flat, - # commons.convert_pad_shape([[0, 0], [0, 0], [int(length), 0]]) - [length, 0, 0, 0, 0, 0], - ) - x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] - return x_final - - def _attention_bias_proximal(self, length: int): - """Bias for self-attention to encourage attention to close positions. - Args: - length: an integer scalar. - Returns: - a Tensor with shape [1, 1, length, length] - """ - r = torch.arange(length, dtype=torch.float32) - diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) - return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) - - -class FFN(nn.Module): - def __init__( - self, - in_channels, - out_channels, - filter_channels, - kernel_size, - p_dropout=0.0, - activation: str = None, - causal=False, - ): - super(FFN, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.activation = activation - self.causal = causal - self.is_activation = True if activation == "gelu" else False - # if causal: - # self.padding = self._causal_padding - # else: - # self.padding = self._same_padding - - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) - self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) - self.drop = nn.Dropout(p_dropout) - - def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor: - if self.causal: - padding = self._causal_padding(x * x_mask) - else: - padding = self._same_padding(x * x_mask) - return padding - - def forward(self, x: torch.Tensor, x_mask: torch.Tensor): - x = self.conv_1(self.padding(x, x_mask)) - if self.is_activation: - x = x * torch.sigmoid(1.702 * x) - else: - x = torch.relu(x) - x = self.drop(x) - - x = self.conv_2(self.padding(x, x_mask)) - return x * x_mask - - def _causal_padding(self, x): - if self.kernel_size == 1: - return x - pad_l: int = self.kernel_size - 1 - pad_r: int = 0 - # padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad( - x, - # commons.convert_pad_shape(padding) - [pad_l, pad_r, 0, 0, 0, 0], - ) - return x - - def _same_padding(self, x): - if self.kernel_size == 1: - return x - pad_l: int = (self.kernel_size - 1) // 2 - pad_r: int = self.kernel_size // 2 - # padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad( - x, - # commons.convert_pad_shape(padding) - [pad_l, pad_r, 0, 0, 0, 0], - ) - return x diff --git a/infer/lib/infer_pack/commons.py b/infer/lib/infer_pack/commons.py deleted file mode 100644 index 4ec6c24..0000000 --- a/infer/lib/infer_pack/commons.py +++ /dev/null @@ -1,172 +0,0 @@ -from typing import List, Optional -import math - -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -# def convert_pad_shape(pad_shape): -# l = pad_shape[::-1] -# pad_shape = [item for sublist in l for item in sublist] -# return pad_shape - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += ( - 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - ) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - - -def slice_segments2(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, idx_str:idx_end] - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( - num_timescales - 1 - ) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - -def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -# def convert_pad_shape(pad_shape): -# l = pad_shape[::-1] -# pad_shape = [item for sublist in l for item in sublist] -# return pad_shape - - -def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]: - return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist() - - -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - -def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - device = duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - -def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type - if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1.0 / norm_type) - return total_norm diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index a1a27e2..89f3142 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -1,17 +1,18 @@ import math import logging -from typing import Optional +from typing import Optional, Tuple, List + +from rvc import utils logger = logging.getLogger(__name__) -import numpy as np import torch from torch import nn -from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn import Conv1d, Conv2d, ConvTranspose1d from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm -from infer.lib.infer_pack import attentions, commons, modules -from infer.lib.infer_pack.commons import get_padding, init_weights +from infer.lib.infer_pack import attentions, modules +from rvc.utils import get_padding, call_weight_data_normal_if_Conv has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) @@ -51,13 +52,25 @@ class TextEncoder(nn.Module): ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + def __call__( + self, + phone: torch.Tensor, + pitch: torch.Tensor, + lengths: torch.Tensor, + # skip_head: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return super().__call__( + phone, pitch, lengths, + # skip_head=skip_head, + ) + def forward( self, phone: torch.Tensor, pitch: torch.Tensor, lengths: torch.Tensor, - skip_head: Optional[torch.Tensor] = None, - ): + # skip_head: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if pitch is None: x = self.emb_phone(phone) else: @@ -65,15 +78,19 @@ class TextEncoder(nn.Module): x = x * math.sqrt(self.hidden_channels) # [b, t, h] x = self.lrelu(x) x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( - x.dtype - ) + x_mask = torch.unsqueeze( + utils.sequence_mask( + lengths, x.size(2), + ), 1, + ).to(x.dtype) x = self.encoder(x * x_mask, x_mask) + """ if skip_head is not None: assert isinstance(skip_head, torch.Tensor) head = int(skip_head.item()) x = x[:, :, head:] x_mask = x_mask[:, :, head:] + """ stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs, x_mask @@ -125,7 +142,7 @@ class ResidualCouplingBlock(nn.Module): for flow in self.flows: x, _ = flow(x, x_mask, g=g, reverse=reverse) else: - for flow in self.flows[::-1]: + for flow in reversed(self.flows): x, _ = flow.forward(x, x_mask, g=g, reverse=reverse) return x @@ -175,12 +192,19 @@ class PosteriorEncoder(nn.Module): ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + def __call__( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + super().__call__(x, x_lengths, g = g) + def forward( self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None - ): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( - x.dtype - ) + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x_mask = torch.unsqueeze( + utils.sequence_mask( + x_lengths, x.size(2), + ), 1, + ).to(x.dtype) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask @@ -244,7 +268,7 @@ class Generator(torch.nn.Module): self.resblocks.append(resblock(ch, k, d)) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) + self.ups.apply(call_weight_data_normal_if_Conv) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) @@ -253,13 +277,15 @@ class Generator(torch.nn.Module): self, x: torch.Tensor, g: Optional[torch.Tensor] = None, - n_res: Optional[torch.Tensor] = None, + # n_res: Optional[torch.Tensor] = None, ): + """ if n_res is not None: assert isinstance(n_res, torch.Tensor) n = int(n_res.item()) if n != x.shape[-1]: x = F.interpolate(x, size=n, mode="linear") + """ x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -529,7 +555,7 @@ class GeneratorNSF(torch.nn.Module): self.resblocks.append(resblock(ch, k, d)) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) + self.ups.apply(call_weight_data_normal_if_Conv) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) @@ -543,10 +569,11 @@ class GeneratorNSF(torch.nn.Module): x, f0, g: Optional[torch.Tensor] = None, - n_res: Optional[torch.Tensor] = None, + # n_res: Optional[torch.Tensor] = None, ): har_source, noi_source, uv = self.m_source(f0, self.upp) har_source = har_source.transpose(1, 2) + """ if n_res is not None: assert isinstance(n_res, torch.Tensor) n = int(n_res.item()) @@ -554,6 +581,7 @@ class GeneratorNSF(torch.nn.Module): har_source = F.interpolate(har_source, size=n * self.upp, mode="linear") if n != x.shape[-1]: x = F.interpolate(x, size=n, mode="linear") + """ x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -611,39 +639,35 @@ class GeneratorNSF(torch.nn.Module): return self -sr2sr = { - "32k": 32000, - "40k": 40000, - "48k": 48000, -} - - class SynthesizerTrnMs256NSFsid(nn.Module): def __init__( self, spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - spk_embed_dim, - gin_channels, - sr, + segment_size: int, + inter_channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: int, + resblock: str, + resblock_kernel_sizes: List[int], + resblock_dilation_sizes: List[List[int]], + upsample_rates: List[int], + upsample_initial_channel: int, + upsample_kernel_sizes: List[int], + spk_embed_dim: int, + gin_channels: int, + sr: str | int, **kwargs ): super(SynthesizerTrnMs256NSFsid, self).__init__() - if isinstance(sr, str): - sr = sr2sr[sr] + if isinstance(sr, str): sr = { + "32k": 32000, + "40k": 40000, + "48k": 48000, + }[sr] self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels @@ -752,11 +776,11 @@ class SynthesizerTrnMs256NSFsid(nn.Module): m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) - z_slice, ids_slice = commons.rand_slice_segments( + z_slice, ids_slice = utils.rand_slice_segments( z, y_lengths, self.segment_size ) # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) - pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size) + pitchf = utils.slice_on_last_dim(pitchf, ids_slice, self.segment_size) # print(-2,pitchf.shape,z_slice.shape) o = self.dec(z_slice, pitchf, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) @@ -771,7 +795,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module): sid: torch.Tensor, skip_head: Optional[torch.Tensor] = None, return_length: Optional[torch.Tensor] = None, - return_length2: Optional[torch.Tensor] = None, + # return_length2: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) if skip_head is not None and return_length is not None: @@ -791,7 +815,10 @@ class SynthesizerTrnMs256NSFsid(nn.Module): m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) - o = self.dec(z * x_mask, nsff0, g=g, n_res=return_length2) + o = self.dec( + z * x_mask, nsff0, g=g, + # n_res=return_length2, + ) return o, x_mask, (z, z_p, m_p, logs_p) @@ -973,7 +1000,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module): m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) - z_slice, ids_slice = commons.rand_slice_segments( + z_slice, ids_slice = utils.rand_slice_segments( z, y_lengths, self.segment_size ) o = self.dec(z_slice, g=g) @@ -987,7 +1014,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module): sid: torch.Tensor, skip_head: Optional[torch.Tensor] = None, return_length: Optional[torch.Tensor] = None, - return_length2: Optional[torch.Tensor] = None, + #return_length2: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) if skip_head is not None and return_length is not None: @@ -1006,7 +1033,10 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module): m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) - o = self.dec(z * x_mask, g=g, n_res=return_length2) + o = self.dec( + z * x_mask, g=g, + # n_res=return_length2 + ) return o, x_mask, (z, z_p, m_p, logs_p) diff --git a/infer/lib/infer_pack/models_onnx.py b/infer/lib/infer_pack/models_onnx.py index b06bb9a..63c2caa 100644 --- a/infer/lib/infer_pack/models_onnx.py +++ b/infer/lib/infer_pack/models_onnx.py @@ -1,594 +1,7 @@ -import math -import logging -from typing import Optional - -logger = logging.getLogger(__name__) - -import numpy as np import torch from torch import nn -from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d -from torch.nn import functional as F -from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm -from infer.lib.infer_pack import commons, modules -from infer.lib.infer_pack.commons import get_padding, init_weights -import infer.lib.infer_pack.attentions_onnx as attentions - -has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) - - -class TextEncoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - f0=True, - ): - super(TextEncoder, self).__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = float(p_dropout) - self.emb_phone = nn.Linear(in_channels, hidden_channels) - self.lrelu = nn.LeakyReLU(0.1, inplace=True) - if f0 == True: - self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 - self.encoder = attentions.Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - float(p_dropout), - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward( - self, - phone: torch.Tensor, - pitch: torch.Tensor, - lengths: torch.Tensor, - skip_head: Optional[torch.Tensor] = None, - ): - if pitch is None: - x = self.emb_phone(phone) - else: - x = self.emb_phone(phone) + self.emb_pitch(pitch) - x = x * math.sqrt(self.hidden_channels) # [b, t, h] - x = self.lrelu(x) - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( - x.dtype - ) - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - return m, logs, x_mask - - -class ResidualCouplingBlock(nn.Module): - def __init__( - self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0, - ): - super(ResidualCouplingBlock, self).__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append( - modules.ResidualCouplingLayer( - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - mean_only=True, - ) - ) - self.flows.append(modules.Flip()) - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - g: Optional[torch.Tensor] = None, - reverse: bool = False, - ): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in reversed(self.flows): - x, _ = flow.forward(x, x_mask, g=g, reverse=reverse) - return x - - def remove_weight_norm(self): - for i in range(self.n_flows): - self.flows[i * 2].remove_weight_norm() - - def __prepare_scriptable__(self): - for i in range(self.n_flows): - for hook in self.flows[i * 2]._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.flows[i * 2]) - - return self - - -class PosteriorEncoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0, - ): - super(PosteriorEncoder, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN( - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward( - self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None - ): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( - x.dtype - ) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - def remove_weight_norm(self): - self.enc.remove_weight_norm() - - def __prepare_scriptable__(self): - for hook in self.enc._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.enc) - return self - - -class Generator(torch.nn.Module): - def __init__( - self, - initial_channel, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=0, - ): - super(Generator, self).__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d( - initial_channel, upsample_initial_channel, 7, 1, padding=3 - ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes) - ): - self.resblocks.append(resblock(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - - def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): - x = self.conv_pre(x) - if g is not None: - x = x + self.cond(g) - - for i in range(self.num_upsamples): - x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def __prepare_scriptable__(self): - for l in self.ups: - for hook in l._forward_pre_hooks.values(): - # The hook we want to remove is an instance of WeightNorm class, so - # normally we would do `if isinstance(...)` but this class is not accessible - # because of shadowing, so we check the module name directly. - # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) - - for l in self.resblocks: - for hook in l._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) - return self - - def remove_weight_norm(self): - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - - -class SineGen(torch.nn.Module): - """Definition of sine generator - SineGen(samp_rate, harmonic_num = 0, - sine_amp = 0.1, noise_std = 0.003, - voiced_threshold = 0, - flag_for_pulse=False) - samp_rate: sampling rate in Hz - harmonic_num: number of harmonic overtones (default 0) - sine_amp: amplitude of sine-wavefrom (default 0.1) - noise_std: std of Gaussian noise (default 0.003) - voiced_thoreshold: F0 threshold for U/V classification (default 0) - flag_for_pulse: this SinGen is used inside PulseGen (default False) - Note: when flag_for_pulse is True, the first time step of a voiced - segment is always sin(torch.pi) or cos(0) - """ - - def __init__( - self, - samp_rate, - harmonic_num=0, - sine_amp=0.1, - noise_std=0.003, - voiced_threshold=0, - flag_for_pulse=False, - ): - super(SineGen, self).__init__() - self.sine_amp = sine_amp - self.noise_std = noise_std - self.harmonic_num = harmonic_num - self.dim = self.harmonic_num + 1 - self.sampling_rate = samp_rate - self.voiced_threshold = voiced_threshold - - def _f02uv(self, f0): - # generate uv signal - uv = torch.ones_like(f0) - uv = uv * (f0 > self.voiced_threshold) - if uv.device.type == "privateuseone": # for DirectML - uv = uv.float() - return uv - - def forward(self, f0: torch.Tensor, upp: int): - """sine_tensor, uv = forward(f0) - input F0: tensor(batchsize=1, length, dim=1) - f0 for unvoiced steps should be 0 - output sine_tensor: tensor(batchsize=1, length, dim) - output uv: tensor(batchsize=1, length, 1) - """ - with torch.no_grad(): - f0 = f0[:, None].transpose(1, 2) - f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) - # fundamental component - f0_buf[:, :, 0] = f0[:, :, 0] - for idx in range(self.harmonic_num): - f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( - idx + 2 - ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic - rad_values = ( - f0_buf / self.sampling_rate - ) % 1 ###%1意味着n_har的乘积无法后处理优化 - rand_ini = torch.rand( - f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device - ) - rand_ini[:, 0] = 0 - rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini - tmp_over_one = torch.cumsum( - rad_values, 1 - ) # % 1 #####%1意味着后面的cumsum无法再优化 - tmp_over_one *= upp - tmp_over_one = F.interpolate( - tmp_over_one.transpose(2, 1), - scale_factor=float(upp), - mode="linear", - align_corners=True, - ).transpose(2, 1) - rad_values = F.interpolate( - rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest" - ).transpose( - 2, 1 - ) ####### - tmp_over_one %= 1 - tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 - cumsum_shift = torch.zeros_like(rad_values) - cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - sine_waves = torch.sin( - torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi - ) - sine_waves = sine_waves * self.sine_amp - uv = self._f02uv(f0) - uv = F.interpolate( - uv.transpose(2, 1), scale_factor=float(upp), mode="nearest" - ).transpose(2, 1) - noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 - noise = noise_amp * torch.randn_like(sine_waves) - sine_waves = sine_waves * uv + noise - return sine_waves, uv, noise - - -class SourceModuleHnNSF(torch.nn.Module): - """SourceModule for hn-nsf - SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, - add_noise_std=0.003, voiced_threshod=0) - sampling_rate: sampling_rate in Hz - harmonic_num: number of harmonic above F0 (default: 0) - sine_amp: amplitude of sine source signal (default: 0.1) - add_noise_std: std of additive Gaussian noise (default: 0.003) - note that amplitude of noise in unvoiced is decided - by sine_amp - voiced_threshold: threhold to set U/V given F0 (default: 0) - Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) - F0_sampled (batchsize, length, 1) - Sine_source (batchsize, length, 1) - noise_source (batchsize, length 1) - uv (batchsize, length, 1) - """ - - def __init__( - self, - sampling_rate, - harmonic_num=0, - sine_amp=0.1, - add_noise_std=0.003, - voiced_threshod=0, - is_half=True, - ): - super(SourceModuleHnNSF, self).__init__() - - self.sine_amp = sine_amp - self.noise_std = add_noise_std - self.is_half = is_half - # to produce sine waveforms - self.l_sin_gen = SineGen( - sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod - ) - - # to merge source harmonics into a single excitation - self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) - self.l_tanh = torch.nn.Tanh() - # self.ddtype:int = -1 - - def forward(self, x: torch.Tensor, upp: int = 1): - # if self.ddtype ==-1: - # self.ddtype = self.l_linear.weight.dtype - sine_wavs, uv, _ = self.l_sin_gen(x, upp) - # print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype) - # if self.is_half: - # sine_wavs = sine_wavs.half() - # sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x))) - # print(sine_wavs.dtype,self.ddtype) - # if sine_wavs.dtype != self.l_linear.weight.dtype: - sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) - return sine_merge, None, None # noise, uv - - -class GeneratorNSF(torch.nn.Module): - def __init__( - self, - initial_channel, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels, - sr, - is_half=False, - ): - super(GeneratorNSF, self).__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - - self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates)) - self.m_source = SourceModuleHnNSF( - sampling_rate=sr, harmonic_num=0, is_half=is_half - ) - self.noise_convs = nn.ModuleList() - self.conv_pre = Conv1d( - initial_channel, upsample_initial_channel, 7, 1, padding=3 - ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - c_cur = upsample_initial_channel // (2 ** (i + 1)) - self.ups.append( - weight_norm( - ConvTranspose1d( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - if i + 1 < len(upsample_rates): - stride_f0 = math.prod(upsample_rates[i + 1 :]) - self.noise_convs.append( - Conv1d( - 1, - c_cur, - kernel_size=stride_f0 * 2, - stride=stride_f0, - padding=stride_f0 // 2, - ) - ) - else: - self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes) - ): - self.resblocks.append(resblock(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - - self.upp = math.prod(upsample_rates) - - self.lrelu_slope = modules.LRELU_SLOPE - - def forward(self, x, f0, g: Optional[torch.Tensor] = None): - har_source, noi_source, uv = self.m_source(f0, self.upp) - har_source = har_source.transpose(1, 2) - x = self.conv_pre(x) - if g is not None: - x = x + self.cond(g) - # torch.jit.script() does not support direct indexing of torch modules - # That's why I wrote this - for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)): - if i < self.num_upsamples: - x = F.leaky_relu(x, self.lrelu_slope) - x = ups(x) - x_source = noise_convs(har_source) - x = x + x_source - xs: Optional[torch.Tensor] = None - l = [i * self.num_kernels + j for j in range(self.num_kernels)] - for j, resblock in enumerate(self.resblocks): - if j in l: - if xs is None: - xs = resblock(x) - else: - xs += resblock(x) - # This assertion cannot be ignored! \ - # If ignored, it will cause torch.jit.script() compilation errors - assert isinstance(xs, torch.Tensor) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - return x - - def remove_weight_norm(self): - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - - def __prepare_scriptable__(self): - for l in self.ups: - for hook in l._forward_pre_hooks.values(): - # The hook we want to remove is an instance of WeightNorm class, so - # normally we would do `if isinstance(...)` but this class is not accessible - # because of shadowing, so we check the module name directly. - # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) - for l in self.resblocks: - for hook in self.resblocks._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) - return self - - -sr2sr = { - "32k": 32000, - "40k": 40000, - "48k": 48000, -} +from .attentions import TextEncoder, ResidualCouplingBlock, PosteriorEncoder, Generator, SineGen, SourceModuleHnNSF, GeneratorNSF class SynthesizerTrnMsNSFsidM(nn.Module): @@ -616,8 +29,11 @@ class SynthesizerTrnMsNSFsidM(nn.Module): **kwargs ): super(SynthesizerTrnMsNSFsidM, self).__init__() - if isinstance(sr, str): - sr = sr2sr[sr] + if isinstance(sr, str): sr = { + "32k": 32000, + "40k": 40000, + "48k": 48000, + }[sr] self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels @@ -671,12 +87,6 @@ class SynthesizerTrnMsNSFsidM(nn.Module): inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) - logger.debug( - "gin_channels: " - + str(gin_channels) - + ", self.spk_embed_dim: " - + str(self.spk_embed_dim) - ) self.speaker_map = None def remove_weight_norm(self): @@ -705,177 +115,3 @@ class SynthesizerTrnMsNSFsidM(nn.Module): z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g) return o - - -class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(MultiPeriodDiscriminator, self).__init__() - periods = [2, 3, 5, 7, 11, 17] - # periods = [3, 5, 7, 11, 17, 23, 37] - - discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [ - DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods - ] - self.discriminators = nn.ModuleList(discs) - - def forward(self, y, y_hat): - y_d_rs = [] # - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - # for j in range(len(fmap_r)): - # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class MultiPeriodDiscriminatorV2(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(MultiPeriodDiscriminatorV2, self).__init__() - # periods = [2, 3, 5, 7, 11, 17] - periods = [2, 3, 5, 7, 11, 17, 23, 37] - - discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [ - DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods - ] - self.discriminators = nn.ModuleList(discs) - - def forward(self, y, y_hat): - y_d_rs = [] # - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - # for j in range(len(fmap_r)): - # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ] - ) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f( - Conv2d( - 1, - 32, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 32, - 128, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 128, - 512, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 512, - 1024, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 1024, - 1024, - (kernel_size, 1), - 1, - padding=(get_padding(kernel_size, 1), 0), - ) - ), - ] - ) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - if has_xpu and x.dtype == torch.bfloat16: - x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to( - dtype=torch.bfloat16 - ) - else: - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap diff --git a/infer/lib/infer_pack/modules.py b/infer/lib/infer_pack/modules.py index 51aeaf0..0587c18 100644 --- a/infer/lib/infer_pack/modules.py +++ b/infer/lib/infer_pack/modules.py @@ -10,8 +10,8 @@ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm -from infer.lib.infer_pack import commons -from infer.lib.infer_pack.commons import get_padding, init_weights +from rvc import utils +from rvc.utils import get_padding, call_weight_data_normal_if_Conv from infer.lib.infer_pack.transforms import piecewise_rational_quadratic_transform LRELU_SLOPE = 0.1 @@ -204,7 +204,7 @@ class WN(torch.nn.Module): else: g_l = torch.zeros_like(x_in) - acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = utils.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) acts = self.drop(acts) res_skip_acts = res_skip_layer(acts) @@ -286,7 +286,7 @@ class ResBlock1(torch.nn.Module): ), ] ) - self.convs1.apply(init_weights) + self.convs1.apply(call_weight_data_normal_if_Conv) self.convs2 = nn.ModuleList( [ @@ -322,7 +322,7 @@ class ResBlock1(torch.nn.Module): ), ] ) - self.convs2.apply(init_weights) + self.convs2.apply(call_weight_data_normal_if_Conv) self.lrelu_slope = LRELU_SLOPE def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None): @@ -391,7 +391,7 @@ class ResBlock2(torch.nn.Module): ), ] ) - self.convs.apply(init_weights) + self.convs.apply(call_weight_data_normal_if_Conv) self.lrelu_slope = LRELU_SLOPE def forward(self, x, x_mask: Optional[torch.Tensor] = None): diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 6b37080..a6f3550 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -46,7 +46,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from infer.lib.infer_pack import commons +from rvc import utils from infer.lib.train.data_utils import ( DistributedBucketSampler, TextAudioCollate, @@ -452,7 +452,7 @@ def train_and_evaluate( hps.data.mel_fmin, hps.data.mel_fmax, ) - y_mel = commons.slice_segments( + y_mel = utils.slice_on_last_dim( mel, ids_slice, hps.train.segment_size // hps.data.hop_length ) with autocast(enabled=False): @@ -468,7 +468,7 @@ def train_and_evaluate( ) if hps.train.fp16_run == True: y_hat_mel = y_hat_mel.half() - wave = commons.slice_segments( + wave = utils.slice_on_last_dim( wave, ids_slice * hps.data.hop_length, hps.train.segment_size ) # slice @@ -481,7 +481,7 @@ def train_and_evaluate( optim_d.zero_grad() scaler.scale(loss_disc).backward() scaler.unscale_(optim_d) - grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + grad_norm_d = utils.clip_grad_value_(net_d.parameters(), None) scaler.step(optim_d) with autocast(enabled=hps.train.fp16_run): @@ -496,7 +496,7 @@ def train_and_evaluate( optim_g.zero_grad() scaler.scale(loss_gen_all).backward() scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + grad_norm_g = utils.clip_grad_value_(net_g.parameters(), None) scaler.step(optim_g) scaler.update() diff --git a/rvc/utils.py b/rvc/utils.py new file mode 100644 index 0000000..4cf5cef --- /dev/null +++ b/rvc/utils.py @@ -0,0 +1,79 @@ +from typing import List, Optional, Tuple + +import torch + +def call_weight_data_normal_if_Conv(m: torch.nn.Module): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + mean=0.0 + std=0.01 + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size: int, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def slice_on_last_dim( + x: torch.Tensor, start_indices: List[int], segment_size=4, + ) -> torch.Tensor: + new_shape = x.shape + new_shape[-1] = segment_size + ret = torch.empty(new_shape) + for i in range(x.size(0)): + idx_str = start_indices[i] + idx_end = idx_str + segment_size + ret[i, ..., :] = x[i, ..., idx_str:idx_end] + return ret + + +def rand_slice_segments( + x: torch.Tensor, x_lengths: int = None, segment_size=4, + ) -> Tuple[torch.Tensor, List[int]]: + b, _, t = x.size() + if x_lengths is None: x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_on_last_dim(x, ids_str, segment_size) + return ret, ids_str + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]: + return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist() + + +def sequence_mask( + length: torch.Tensor, max_length: Optional[int] = None, + ) -> torch.BoolTensor: + if max_length is None: + max_length = int(length.max()) + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm