mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
fix(train): parameter issue
This commit is contained in:
2
gui.py
2
gui.py
@@ -144,8 +144,6 @@ if __name__ == "__main__":
|
||||
self.input_devices_indices = None
|
||||
self.output_devices_indices = None
|
||||
self.stream = None
|
||||
if not self.config.nocheck:
|
||||
self.check_assets()
|
||||
self.update_devices()
|
||||
self.launcher()
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
import codecs
|
||||
import numpy as np
|
||||
@@ -445,6 +446,9 @@ class HParams:
|
||||
def values(self):
|
||||
return self.__dict__.values()
|
||||
|
||||
def copy(self):
|
||||
return deepcopy(self)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__dict__)
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from rvc.layers import utils
|
||||
from infer.lib.train.data_utils import (
|
||||
DistributedBucketSampler,
|
||||
TextAudioCollate,
|
||||
@@ -77,6 +76,11 @@ from infer.lib.train.losses import (
|
||||
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
from infer.lib.train.process_ckpt import save_small_model
|
||||
|
||||
from rvc.layers.utils import (
|
||||
slice_on_last_dim,
|
||||
total_grad_norm,
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
|
||||
|
||||
@@ -118,7 +122,7 @@ def main():
|
||||
children[i].join()
|
||||
|
||||
|
||||
def run(rank, n_gpus, hps, logger: logging.Logger):
|
||||
def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||
global global_step
|
||||
if rank == 0:
|
||||
# logger = utils.get_logger(hps.model_dir)
|
||||
@@ -163,20 +167,20 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
||||
persistent_workers=True,
|
||||
prefetch_factor=8,
|
||||
)
|
||||
mdl = hps.copy().model
|
||||
del mdl.use_spectral_norm
|
||||
if hps.if_f0 == 1:
|
||||
net_g = RVC_Model_f0(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
**hps.model,
|
||||
is_half=hps.train.fp16_run,
|
||||
**mdl,
|
||||
sr=hps.sample_rate,
|
||||
)
|
||||
else:
|
||||
net_g = RVC_Model_nof0(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
**hps.model,
|
||||
is_half=hps.train.fp16_run,
|
||||
**mdl,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
net_g = net_g.cuda(rank)
|
||||
@@ -459,7 +463,7 @@ def train_and_evaluate(
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = utils.slice_on_last_dim(
|
||||
y_mel = slice_on_last_dim(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
)
|
||||
with autocast(enabled=False):
|
||||
@@ -475,7 +479,7 @@ def train_and_evaluate(
|
||||
)
|
||||
if hps.train.fp16_run == True:
|
||||
y_hat_mel = y_hat_mel.half()
|
||||
wave = utils.slice_on_last_dim(
|
||||
wave = slice_on_last_dim(
|
||||
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||
) # slice
|
||||
|
||||
@@ -488,7 +492,7 @@ def train_and_evaluate(
|
||||
optim_d.zero_grad()
|
||||
scaler.scale(loss_disc).backward()
|
||||
scaler.unscale_(optim_d)
|
||||
grad_norm_d = utils.total_grad_norm(net_d.parameters())
|
||||
grad_norm_d = total_grad_norm(net_d.parameters())
|
||||
scaler.step(optim_d)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
@@ -503,7 +507,7 @@ def train_and_evaluate(
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = utils.total_grad_norm(net_g.parameters())
|
||||
grad_norm_g = total_grad_norm(net_g.parameters())
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
|
||||
3
web.py
3
web.py
@@ -24,7 +24,6 @@ import torch, platform
|
||||
import numpy as np
|
||||
import gradio as gr
|
||||
import faiss
|
||||
import fairseq
|
||||
import pathlib
|
||||
import json
|
||||
from time import sleep
|
||||
@@ -72,7 +71,9 @@ if config.dml == True:
|
||||
res = x.clone().detach()
|
||||
return res
|
||||
|
||||
import fairseq
|
||||
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
|
||||
|
||||
i18n = I18nAuto()
|
||||
logger.info(i18n)
|
||||
# 判断是否有能用来训练和加速推理的N卡
|
||||
|
||||
Reference in New Issue
Block a user