1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

optimize(train): move discriminators into rvc

This commit is contained in:
源文雨
2024-06-10 01:10:57 +09:00
parent 360318b2f5
commit b23ea7c6e7
4 changed files with 181 additions and 188 deletions

View File

@@ -1 +1,21 @@
{"pth_path": "", "index_path": "", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "", "sg_output_device": "", "sr_type": "sr_device", "threhold": -60.0, "pitch": 12.0, "formant": 0.0, "rms_mix_rate": 0.5, "index_rate": 0.0, "block_time": 0.15, "crossfade_length": 0.08, "extra_time": 2.0, "n_cpu": 4.0, "use_jit": false, "use_pv": false, "f0method": "fcpe"}
{
"pth_path": "",
"index_path": "",
"sg_hostapi": "MME",
"sg_wasapi_exclusive": false,
"sg_input_device": "",
"sg_output_device": "",
"sr_type": "sr_device",
"threhold": -60.0,
"pitch": 12.0,
"formant": 0.0,
"rms_mix_rate": 0.5,
"index_rate": 0.0,
"block_time": 0.15,
"crossfade_length": 0.08,
"extra_time": 2.0,
"n_cpu": 4.0,
"use_jit": false,
"use_pv": false,
"f0method": "fcpe"
}

View File

@@ -2,14 +2,9 @@ from typing import Optional, List
import torch
from torch import nn
from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
from rvc import residuals
from rvc.residuals import ResidualCouplingBlock
from rvc.utils import (
get_padding,
slice_on_last_dim,
rand_slice_segments_on_last_dim,
)
@@ -17,8 +12,6 @@ from rvc.encoders import TextEncoder, PosteriorEncoder
from rvc.generators import Generator
from rvc.nsf import NSFGenerator
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class SynthesizerTrnMsNSFsid(nn.Module):
def __init__(
@@ -41,7 +34,7 @@ class SynthesizerTrnMsNSFsid(nn.Module):
spk_embed_dim: int,
gin_channels: int,
sr: str | int,
text_encoder_in_channels: int,
encoder_dim: int,
):
super(SynthesizerTrnMs256NSFsid, self).__init__()
if isinstance(sr, str):
@@ -69,7 +62,7 @@ class SynthesizerTrnMsNSFsid(nn.Module):
# self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder(
text_encoder_in_channels,
encoder_dim,
inter_channels,
hidden_channels,
filter_channels,
@@ -497,177 +490,3 @@ class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
float(p_dropout),
f0=False,
)
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2, 3, 5, 7, 11, 17]
# periods = [3, 5, 7, 11, 17, 23, 37]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = [] #
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
# for j in range(len(fmap_r)):
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class MultiPeriodDiscriminatorV2(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminatorV2, self).__init__()
# periods = [2, 3, 5, 7, 11, 17]
periods = [2, 3, 5, 7, 11, 17, 23, 37]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = [] #
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
# for j in range(len(fmap_r)):
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, residuals.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(
Conv2d(
1,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=(get_padding(kernel_size, 1), 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
if has_xpu and x.dtype == torch.bfloat16:
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
dtype=torch.bfloat16
)
else:
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, residuals.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap

View File

@@ -1,6 +1,7 @@
import os
import sys
import logging
from typing import Tuple
logger = logging.getLogger(__name__)
logging.getLogger("numba").setLevel(logging.WARNING)
@@ -55,8 +56,9 @@ from infer.lib.train.data_utils import (
TextAudioLoaderMultiNSFsid,
)
from rvc.discriminators import MultiPeriodDiscriminator
if hps.version == "v1":
from infer.lib.infer_pack.models import MultiPeriodDiscriminator
from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
from infer.lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
@@ -65,7 +67,6 @@ else:
from infer.lib.infer_pack.models import (
SynthesizerTrnMs768NSFsid as RVC_Model_f0,
SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
)
from infer.lib.train.losses import (
@@ -180,7 +181,12 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
)
if torch.cuda.is_available():
net_g = net_g.cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
net_d = MultiPeriodDiscriminator(
hps.version,
use_spectral_norm=hps.model.use_spectral_norm,
has_xpu=has_xpu,
)
if torch.cuda.is_available():
net_d = net_d.cuda(rank)
optim_g = torch.optim.AdamW(
@@ -298,7 +304,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
rank, epoch, hps, nets: Tuple[RVC_Model_f0, MultiPeriodDiscriminator], optims, schedulers, scaler, loaders, logger, writers, cache
):
net_g, net_d = nets
optim_g, optim_d = optims

148
rvc/discriminators.py Normal file
View File

@@ -0,0 +1,148 @@
from typing import List, Tuple
import torch
from torch import nn
from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
from .residuals import LRELU_SLOPE
from .utils import get_padding
class MultiPeriodDiscriminator(torch.nn.Module):
"""
version: 'v1' or 'v2'
"""
def __init__(self, version: str, use_spectral_norm: bool = False, has_xpu: bool = False):
super(MultiPeriodDiscriminator, self).__init__()
periods = (2, 3, 5, 7, 11, 17) if version == "v1" else (2, 3, 5, 7, 11, 17, 23, 37)
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=use_spectral_norm),
*(
DiscriminatorP(i, use_spectral_norm=use_spectral_norm, has_xpu=has_xpu) for i in periods
)
])
def __call__(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
return super().__call__(y, y_hat)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm: bool = False):
super(DiscriminatorS, self).__init__()
norm_f = spectral_norm if use_spectral_norm else weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
return super().__call__(x)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorP(torch.nn.Module):
def __init__(
self,
period: int,
kernel_size: int = 5,
stride: int = 3,
use_spectral_norm: bool = False,
has_xpu: bool = False,
):
super(DiscriminatorP, self).__init__()
self.period = period
self.has_xpu = has_xpu
norm_f = spectral_norm if use_spectral_norm else weight_norm
sequence = (1, 32, 128, 512, 1024)
convs_padding = (get_padding(kernel_size, 1), 0)
self.convs = nn.ModuleList()
for i in range(len(sequence)-1):
self.convs.append(norm_f(
Conv2d(
sequence[i],
sequence[i + 1],
(kernel_size, 1),
(stride, 1),
padding=convs_padding,
)
))
self.convs.append(norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=convs_padding,
)
))
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
return super().__call__(x)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
if self.has_xpu and x.dtype == torch.bfloat16:
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
dtype=torch.bfloat16
)
else:
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap