1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-09 04:29:50 +08:00

fix(train): parameter issue

This commit is contained in:
源文雨
2024-06-16 18:20:53 +09:00
parent 1410bd4d15
commit add4642b7e
4 changed files with 20 additions and 13 deletions

2
gui.py
View File

@@ -144,8 +144,6 @@ if __name__ == "__main__":
self.input_devices_indices = None self.input_devices_indices = None
self.output_devices_indices = None self.output_devices_indices = None
self.stream = None self.stream = None
if not self.config.nocheck:
self.check_assets()
self.update_devices() self.update_devices()
self.launcher() self.launcher()

View File

@@ -4,6 +4,7 @@ import json
import logging import logging
import os import os
import sys import sys
from copy import deepcopy
import codecs import codecs
import numpy as np import numpy as np
@@ -445,6 +446,9 @@ class HParams:
def values(self): def values(self):
return self.__dict__.values() return self.__dict__.values()
def copy(self):
return deepcopy(self)
def __len__(self): def __len__(self):
return len(self.__dict__) return len(self.__dict__)

View File

@@ -46,7 +46,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from rvc.layers import utils
from infer.lib.train.data_utils import ( from infer.lib.train.data_utils import (
DistributedBucketSampler, DistributedBucketSampler,
TextAudioCollate, TextAudioCollate,
@@ -77,6 +76,11 @@ from infer.lib.train.losses import (
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from infer.lib.train.process_ckpt import save_small_model from infer.lib.train.process_ckpt import save_small_model
from rvc.layers.utils import (
slice_on_last_dim,
total_grad_norm,
)
global_step = 0 global_step = 0
@@ -118,7 +122,7 @@ def main():
children[i].join() children[i].join()
def run(rank, n_gpus, hps, logger: logging.Logger): def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
global global_step global global_step
if rank == 0: if rank == 0:
# logger = utils.get_logger(hps.model_dir) # logger = utils.get_logger(hps.model_dir)
@@ -163,20 +167,20 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
persistent_workers=True, persistent_workers=True,
prefetch_factor=8, prefetch_factor=8,
) )
mdl = hps.copy().model
del mdl.use_spectral_norm
if hps.if_f0 == 1: if hps.if_f0 == 1:
net_g = RVC_Model_f0( net_g = RVC_Model_f0(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
**hps.model, **mdl,
is_half=hps.train.fp16_run,
sr=hps.sample_rate, sr=hps.sample_rate,
) )
else: else:
net_g = RVC_Model_nof0( net_g = RVC_Model_nof0(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
**hps.model, **mdl,
is_half=hps.train.fp16_run,
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
net_g = net_g.cuda(rank) net_g = net_g.cuda(rank)
@@ -459,7 +463,7 @@ def train_and_evaluate(
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax, hps.data.mel_fmax,
) )
y_mel = utils.slice_on_last_dim( y_mel = slice_on_last_dim(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length mel, ids_slice, hps.train.segment_size // hps.data.hop_length
) )
with autocast(enabled=False): with autocast(enabled=False):
@@ -475,7 +479,7 @@ def train_and_evaluate(
) )
if hps.train.fp16_run == True: if hps.train.fp16_run == True:
y_hat_mel = y_hat_mel.half() y_hat_mel = y_hat_mel.half()
wave = utils.slice_on_last_dim( wave = slice_on_last_dim(
wave, ids_slice * hps.data.hop_length, hps.train.segment_size wave, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice ) # slice
@@ -488,7 +492,7 @@ def train_and_evaluate(
optim_d.zero_grad() optim_d.zero_grad()
scaler.scale(loss_disc).backward() scaler.scale(loss_disc).backward()
scaler.unscale_(optim_d) scaler.unscale_(optim_d)
grad_norm_d = utils.total_grad_norm(net_d.parameters()) grad_norm_d = total_grad_norm(net_d.parameters())
scaler.step(optim_d) scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run): with autocast(enabled=hps.train.fp16_run):
@@ -503,7 +507,7 @@ def train_and_evaluate(
optim_g.zero_grad() optim_g.zero_grad()
scaler.scale(loss_gen_all).backward() scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g) scaler.unscale_(optim_g)
grad_norm_g = utils.total_grad_norm(net_g.parameters()) grad_norm_g = total_grad_norm(net_g.parameters())
scaler.step(optim_g) scaler.step(optim_g)
scaler.update() scaler.update()

3
web.py
View File

@@ -24,7 +24,6 @@ import torch, platform
import numpy as np import numpy as np
import gradio as gr import gradio as gr
import faiss import faiss
import fairseq
import pathlib import pathlib
import json import json
from time import sleep from time import sleep
@@ -72,7 +71,9 @@ if config.dml == True:
res = x.clone().detach() res = x.clone().detach()
return res return res
import fairseq
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
i18n = I18nAuto() i18n = I18nAuto()
logger.info(i18n) logger.info(i18n)
# 判断是否有能用来训练和加速推理的N卡 # 判断是否有能用来训练和加速推理的N卡