1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

optimize(infer): move PosteriorEncoder into rvc

This commit is contained in:
源文雨
2024-06-09 14:33:20 +09:00
parent 00cd60b47f
commit 62e6e598ae
3 changed files with 67 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
import math
from typing import Optional, Tuple, List
from typing import Optional, List
import torch
from torch import nn
@@ -8,82 +8,18 @@ from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from rvc import residuals
from rvc.norms import WN
from rvc.residuals import ResidualCouplingBlock
from rvc.utils import (
get_padding,
call_weight_data_normal_if_Conv,
sequence_mask,
slice_on_last_dim,
rand_slice_segments_on_last_dim,
)
from rvc.encoders import TextEncoder
from rvc.encoders import TextEncoder, PosteriorEncoder
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
gin_channels=0,
):
super(PosteriorEncoder, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def __call__(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
super().__call__(x, x_lengths, g=g)
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x_mask = torch.unsqueeze(
sequence_mask(x_lengths, x.size(2)),
1,
).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
def remove_weight_norm(self):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
return self
class Generator(torch.nn.Module):
def __init__(
self,

View File

@@ -1,12 +1,9 @@
import torch
from torch import nn
from .models import (
PosteriorEncoder,
GeneratorNSF,
)
from .models import GeneratorNSF
from rvc.encoders import TextEncoder
from rvc.encoders import TextEncoder, PosteriorEncoder
from rvc.residuals import ResidualCouplingBlock

View File

@@ -1,11 +1,11 @@
import math
from typing import Tuple
from typing import Tuple, Optional
import torch
from torch import nn
from .attentions import MultiHeadAttention, FFN
from .norms import LayerNorm
from .norms import LayerNorm, WN
from .utils import sequence_mask
@@ -160,3 +160,64 @@ class TextEncoder(nn.Module):
stats: torch.Tensor = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
gin_channels=0,
):
super(PosteriorEncoder, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def __call__(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
super().__call__(x, x_lengths, g=g)
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x_mask = torch.unsqueeze(
sequence_mask(x_lengths, x.size(2)),
1,
).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
def remove_weight_norm(self):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
return self