mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
139 lines
4.7 KiB
Python
139 lines
4.7 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from .utils import activate_add_tanh_sigmoid_multiply
|
|
|
|
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, channels: int, eps: float = 1e-5):
|
|
super(LayerNorm, self).__init__()
|
|
self.channels = channels
|
|
self.eps = eps
|
|
|
|
self.gamma = nn.Parameter(torch.ones(channels))
|
|
self.beta = nn.Parameter(torch.zeros(channels))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x = x.transpose(1, -1)
|
|
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
|
return x.transpose(1, -1)
|
|
|
|
|
|
class WN(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_channels: int,
|
|
kernel_size: int,
|
|
dilation_rate: int,
|
|
n_layers: int,
|
|
gin_channels: int = 0,
|
|
p_dropout: float = 0,
|
|
):
|
|
super(WN, self).__init__()
|
|
assert kernel_size % 2 == 1
|
|
self.hidden_channels = hidden_channels
|
|
self.kernel_size = (kernel_size,)
|
|
self.dilation_rate = dilation_rate
|
|
self.n_layers = n_layers
|
|
self.gin_channels = gin_channels
|
|
self.p_dropout = float(p_dropout)
|
|
|
|
self.in_layers = torch.nn.ModuleList()
|
|
self.res_skip_layers = torch.nn.ModuleList()
|
|
self.drop = nn.Dropout(float(p_dropout))
|
|
|
|
if gin_channels != 0:
|
|
cond_layer = torch.nn.Conv1d(
|
|
gin_channels, 2 * hidden_channels * n_layers, 1
|
|
)
|
|
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
|
|
|
|
for i in range(n_layers):
|
|
dilation = dilation_rate**i
|
|
padding = int((kernel_size * dilation - dilation) / 2)
|
|
in_layer = torch.nn.Conv1d(
|
|
hidden_channels,
|
|
2 * hidden_channels,
|
|
kernel_size,
|
|
dilation=dilation,
|
|
padding=padding,
|
|
)
|
|
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
|
self.in_layers.append(in_layer)
|
|
|
|
# last one is not necessary
|
|
if i < n_layers - 1:
|
|
res_skip_channels = 2 * hidden_channels
|
|
else:
|
|
res_skip_channels = hidden_channels
|
|
|
|
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
|
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
|
self.res_skip_layers.append(res_skip_layer)
|
|
|
|
def __call__(
|
|
self,
|
|
x: torch.Tensor,
|
|
x_mask: torch.Tensor,
|
|
g: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return super().__call__(x, x_mask, g=g)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
x_mask: torch.Tensor,
|
|
g: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
output = torch.zeros_like(x)
|
|
|
|
if g is not None:
|
|
g = self.cond_layer(g)
|
|
|
|
for i, (in_layer, res_skip_layer) in enumerate(
|
|
zip(self.in_layers, self.res_skip_layers)
|
|
):
|
|
x_in: torch.Tensor = in_layer(x)
|
|
if g is not None:
|
|
cond_offset = i * 2 * self.hidden_channels
|
|
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
|
else:
|
|
g_l = torch.zeros_like(x_in)
|
|
|
|
acts = activate_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels)
|
|
acts: torch.Tensor = self.drop(acts)
|
|
|
|
res_skip_acts: torch.Tensor = res_skip_layer(acts)
|
|
if i < self.n_layers - 1:
|
|
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
|
x = (x + res_acts) * x_mask
|
|
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
|
else:
|
|
output = output + res_skip_acts
|
|
return output * x_mask
|
|
|
|
def remove_weight_norm(self):
|
|
if self.gin_channels != 0:
|
|
remove_parametrizations(self.cond_layer, "weight")
|
|
for l in self.in_layers:
|
|
remove_parametrizations(l, "weight")
|
|
for l in self.res_skip_layers:
|
|
remove_parametrizations(l, "weight")
|
|
|
|
def __prepare_scriptable__(self):
|
|
if self.gin_channels != 0:
|
|
if is_parametrized(self.cond_layer, "weight"):
|
|
remove_parametrizations(self.cond_layer, "weight")
|
|
for l in self.in_layers:
|
|
if is_parametrized(l, "weight"):
|
|
remove_parametrizations(l, "weight")
|
|
for l in self.res_skip_layers:
|
|
if is_parametrized(l, "weight"):
|
|
remove_parametrizations(l, "weight")
|
|
return self
|