mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-06 01:30:24 +08:00
optimize(rvc): gather residuals
This commit is contained in:
@@ -9,6 +9,7 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
from rvc import residuals
|
||||
|
||||
from rvc.norms import WN
|
||||
from rvc.residuals import ResidualCouplingBlock
|
||||
from rvc.utils import (
|
||||
get_padding,
|
||||
call_weight_data_normal_if_Conv,
|
||||
@@ -21,92 +22,6 @@ from rvc.encoders import TextEncoder
|
||||
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
class Flip(nn.Module):
|
||||
"""
|
||||
torch.jit.script() Compiled functions
|
||||
can't take variable number of arguments or
|
||||
use keyword-only arguments with defaults
|
||||
"""
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||
return x, logdet
|
||||
else:
|
||||
return x, torch.zeros([1], device=x.device)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
n_flows=4,
|
||||
gin_channels=0,
|
||||
):
|
||||
super(ResidualCouplingBlock, self).__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
residuals.ResidualCouplingLayer(
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=gin_channels,
|
||||
mean_only=True,
|
||||
)
|
||||
)
|
||||
self.flows.append(self.Flip())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse: bool = False,
|
||||
):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x, _ = flow.forward(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for i in range(self.n_flows):
|
||||
self.flows[i * 2].remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for i in range(self.n_flows):
|
||||
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -425,15 +340,15 @@ class SourceModuleHnNSF(torch.nn.Module):
|
||||
class GeneratorNSF(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
initial_channel,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels,
|
||||
sr,
|
||||
initial_channel: int,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates: List[int],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: List[int],
|
||||
gin_channels: int,
|
||||
sr: int,
|
||||
):
|
||||
super(GeneratorNSF, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
@@ -479,7 +394,7 @@ class GeneratorNSF(torch.nn.Module):
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
ch: int = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
@@ -817,7 +732,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
p_dropout,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
|
||||
@@ -2,12 +2,12 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from .models import (
|
||||
ResidualCouplingBlock,
|
||||
PosteriorEncoder,
|
||||
GeneratorNSF,
|
||||
)
|
||||
|
||||
from rvc.encoders import TextEncoder
|
||||
from rvc.residuals import ResidualCouplingBlock
|
||||
|
||||
|
||||
class SynthesizerTrnMsNSFsidM(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user