mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-06 01:30:24 +08:00
fix(train): parameter issue
This commit is contained in:
@@ -46,7 +46,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from rvc.layers import utils
|
||||
from infer.lib.train.data_utils import (
|
||||
DistributedBucketSampler,
|
||||
TextAudioCollate,
|
||||
@@ -77,6 +76,11 @@ from infer.lib.train.losses import (
|
||||
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
from infer.lib.train.process_ckpt import save_small_model
|
||||
|
||||
from rvc.layers.utils import (
|
||||
slice_on_last_dim,
|
||||
total_grad_norm,
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
|
||||
|
||||
@@ -118,7 +122,7 @@ def main():
|
||||
children[i].join()
|
||||
|
||||
|
||||
def run(rank, n_gpus, hps, logger: logging.Logger):
|
||||
def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||
global global_step
|
||||
if rank == 0:
|
||||
# logger = utils.get_logger(hps.model_dir)
|
||||
@@ -163,20 +167,20 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
||||
persistent_workers=True,
|
||||
prefetch_factor=8,
|
||||
)
|
||||
mdl = hps.copy().model
|
||||
del mdl.use_spectral_norm
|
||||
if hps.if_f0 == 1:
|
||||
net_g = RVC_Model_f0(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
**hps.model,
|
||||
is_half=hps.train.fp16_run,
|
||||
**mdl,
|
||||
sr=hps.sample_rate,
|
||||
)
|
||||
else:
|
||||
net_g = RVC_Model_nof0(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
**hps.model,
|
||||
is_half=hps.train.fp16_run,
|
||||
**mdl,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
net_g = net_g.cuda(rank)
|
||||
@@ -459,7 +463,7 @@ def train_and_evaluate(
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = utils.slice_on_last_dim(
|
||||
y_mel = slice_on_last_dim(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
)
|
||||
with autocast(enabled=False):
|
||||
@@ -475,7 +479,7 @@ def train_and_evaluate(
|
||||
)
|
||||
if hps.train.fp16_run == True:
|
||||
y_hat_mel = y_hat_mel.half()
|
||||
wave = utils.slice_on_last_dim(
|
||||
wave = slice_on_last_dim(
|
||||
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||
) # slice
|
||||
|
||||
@@ -488,7 +492,7 @@ def train_and_evaluate(
|
||||
optim_d.zero_grad()
|
||||
scaler.scale(loss_disc).backward()
|
||||
scaler.unscale_(optim_d)
|
||||
grad_norm_d = utils.total_grad_norm(net_d.parameters())
|
||||
grad_norm_d = total_grad_norm(net_d.parameters())
|
||||
scaler.step(optim_d)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
@@ -503,7 +507,7 @@ def train_and_evaluate(
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = utils.total_grad_norm(net_g.parameters())
|
||||
grad_norm_g = total_grad_norm(net_g.parameters())
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user