From b23ea7c6e77c43e339adce5bc4b1680980484a6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Mon, 10 Jun 2024 01:10:57 +0900 Subject: [PATCH] optimize(train): move discriminators into rvc --- configs/config.json | 22 +++- infer/lib/infer_pack/models.py | 185 +-------------------------------- infer/modules/train/train.py | 14 ++- rvc/discriminators.py | 148 ++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 188 deletions(-) create mode 100644 rvc/discriminators.py diff --git a/configs/config.json b/configs/config.json index 7f2142f..1aaa63f 100644 --- a/configs/config.json +++ b/configs/config.json @@ -1 +1,21 @@ -{"pth_path": "", "index_path": "", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "", "sg_output_device": "", "sr_type": "sr_device", "threhold": -60.0, "pitch": 12.0, "formant": 0.0, "rms_mix_rate": 0.5, "index_rate": 0.0, "block_time": 0.15, "crossfade_length": 0.08, "extra_time": 2.0, "n_cpu": 4.0, "use_jit": false, "use_pv": false, "f0method": "fcpe"} \ No newline at end of file +{ + "pth_path": "", + "index_path": "", + "sg_hostapi": "MME", + "sg_wasapi_exclusive": false, + "sg_input_device": "", + "sg_output_device": "", + "sr_type": "sr_device", + "threhold": -60.0, + "pitch": 12.0, + "formant": 0.0, + "rms_mix_rate": 0.5, + "index_rate": 0.0, + "block_time": 0.15, + "crossfade_length": 0.08, + "extra_time": 2.0, + "n_cpu": 4.0, + "use_jit": false, + "use_pv": false, + "f0method": "fcpe" +} \ No newline at end of file diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index 52ab241..16d19c4 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -2,14 +2,9 @@ from typing import Optional, List import torch from torch import nn -from torch.nn import Conv1d, Conv2d -from torch.nn import functional as F -from torch.nn.utils import spectral_norm, weight_norm -from rvc import residuals from rvc.residuals import ResidualCouplingBlock from rvc.utils import ( - get_padding, slice_on_last_dim, rand_slice_segments_on_last_dim, ) @@ -17,8 +12,6 @@ from rvc.encoders import TextEncoder, PosteriorEncoder from rvc.generators import Generator from rvc.nsf import NSFGenerator -has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) - class SynthesizerTrnMsNSFsid(nn.Module): def __init__( @@ -41,7 +34,7 @@ class SynthesizerTrnMsNSFsid(nn.Module): spk_embed_dim: int, gin_channels: int, sr: str | int, - text_encoder_in_channels: int, + encoder_dim: int, ): super(SynthesizerTrnMs256NSFsid, self).__init__() if isinstance(sr, str): @@ -69,7 +62,7 @@ class SynthesizerTrnMsNSFsid(nn.Module): # self.hop_length = hop_length# self.spk_embed_dim = spk_embed_dim self.enc_p = TextEncoder( - text_encoder_in_channels, + encoder_dim, inter_channels, hidden_channels, filter_channels, @@ -497,177 +490,3 @@ class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono): float(p_dropout), f0=False, ) - - -class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(MultiPeriodDiscriminator, self).__init__() - periods = [2, 3, 5, 7, 11, 17] - # periods = [3, 5, 7, 11, 17, 23, 37] - - discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [ - DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods - ] - self.discriminators = nn.ModuleList(discs) - - def forward(self, y, y_hat): - y_d_rs = [] # - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - # for j in range(len(fmap_r)): - # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class MultiPeriodDiscriminatorV2(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(MultiPeriodDiscriminatorV2, self).__init__() - # periods = [2, 3, 5, 7, 11, 17] - periods = [2, 3, 5, 7, 11, 17, 23, 37] - - discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [ - DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods - ] - self.discriminators = nn.ModuleList(discs) - - def forward(self, y, y_hat): - y_d_rs = [] # - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - # for j in range(len(fmap_r)): - # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ] - ) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, residuals.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f( - Conv2d( - 1, - 32, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 32, - 128, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 128, - 512, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 512, - 1024, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 1024, - 1024, - (kernel_size, 1), - 1, - padding=(get_padding(kernel_size, 1), 0), - ) - ), - ] - ) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - if has_xpu and x.dtype == torch.bfloat16: - x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to( - dtype=torch.bfloat16 - ) - else: - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, residuals.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index ef40b6a..c511fcd 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -1,6 +1,7 @@ import os import sys import logging +from typing import Tuple logger = logging.getLogger(__name__) logging.getLogger("numba").setLevel(logging.WARNING) @@ -55,8 +56,9 @@ from infer.lib.train.data_utils import ( TextAudioLoaderMultiNSFsid, ) +from rvc.discriminators import MultiPeriodDiscriminator + if hps.version == "v1": - from infer.lib.infer_pack.models import MultiPeriodDiscriminator from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0 from infer.lib.infer_pack.models import ( SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0, @@ -65,7 +67,6 @@ else: from infer.lib.infer_pack.models import ( SynthesizerTrnMs768NSFsid as RVC_Model_f0, SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0, - MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator, ) from infer.lib.train.losses import ( @@ -180,7 +181,12 @@ def run(rank, n_gpus, hps, logger: logging.Logger): ) if torch.cuda.is_available(): net_g = net_g.cuda(rank) - net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm) + has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) + net_d = MultiPeriodDiscriminator( + hps.version, + use_spectral_norm=hps.model.use_spectral_norm, + has_xpu=has_xpu, + ) if torch.cuda.is_available(): net_d = net_d.cuda(rank) optim_g = torch.optim.AdamW( @@ -298,7 +304,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger): def train_and_evaluate( - rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache + rank, epoch, hps, nets: Tuple[RVC_Model_f0, MultiPeriodDiscriminator], optims, schedulers, scaler, loaders, logger, writers, cache ): net_g, net_d = nets optim_g, optim_d = optims diff --git a/rvc/discriminators.py b/rvc/discriminators.py new file mode 100644 index 0000000..784265a --- /dev/null +++ b/rvc/discriminators.py @@ -0,0 +1,148 @@ +from typing import List, Tuple + +import torch +from torch import nn +from torch.nn import Conv1d, Conv2d +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .residuals import LRELU_SLOPE +from .utils import get_padding + + +class MultiPeriodDiscriminator(torch.nn.Module): + """ + version: 'v1' or 'v2' + """ + def __init__(self, version: str, use_spectral_norm: bool = False, has_xpu: bool = False): + super(MultiPeriodDiscriminator, self).__init__() + periods = (2, 3, 5, 7, 11, 17) if version == "v1" else (2, 3, 5, 7, 11, 17, 23, 37) + + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=use_spectral_norm), + *( + DiscriminatorP(i, use_spectral_norm=use_spectral_norm, has_xpu=has_xpu) for i in periods + ) + ]) + + def __call__(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + return super().__call__(y, y_hat) + + def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm: bool = False): + super(DiscriminatorS, self).__init__() + norm_f = spectral_norm if use_spectral_norm else weight_norm + + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorP(torch.nn.Module): + def __init__( + self, + period: int, + kernel_size: int = 5, + stride: int = 3, + use_spectral_norm: bool = False, + has_xpu: bool = False, + ): + super(DiscriminatorP, self).__init__() + self.period = period + self.has_xpu = has_xpu + norm_f = spectral_norm if use_spectral_norm else weight_norm + sequence = (1, 32, 128, 512, 1024) + convs_padding = (get_padding(kernel_size, 1), 0) + + self.convs = nn.ModuleList() + for i in range(len(sequence)-1): + self.convs.append(norm_f( + Conv2d( + sequence[i], + sequence[i + 1], + (kernel_size, 1), + (stride, 1), + padding=convs_padding, + ) + )) + self.convs.append(norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=convs_padding, + ) + )) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + if self.has_xpu and x.dtype == torch.bfloat16: + x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to( + dtype=torch.bfloat16 + ) + else: + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap