From add4642b7e77d7ca48c820cc06dc58a1c9e8b97a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 16 Jun 2024 18:20:53 +0900 Subject: [PATCH] fix(train): parameter issue --- gui.py | 2 -- infer/lib/train/utils.py | 4 ++++ infer/modules/train/train.py | 24 ++++++++++++++---------- web.py | 3 ++- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/gui.py b/gui.py index 43aa71a..c460e7b 100644 --- a/gui.py +++ b/gui.py @@ -144,8 +144,6 @@ if __name__ == "__main__": self.input_devices_indices = None self.output_devices_indices = None self.stream = None - if not self.config.nocheck: - self.check_assets() self.update_devices() self.launcher() diff --git a/infer/lib/train/utils.py b/infer/lib/train/utils.py index 619624e..88e60e3 100644 --- a/infer/lib/train/utils.py +++ b/infer/lib/train/utils.py @@ -4,6 +4,7 @@ import json import logging import os import sys +from copy import deepcopy import codecs import numpy as np @@ -444,6 +445,9 @@ class HParams: def values(self): return self.__dict__.values() + + def copy(self): + return deepcopy(self) def __len__(self): return len(self.__dict__) diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 4e30005..8d32e4a 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -46,7 +46,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from rvc.layers import utils from infer.lib.train.data_utils import ( DistributedBucketSampler, TextAudioCollate, @@ -77,6 +76,11 @@ from infer.lib.train.losses import ( from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch from infer.lib.train.process_ckpt import save_small_model +from rvc.layers.utils import ( + slice_on_last_dim, + total_grad_norm, +) + global_step = 0 @@ -118,7 +122,7 @@ def main(): children[i].join() -def run(rank, n_gpus, hps, logger: logging.Logger): +def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger): global global_step if rank == 0: # logger = utils.get_logger(hps.model_dir) @@ -163,20 +167,20 @@ def run(rank, n_gpus, hps, logger: logging.Logger): persistent_workers=True, prefetch_factor=8, ) + mdl = hps.copy().model + del mdl.use_spectral_norm if hps.if_f0 == 1: net_g = RVC_Model_f0( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, - **hps.model, - is_half=hps.train.fp16_run, + **mdl, sr=hps.sample_rate, ) else: net_g = RVC_Model_nof0( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, - **hps.model, - is_half=hps.train.fp16_run, + **mdl, ) if torch.cuda.is_available(): net_g = net_g.cuda(rank) @@ -459,7 +463,7 @@ def train_and_evaluate( hps.data.mel_fmin, hps.data.mel_fmax, ) - y_mel = utils.slice_on_last_dim( + y_mel = slice_on_last_dim( mel, ids_slice, hps.train.segment_size // hps.data.hop_length ) with autocast(enabled=False): @@ -475,7 +479,7 @@ def train_and_evaluate( ) if hps.train.fp16_run == True: y_hat_mel = y_hat_mel.half() - wave = utils.slice_on_last_dim( + wave = slice_on_last_dim( wave, ids_slice * hps.data.hop_length, hps.train.segment_size ) # slice @@ -488,7 +492,7 @@ def train_and_evaluate( optim_d.zero_grad() scaler.scale(loss_disc).backward() scaler.unscale_(optim_d) - grad_norm_d = utils.total_grad_norm(net_d.parameters()) + grad_norm_d = total_grad_norm(net_d.parameters()) scaler.step(optim_d) with autocast(enabled=hps.train.fp16_run): @@ -503,7 +507,7 @@ def train_and_evaluate( optim_g.zero_grad() scaler.scale(loss_gen_all).backward() scaler.unscale_(optim_g) - grad_norm_g = utils.total_grad_norm(net_g.parameters()) + grad_norm_g = total_grad_norm(net_g.parameters()) scaler.step(optim_g) scaler.update() diff --git a/web.py b/web.py index b747215..bb2874b 100644 --- a/web.py +++ b/web.py @@ -24,7 +24,6 @@ import torch, platform import numpy as np import gradio as gr import faiss -import fairseq import pathlib import json from time import sleep @@ -72,7 +71,9 @@ if config.dml == True: res = x.clone().detach() return res + import fairseq fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml + i18n = I18nAuto() logger.info(i18n) # 判断是否有能用来训练和加速推理的N卡