1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

optimize(infer): move PosteriorEncoder into rvc

This commit is contained in:
源文雨
2024-06-09 14:33:20 +09:00
parent 00cd60b47f
commit 62e6e598ae
3 changed files with 67 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
import math import math
from typing import Optional, Tuple, List from typing import Optional, List
import torch import torch
from torch import nn from torch import nn
@@ -8,82 +8,18 @@ from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from rvc import residuals from rvc import residuals
from rvc.norms import WN
from rvc.residuals import ResidualCouplingBlock from rvc.residuals import ResidualCouplingBlock
from rvc.utils import ( from rvc.utils import (
get_padding, get_padding,
call_weight_data_normal_if_Conv, call_weight_data_normal_if_Conv,
sequence_mask,
slice_on_last_dim, slice_on_last_dim,
rand_slice_segments_on_last_dim, rand_slice_segments_on_last_dim,
) )
from rvc.encoders import TextEncoder from rvc.encoders import TextEncoder, PosteriorEncoder
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
gin_channels=0,
):
super(PosteriorEncoder, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def __call__(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
super().__call__(x, x_lengths, g=g)
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x_mask = torch.unsqueeze(
sequence_mask(x_lengths, x.size(2)),
1,
).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
def remove_weight_norm(self):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
return self
class Generator(torch.nn.Module): class Generator(torch.nn.Module):
def __init__( def __init__(
self, self,

View File

@@ -1,12 +1,9 @@
import torch import torch
from torch import nn from torch import nn
from .models import ( from .models import GeneratorNSF
PosteriorEncoder,
GeneratorNSF,
)
from rvc.encoders import TextEncoder from rvc.encoders import TextEncoder, PosteriorEncoder
from rvc.residuals import ResidualCouplingBlock from rvc.residuals import ResidualCouplingBlock

View File

@@ -1,11 +1,11 @@
import math import math
from typing import Tuple from typing import Tuple, Optional
import torch import torch
from torch import nn from torch import nn
from .attentions import MultiHeadAttention, FFN from .attentions import MultiHeadAttention, FFN
from .norms import LayerNorm from .norms import LayerNorm, WN
from .utils import sequence_mask from .utils import sequence_mask
@@ -160,3 +160,64 @@ class TextEncoder(nn.Module):
stats: torch.Tensor = self.proj(x) * x_mask stats: torch.Tensor = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask return m, logs, x_mask
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
gin_channels=0,
):
super(PosteriorEncoder, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def __call__(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
super().__call__(x, x_lengths, g=g)
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x_mask = torch.unsqueeze(
sequence_mask(x_lengths, x.size(2)),
1,
).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
def remove_weight_norm(self):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
return self