From 62e6e598ae1105086c3d21cdbd5c0dbb4c299db2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 9 Jun 2024 14:33:20 +0900 Subject: [PATCH] optimize(infer): move PosteriorEncoder into rvc --- infer/lib/infer_pack/models.py | 68 +---------------------------- infer/lib/infer_pack/models_onnx.py | 7 +-- rvc/encoders.py | 65 ++++++++++++++++++++++++++- 3 files changed, 67 insertions(+), 73 deletions(-) diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index b923baf..b8d21e1 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple, List +from typing import Optional, List import torch from torch import nn @@ -8,82 +8,18 @@ from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm from rvc import residuals -from rvc.norms import WN from rvc.residuals import ResidualCouplingBlock from rvc.utils import ( get_padding, call_weight_data_normal_if_Conv, - sequence_mask, slice_on_last_dim, rand_slice_segments_on_last_dim, ) -from rvc.encoders import TextEncoder +from rvc.encoders import TextEncoder, PosteriorEncoder has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) -class PosteriorEncoder(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - hidden_channels: int, - kernel_size: int, - dilation_rate: int, - n_layers: int, - gin_channels=0, - ): - super(PosteriorEncoder, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = WN( - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def __call__( - self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - super().__call__(x, x_lengths, g=g) - - def forward( - self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x_mask = torch.unsqueeze( - sequence_mask(x_lengths, x.size(2)), - 1, - ).to(x.dtype) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - def remove_weight_norm(self): - self.enc.remove_weight_norm() - - def __prepare_scriptable__(self): - for hook in self.enc._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.enc) - return self - - class Generator(torch.nn.Module): def __init__( self, diff --git a/infer/lib/infer_pack/models_onnx.py b/infer/lib/infer_pack/models_onnx.py index 9fbcc53..dcc0853 100644 --- a/infer/lib/infer_pack/models_onnx.py +++ b/infer/lib/infer_pack/models_onnx.py @@ -1,12 +1,9 @@ import torch from torch import nn -from .models import ( - PosteriorEncoder, - GeneratorNSF, -) +from .models import GeneratorNSF -from rvc.encoders import TextEncoder +from rvc.encoders import TextEncoder, PosteriorEncoder from rvc.residuals import ResidualCouplingBlock diff --git a/rvc/encoders.py b/rvc/encoders.py index c51828f..147118f 100644 --- a/rvc/encoders.py +++ b/rvc/encoders.py @@ -1,11 +1,11 @@ import math -from typing import Tuple +from typing import Tuple, Optional import torch from torch import nn from .attentions import MultiHeadAttention, FFN -from .norms import LayerNorm +from .norms import LayerNorm, WN from .utils import sequence_mask @@ -160,3 +160,64 @@ class TextEncoder(nn.Module): stats: torch.Tensor = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs, x_mask + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, + gin_channels=0, + ): + super(PosteriorEncoder, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def __call__( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + super().__call__(x, x_lengths, g=g) + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x_mask = torch.unsqueeze( + sequence_mask(x_lengths, x.size(2)), + 1, + ).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + + def __prepare_scriptable__(self): + for hook in self.enc._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.weight_norm" + and hook.__class__.__name__ == "WeightNorm" + ): + torch.nn.utils.remove_weight_norm(self.enc) + return self