mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 09:10:25 +08:00
optimize(train): move discriminators into rvc
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
@@ -55,8 +56,9 @@ from infer.lib.train.data_utils import (
|
||||
TextAudioLoaderMultiNSFsid,
|
||||
)
|
||||
|
||||
from rvc.discriminators import MultiPeriodDiscriminator
|
||||
|
||||
if hps.version == "v1":
|
||||
from infer.lib.infer_pack.models import MultiPeriodDiscriminator
|
||||
from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
|
||||
from infer.lib.infer_pack.models import (
|
||||
SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
|
||||
@@ -65,7 +67,6 @@ else:
|
||||
from infer.lib.infer_pack.models import (
|
||||
SynthesizerTrnMs768NSFsid as RVC_Model_f0,
|
||||
SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
|
||||
MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
|
||||
)
|
||||
|
||||
from infer.lib.train.losses import (
|
||||
@@ -180,7 +181,12 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
net_g = net_g.cuda(rank)
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
|
||||
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
net_d = MultiPeriodDiscriminator(
|
||||
hps.version,
|
||||
use_spectral_norm=hps.model.use_spectral_norm,
|
||||
has_xpu=has_xpu,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
net_d = net_d.cuda(rank)
|
||||
optim_g = torch.optim.AdamW(
|
||||
@@ -298,7 +304,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
||||
|
||||
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
|
||||
rank, epoch, hps, nets: Tuple[RVC_Model_f0, MultiPeriodDiscriminator], optims, schedulers, scaler, loaders, logger, writers, cache
|
||||
):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
|
||||
Reference in New Issue
Block a user