1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

fix(train): parameter issue

This commit is contained in:
源文雨
2024-06-16 18:20:53 +09:00
parent 1410bd4d15
commit add4642b7e
4 changed files with 20 additions and 13 deletions

2
gui.py
View File

@@ -144,8 +144,6 @@ if __name__ == "__main__":
self.input_devices_indices = None
self.output_devices_indices = None
self.stream = None
if not self.config.nocheck:
self.check_assets()
self.update_devices()
self.launcher()

View File

@@ -4,6 +4,7 @@ import json
import logging
import os
import sys
from copy import deepcopy
import codecs
import numpy as np
@@ -445,6 +446,9 @@ class HParams:
def values(self):
return self.__dict__.values()
def copy(self):
return deepcopy(self)
def __len__(self):
return len(self.__dict__)

View File

@@ -46,7 +46,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from rvc.layers import utils
from infer.lib.train.data_utils import (
DistributedBucketSampler,
TextAudioCollate,
@@ -77,6 +76,11 @@ from infer.lib.train.losses import (
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from infer.lib.train.process_ckpt import save_small_model
from rvc.layers.utils import (
slice_on_last_dim,
total_grad_norm,
)
global_step = 0
@@ -118,7 +122,7 @@ def main():
children[i].join()
def run(rank, n_gpus, hps, logger: logging.Logger):
def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
global global_step
if rank == 0:
# logger = utils.get_logger(hps.model_dir)
@@ -163,20 +167,20 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
persistent_workers=True,
prefetch_factor=8,
)
mdl = hps.copy().model
del mdl.use_spectral_norm
if hps.if_f0 == 1:
net_g = RVC_Model_f0(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
is_half=hps.train.fp16_run,
**mdl,
sr=hps.sample_rate,
)
else:
net_g = RVC_Model_nof0(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
is_half=hps.train.fp16_run,
**mdl,
)
if torch.cuda.is_available():
net_g = net_g.cuda(rank)
@@ -459,7 +463,7 @@ def train_and_evaluate(
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_mel = utils.slice_on_last_dim(
y_mel = slice_on_last_dim(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
with autocast(enabled=False):
@@ -475,7 +479,7 @@ def train_and_evaluate(
)
if hps.train.fp16_run == True:
y_hat_mel = y_hat_mel.half()
wave = utils.slice_on_last_dim(
wave = slice_on_last_dim(
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
@@ -488,7 +492,7 @@ def train_and_evaluate(
optim_d.zero_grad()
scaler.scale(loss_disc).backward()
scaler.unscale_(optim_d)
grad_norm_d = utils.total_grad_norm(net_d.parameters())
grad_norm_d = total_grad_norm(net_d.parameters())
scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run):
@@ -503,7 +507,7 @@ def train_and_evaluate(
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = utils.total_grad_norm(net_g.parameters())
grad_norm_g = total_grad_norm(net_g.parameters())
scaler.step(optim_g)
scaler.update()

3
web.py
View File

@@ -24,7 +24,6 @@ import torch, platform
import numpy as np
import gradio as gr
import faiss
import fairseq
import pathlib
import json
from time import sleep
@@ -72,7 +71,9 @@ if config.dml == True:
res = x.clone().detach()
return res
import fairseq
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
i18n = I18nAuto()
logger.info(i18n)
# 判断是否有能用来训练和加速推理的N卡