1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

optimize(train): move discriminators into rvc

This commit is contained in:
源文雨
2024-06-10 01:10:57 +09:00
parent 360318b2f5
commit b23ea7c6e7
4 changed files with 181 additions and 188 deletions

View File

@@ -1,6 +1,7 @@
import os
import sys
import logging
from typing import Tuple
logger = logging.getLogger(__name__)
logging.getLogger("numba").setLevel(logging.WARNING)
@@ -55,8 +56,9 @@ from infer.lib.train.data_utils import (
TextAudioLoaderMultiNSFsid,
)
from rvc.discriminators import MultiPeriodDiscriminator
if hps.version == "v1":
from infer.lib.infer_pack.models import MultiPeriodDiscriminator
from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
from infer.lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
@@ -65,7 +67,6 @@ else:
from infer.lib.infer_pack.models import (
SynthesizerTrnMs768NSFsid as RVC_Model_f0,
SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
)
from infer.lib.train.losses import (
@@ -180,7 +181,12 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
)
if torch.cuda.is_available():
net_g = net_g.cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
net_d = MultiPeriodDiscriminator(
hps.version,
use_spectral_norm=hps.model.use_spectral_norm,
has_xpu=has_xpu,
)
if torch.cuda.is_available():
net_d = net_d.cuda(rank)
optim_g = torch.optim.AdamW(
@@ -298,7 +304,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
rank, epoch, hps, nets: Tuple[RVC_Model_f0, MultiPeriodDiscriminator], optims, schedulers, scaler, loaders, logger, writers, cache
):
net_g, net_d = nets
optim_g, optim_d = optims