1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

optimize(infer): move nsf & gens into rvc

This commit is contained in:
源文雨
2024-06-09 15:35:48 +09:00
parent 2ce493e07c
commit 5790ea7a73
5 changed files with 545 additions and 471 deletions

View File

@@ -1,425 +1,26 @@
import math
from typing import Optional, List
import torch
from torch import nn
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from torch.nn.utils import spectral_norm, weight_norm
from rvc import residuals
from rvc.residuals import ResidualCouplingBlock
from rvc.utils import (
get_padding,
call_weight_data_normal_if_Conv,
slice_on_last_dim,
rand_slice_segments_on_last_dim,
)
from rvc.encoders import TextEncoder, PosteriorEncoder
from rvc.generators import Generator
from rvc.nsf import NSFGenerator
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class Generator(torch.nn.Module):
def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = residuals.ResBlock1 if resblock == "1" else residuals.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(call_weight_data_normal_if_Conv)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(
self,
x: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
):
"""
if n_res is not None:
assert isinstance(n_res, torch.Tensor)
n = int(n_res.item())
if n != x.shape[-1]:
x = F.interpolate(x, size=n, mode="linear")
"""
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, residuals.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class SineGen(torch.nn.Module):
"""Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(torch.pi) or cos(0)
"""
def __init__(
self,
samp_rate,
harmonic_num=0,
sine_amp=0.1,
noise_std=0.003,
voiced_threshold=0,
flag_for_pulse=False,
):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def _f02uv(self, f0):
# generate uv signal
uv = torch.ones_like(f0)
uv = uv * (f0 > self.voiced_threshold)
if uv.device.type == "privateuseone": # for DirectML
uv = uv.float()
return uv
def forward(self, f0: torch.Tensor, upp: int):
"""sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
with torch.no_grad():
f0 = f0[:, None].transpose(1, 2)
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
# fundamental component
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in range(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
idx + 2
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
rad_values = (
f0_buf / self.sampling_rate
) % 1 ###%1意味着n_har的乘积无法后处理优化
rand_ini = torch.rand(
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(
rad_values, 1
) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one *= upp
tmp_over_one = F.interpolate(
tmp_over_one.transpose(2, 1),
scale_factor=float(upp),
mode="linear",
align_corners=True,
).transpose(2, 1)
rad_values = F.interpolate(
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
).transpose(
2, 1
) #######
tmp_over_one %= 1
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sine_waves = torch.sin(
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
)
sine_waves = sine_waves * self.sine_amp
uv = self._f02uv(f0)
uv = F.interpolate(
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
).transpose(2, 1)
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
"""SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(
self,
sampling_rate,
harmonic_num=0,
sine_amp=0.1,
add_noise_std=0.003,
voiced_threshod=0,
):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(
sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x: torch.Tensor, upp: int = 1):
sine_wavs, _, _ = self.l_sin_gen(x, upp)
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
return sine_merge, None, None # noise, uv
class GeneratorNSF(torch.nn.Module):
def __init__(
self,
initial_channel: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
gin_channels: int,
sr: int,
):
super(GeneratorNSF, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
self.m_source = SourceModuleHnNSF(sampling_rate=sr, harmonic_num=0)
self.noise_convs = nn.ModuleList()
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = residuals.ResBlock1 if resblock == "1" else residuals.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
if i + 1 < len(upsample_rates):
stride_f0 = math.prod(upsample_rates[i + 1 :])
self.noise_convs.append(
Conv1d(
1,
c_cur,
kernel_size=stride_f0 * 2,
stride=stride_f0,
padding=stride_f0 // 2,
)
)
else:
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch: int = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(call_weight_data_normal_if_Conv)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
self.upp = math.prod(upsample_rates)
self.lrelu_slope = residuals.LRELU_SLOPE
def forward(
self,
x,
f0,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
):
har_source, noi_source, uv = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2)
"""
if n_res is not None:
assert isinstance(n_res, torch.Tensor)
n = int(n_res.item())
if n * self.upp != har_source.shape[-1]:
har_source = F.interpolate(har_source, size=n * self.upp, mode="linear")
if n != x.shape[-1]:
x = F.interpolate(x, size=n, mode="linear")
"""
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
# torch.jit.script() does not support direct indexing of torch modules
# That's why I wrote this
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
if i < self.num_upsamples:
x = F.leaky_relu(x, self.lrelu_slope)
x = ups(x)
x_source = noise_convs(har_source)
x = x + x_source
xs: Optional[torch.Tensor] = None
l = [i * self.num_kernels + j for j in range(self.num_kernels)]
for j, resblock in enumerate(self.resblocks):
if j in l:
if xs is None:
xs = resblock(x)
else:
xs += resblock(x)
# This assertion cannot be ignored! \
# If ignored, it will cause torch.jit.script() compilation errors
assert isinstance(xs, torch.Tensor)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks:
for hook in self.resblocks._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self
class SynthesizerTrnMs256NSFsid(nn.Module):
class SynthesizerTrnMsNSFsid(nn.Module):
def __init__(
self,
spec_channels: int,
@@ -440,6 +41,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
spk_embed_dim: int,
gin_channels: int,
sr: str | int,
text_encoder_in_channels: int,
):
super(SynthesizerTrnMs256NSFsid, self).__init__()
if isinstance(sr, str):
@@ -467,7 +69,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
# self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder(
256,
text_encoder_in_channels,
inter_channels,
hidden_channels,
filter_channels,
@@ -476,7 +78,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
kernel_size,
float(p_dropout),
)
self.dec = GeneratorNSF(
self.dec = NSFGenerator(
inter_channels,
resblock,
resblock_kernel_sizes,
@@ -597,29 +199,29 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
return o, x_mask, (z, z_p, m_p, logs_p)
class SynthesizerTrnMs768NSFsid(SynthesizerTrnMs256NSFsid):
class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
sr,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr: str | int,
):
super(SynthesizerTrnMs768NSFsid, self).__init__(
super().__init__(
spec_channels,
segment_size,
inter_channels,
@@ -638,42 +240,76 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMs256NSFsid):
spk_embed_dim,
gin_channels,
sr,
256,
)
del self.enc_p
self.enc_p = TextEncoder(
768,
class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
def __init__(
self,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr: str | int,
):
super().__init__(
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
float(p_dropout),
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
sr,
768,
)
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
sr=None,
**kwargs
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr = None,
):
super(SynthesizerTrnMs256NSFsid_nono, self).__init__()
self.spec_channels = spec_channels
@@ -811,25 +447,24 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
sr=None,
**kwargs
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr = None,
):
super(SynthesizerTrnMs768NSFsid_nono, self).__init__(
spec_channels,
@@ -849,8 +484,6 @@ class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
sr,
**kwargs
)
del self.enc_p
self.enc_p = TextEncoder(

View File

@@ -1,8 +1,7 @@
import torch
from torch import nn
from .models import GeneratorNSF
from rvc.nsf import NSFGenerator
from rvc.encoders import TextEncoder, PosteriorEncoder
from rvc.residuals import ResidualCouplingBlock
@@ -66,7 +65,7 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
kernel_size,
float(p_dropout),
)
self.dec = GeneratorNSF(
self.dec = NSFGenerator(
inter_channels,
resblock,
resblock_kernel_sizes,

View File

@@ -226,6 +226,9 @@ class MultiHeadAttention(nn.Module):
class FFN(nn.Module):
"""
Feed-Forward Network
"""
def __init__(
self,
in_channels: int,

225
rvc/generators.py Normal file
View File

@@ -0,0 +1,225 @@
from typing import Optional, List, Tuple
import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
from .utils import call_weight_data_normal_if_Conv
class Generator(torch.nn.Module):
def __init__(
self,
initial_channel: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
gin_channels: int = 0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
resblock_module = ResBlock1 if resblock == "1" else ResBlock2
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(resblock_module(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(call_weight_data_normal_if_Conv)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def __call__(
self,
x: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().__call__(x, g=g)
def forward(
self,
x: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
):
"""
if n_res is not None:
assert isinstance(n_res, torch.Tensor)
n = int(n_res.item())
if n != x.shape[-1]:
x = F.interpolate(x, size=n, mode="linear")
"""
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
n = i * self.num_kernels
xs = self.resblocks[n](x)
for j in range(1, self.num_kernels):
xs += self.resblocks[n + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class SineGenerator(torch.nn.Module):
"""Definition of sine generator
SineGenerator(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(torch.pi) or cos(0)
"""
def __init__(
self,
samp_rate: int,
harmonic_num: int = 0,
sine_amp: float = 0.1,
noise_std: float = 0.003,
voiced_threshold: int = 0,
):
super(SineGenerator, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def __call__(self, f0: torch.Tensor, upp: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return super().__call__(f0, upp)
def forward(self, f0: torch.Tensor, upp: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
with torch.no_grad():
f0 = f0[:, None].transpose(1, 2)
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
# fundamental component
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in range(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
idx + 2
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
rad_values = (
f0_buf / self.sampling_rate
) % 1 ###%1意味着n_har的乘积无法后处理优化
rand_ini = torch.rand(
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(
rad_values, 1
) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one *= upp
tmp_over_one: torch.Tensor = F.interpolate(
tmp_over_one.transpose(2, 1),
scale_factor = float(upp),
mode="linear",
align_corners=True,
).transpose(2, 1)
rad_values: torch.Tensor = F.interpolate(
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
).transpose(
2, 1
) #######
tmp_over_one %= 1
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sine_waves = torch.sin(
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
)
sine_waves = sine_waves * self.sine_amp
uv = self._f02uv(f0)
uv: torch.Tensor = F.interpolate(
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
).transpose(2, 1)
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
def _f02uv(self, f0):
# generate uv signal
uv = torch.ones_like(f0)
uv = uv * (f0 > self.voiced_threshold)
if uv.device.type == "privateuseone": # for DirectML
uv = uv.float()
return uv

214
rvc/nsf.py Normal file
View File

@@ -0,0 +1,214 @@
from typing import Optional, List
import math
import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from .generators import SineGenerator
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
from .utils import call_weight_data_normal_if_Conv
class SourceModuleHnNSF(torch.nn.Module):
"""SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(
self,
sampling_rate: int,
harmonic_num: int = 0,
sine_amp: float = 0.1,
add_noise_std: float = 0.003,
voiced_threshod: int = 0,
):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGenerator(
sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def __call__(self, x: torch.Tensor, upp: int = 1) -> torch.Tensor:
return super().__call__(x, upp=upp)
def forward(self, x: torch.Tensor, upp: int = 1) -> torch.Tensor:
sine_wavs, _, _ = self.l_sin_gen(x, upp)
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
sine_merge: torch.Tensor = self.l_tanh(self.l_linear(sine_wavs))
return sine_merge #, None, None # noise, uv
class NSFGenerator(torch.nn.Module):
def __init__(
self,
initial_channel: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
gin_channels: int,
sr: int,
):
super(NSFGenerator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
self.m_source = SourceModuleHnNSF(sampling_rate=sr, harmonic_num=0)
self.noise_convs = nn.ModuleList()
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = ResBlock1 if resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
if i + 1 < len(upsample_rates):
stride_f0 = math.prod(upsample_rates[i + 1 :])
self.noise_convs.append(
Conv1d(
1,
c_cur,
kernel_size=stride_f0 * 2,
stride=stride_f0,
padding=stride_f0 // 2,
)
)
else:
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch: int = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(call_weight_data_normal_if_Conv)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
self.upp = math.prod(upsample_rates)
self.lrelu_slope = LRELU_SLOPE
def __call__(
self,
x: torch.Tensor,
f0: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().__call__(x, f0, g=g)
def forward(
self,
x: torch.Tensor,
f0: torch.Tensor,
g: Optional[torch.Tensor] = None,
# n_res: Optional[torch.Tensor] = None,
) -> torch.Tensor:
har_source = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2)
"""
if n_res is not None:
assert isinstance(n_res, torch.Tensor)
n = int(n_res.item())
if n * self.upp != har_source.shape[-1]:
har_source = F.interpolate(har_source, size=n * self.upp, mode="linear")
if n != x.shape[-1]:
x = F.interpolate(x, size=n, mode="linear")
"""
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
# torch.jit.script() does not support direct indexing of torch modules
# That's why I wrote this
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
if i < self.num_upsamples:
x = F.leaky_relu(x, self.lrelu_slope)
x = ups(x)
x_source = noise_convs(har_source)
x = x + x_source
xs: Optional[torch.Tensor] = None
l = [i * self.num_kernels + j for j in range(self.num_kernels)]
for j, resblock in enumerate(self.resblocks):
if j in l:
if xs is None:
xs = resblock(x)
else:
xs += resblock(x)
# This assertion cannot be ignored! \
# If ignored, it will cause torch.jit.script() compilation errors
assert isinstance(xs, torch.Tensor)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks:
for hook in self.resblocks._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self