mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
215 lines
7.5 KiB
Python
215 lines
7.5 KiB
Python
from typing import Optional, List, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import Conv1d, ConvTranspose1d
|
|
from torch.nn import functional as F
|
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
|
|
|
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
|
|
from .utils import call_weight_data_normal_if_Conv
|
|
|
|
|
|
class Generator(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
initial_channel: int,
|
|
resblock: str,
|
|
resblock_kernel_sizes: List[int],
|
|
resblock_dilation_sizes: List[List[int]],
|
|
upsample_rates: List[int],
|
|
upsample_initial_channel: int,
|
|
upsample_kernel_sizes: List[int],
|
|
gin_channels: int = 0,
|
|
):
|
|
super(Generator, self).__init__()
|
|
self.num_kernels = len(resblock_kernel_sizes)
|
|
self.num_upsamples = len(upsample_rates)
|
|
|
|
self.conv_pre = Conv1d(
|
|
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
|
)
|
|
|
|
self.ups = nn.ModuleList()
|
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
self.ups.append(
|
|
weight_norm(
|
|
ConvTranspose1d(
|
|
upsample_initial_channel // (2**i),
|
|
upsample_initial_channel // (2 ** (i + 1)),
|
|
k,
|
|
u,
|
|
padding=(k - u) // 2,
|
|
)
|
|
)
|
|
)
|
|
|
|
self.resblocks = nn.ModuleList()
|
|
resblock_module = ResBlock1 if resblock == "1" else ResBlock2
|
|
ch = 0
|
|
for i in range(len(self.ups)):
|
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
|
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
|
self.resblocks.append(resblock_module(ch, k, d))
|
|
|
|
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
|
self.ups.apply(call_weight_data_normal_if_Conv)
|
|
|
|
if gin_channels != 0:
|
|
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
|
|
|
def __call__(
|
|
self,
|
|
x: torch.Tensor,
|
|
g: Optional[torch.Tensor] = None,
|
|
n_res: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
return super().__call__(x, g=g, n_res=n_res)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
g: Optional[torch.Tensor] = None,
|
|
n_res: Optional[int] = None,
|
|
):
|
|
if n_res is not None:
|
|
n = int(n_res)
|
|
if n != x.shape[-1]:
|
|
x = F.interpolate(x, size=n, mode="linear")
|
|
|
|
x = self.conv_pre(x)
|
|
if g is not None:
|
|
x = x + self.cond(g)
|
|
|
|
for i in range(self.num_upsamples):
|
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
|
x = self.ups[i](x)
|
|
n = i * self.num_kernels
|
|
xs = self.resblocks[n](x)
|
|
for j in range(1, self.num_kernels):
|
|
xs += self.resblocks[n + j](x)
|
|
x = xs / self.num_kernels
|
|
|
|
x = F.leaky_relu(x)
|
|
x = self.conv_post(x)
|
|
x = torch.tanh(x)
|
|
|
|
return x
|
|
|
|
def __prepare_scriptable__(self):
|
|
for l in self.ups:
|
|
for hook in l._forward_pre_hooks.values():
|
|
# The hook we want to remove is an instance of WeightNorm class, so
|
|
# normally we would do `if isinstance(...)` but this class is not accessible
|
|
# because of shadowing, so we check the module name directly.
|
|
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
|
if (
|
|
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
and hook.__class__.__name__ == "WeightNorm"
|
|
):
|
|
torch.nn.utils.remove_weight_norm(l)
|
|
|
|
for l in self.resblocks:
|
|
for hook in l._forward_pre_hooks.values():
|
|
if (
|
|
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
and hook.__class__.__name__ == "WeightNorm"
|
|
):
|
|
torch.nn.utils.remove_weight_norm(l)
|
|
return self
|
|
|
|
def remove_weight_norm(self):
|
|
for l in self.ups:
|
|
remove_weight_norm(l)
|
|
for l in self.resblocks:
|
|
l.remove_weight_norm()
|
|
|
|
|
|
class SineGenerator(torch.nn.Module):
|
|
"""Definition of sine generator
|
|
SineGenerator(samp_rate, harmonic_num = 0,
|
|
sine_amp = 0.1, noise_std = 0.003,
|
|
voiced_threshold = 0,
|
|
flag_for_pulse=False)
|
|
samp_rate: sampling rate in Hz
|
|
harmonic_num: number of harmonic overtones (default 0)
|
|
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
|
noise_std: std of Gaussian noise (default 0.003)
|
|
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
|
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
|
Note: when flag_for_pulse is True, the first time step of a voiced
|
|
segment is always sin(torch.pi) or cos(0)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
samp_rate: int,
|
|
harmonic_num: int = 0,
|
|
sine_amp: float = 0.1,
|
|
noise_std: float = 0.003,
|
|
voiced_threshold: int = 0,
|
|
):
|
|
super(SineGenerator, self).__init__()
|
|
self.sine_amp = sine_amp
|
|
self.noise_std = noise_std
|
|
self.harmonic_num = harmonic_num
|
|
self.dim = harmonic_num + 1
|
|
self.sampling_rate = samp_rate
|
|
self.voiced_threshold = voiced_threshold
|
|
|
|
def _f02sine(self, f0: torch.Tensor, upp: int):
|
|
"""
|
|
f0: (batchsize, length, dim)
|
|
|
|
where dim indicates fundamental tone and overtones
|
|
"""
|
|
a = torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
|
|
rad = f0 / self.sampling_rate * a
|
|
rad2 = torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5
|
|
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
|
|
rad += F.pad(rad_acc, (0, 0, 1, 0), mode="constant")
|
|
rad = rad.reshape(f0.shape[0], -1, 1)
|
|
b = torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(
|
|
1, 1, -1
|
|
)
|
|
rad *= b
|
|
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
|
|
rand_ini[..., 0] = 0
|
|
rad += rand_ini
|
|
sines = torch.sin(2 * torch.pi * rad)
|
|
return sines
|
|
|
|
def __call__(
|
|
self, f0: torch.Tensor, upp: int
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
return super().__call__(f0, upp)
|
|
|
|
def forward(
|
|
self, f0: torch.Tensor, upp: int
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""sine_tensor, uv = forward(f0)
|
|
input F0: tensor(batchsize=1, length, dim=1)
|
|
f0 for unvoiced steps should be 0
|
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
|
output uv: tensor(batchsize=1, length, 1)
|
|
"""
|
|
with torch.no_grad():
|
|
f0 = f0.unsqueeze(-1)
|
|
sine_waves = self._f02sine(f0, upp) * self.sine_amp
|
|
uv = self._f02uv(f0)
|
|
uv: torch.Tensor = F.interpolate(
|
|
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
|
|
).transpose(2, 1)
|
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
|
noise = noise_amp * torch.randn_like(sine_waves)
|
|
sine_waves = sine_waves * uv + noise
|
|
return sine_waves, uv, noise
|
|
|
|
def _f02uv(self, f0):
|
|
# generate uv signal
|
|
uv = torch.ones_like(f0)
|
|
uv = uv * (f0 > self.voiced_threshold)
|
|
if uv.device.type == "privateuseone": # for DirectML
|
|
uv = uv.float()
|
|
return uv
|