mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-09 04:29:50 +08:00
fix(train): parameter issue
This commit is contained in:
2
gui.py
2
gui.py
@@ -144,8 +144,6 @@ if __name__ == "__main__":
|
|||||||
self.input_devices_indices = None
|
self.input_devices_indices = None
|
||||||
self.output_devices_indices = None
|
self.output_devices_indices = None
|
||||||
self.stream = None
|
self.stream = None
|
||||||
if not self.config.nocheck:
|
|
||||||
self.check_assets()
|
|
||||||
self.update_devices()
|
self.update_devices()
|
||||||
self.launcher()
|
self.launcher()
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -445,6 +446,9 @@ class HParams:
|
|||||||
def values(self):
|
def values(self):
|
||||||
return self.__dict__.values()
|
return self.__dict__.values()
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return deepcopy(self)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.__dict__)
|
return len(self.__dict__)
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from rvc.layers import utils
|
|
||||||
from infer.lib.train.data_utils import (
|
from infer.lib.train.data_utils import (
|
||||||
DistributedBucketSampler,
|
DistributedBucketSampler,
|
||||||
TextAudioCollate,
|
TextAudioCollate,
|
||||||
@@ -77,6 +76,11 @@ from infer.lib.train.losses import (
|
|||||||
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||||
from infer.lib.train.process_ckpt import save_small_model
|
from infer.lib.train.process_ckpt import save_small_model
|
||||||
|
|
||||||
|
from rvc.layers.utils import (
|
||||||
|
slice_on_last_dim,
|
||||||
|
total_grad_norm,
|
||||||
|
)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
|
|
||||||
@@ -118,7 +122,7 @@ def main():
|
|||||||
children[i].join()
|
children[i].join()
|
||||||
|
|
||||||
|
|
||||||
def run(rank, n_gpus, hps, logger: logging.Logger):
|
def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||||
global global_step
|
global global_step
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# logger = utils.get_logger(hps.model_dir)
|
# logger = utils.get_logger(hps.model_dir)
|
||||||
@@ -163,20 +167,20 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
|
|||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
prefetch_factor=8,
|
prefetch_factor=8,
|
||||||
)
|
)
|
||||||
|
mdl = hps.copy().model
|
||||||
|
del mdl.use_spectral_norm
|
||||||
if hps.if_f0 == 1:
|
if hps.if_f0 == 1:
|
||||||
net_g = RVC_Model_f0(
|
net_g = RVC_Model_f0(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
**hps.model,
|
**mdl,
|
||||||
is_half=hps.train.fp16_run,
|
|
||||||
sr=hps.sample_rate,
|
sr=hps.sample_rate,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
net_g = RVC_Model_nof0(
|
net_g = RVC_Model_nof0(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
**hps.model,
|
**mdl,
|
||||||
is_half=hps.train.fp16_run,
|
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
net_g = net_g.cuda(rank)
|
net_g = net_g.cuda(rank)
|
||||||
@@ -459,7 +463,7 @@ def train_and_evaluate(
|
|||||||
hps.data.mel_fmin,
|
hps.data.mel_fmin,
|
||||||
hps.data.mel_fmax,
|
hps.data.mel_fmax,
|
||||||
)
|
)
|
||||||
y_mel = utils.slice_on_last_dim(
|
y_mel = slice_on_last_dim(
|
||||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||||
)
|
)
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
@@ -475,7 +479,7 @@ def train_and_evaluate(
|
|||||||
)
|
)
|
||||||
if hps.train.fp16_run == True:
|
if hps.train.fp16_run == True:
|
||||||
y_hat_mel = y_hat_mel.half()
|
y_hat_mel = y_hat_mel.half()
|
||||||
wave = utils.slice_on_last_dim(
|
wave = slice_on_last_dim(
|
||||||
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||||
) # slice
|
) # slice
|
||||||
|
|
||||||
@@ -488,7 +492,7 @@ def train_and_evaluate(
|
|||||||
optim_d.zero_grad()
|
optim_d.zero_grad()
|
||||||
scaler.scale(loss_disc).backward()
|
scaler.scale(loss_disc).backward()
|
||||||
scaler.unscale_(optim_d)
|
scaler.unscale_(optim_d)
|
||||||
grad_norm_d = utils.total_grad_norm(net_d.parameters())
|
grad_norm_d = total_grad_norm(net_d.parameters())
|
||||||
scaler.step(optim_d)
|
scaler.step(optim_d)
|
||||||
|
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
@@ -503,7 +507,7 @@ def train_and_evaluate(
|
|||||||
optim_g.zero_grad()
|
optim_g.zero_grad()
|
||||||
scaler.scale(loss_gen_all).backward()
|
scaler.scale(loss_gen_all).backward()
|
||||||
scaler.unscale_(optim_g)
|
scaler.unscale_(optim_g)
|
||||||
grad_norm_g = utils.total_grad_norm(net_g.parameters())
|
grad_norm_g = total_grad_norm(net_g.parameters())
|
||||||
scaler.step(optim_g)
|
scaler.step(optim_g)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
|
|||||||
3
web.py
3
web.py
@@ -24,7 +24,6 @@ import torch, platform
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import faiss
|
import faiss
|
||||||
import fairseq
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import json
|
import json
|
||||||
from time import sleep
|
from time import sleep
|
||||||
@@ -72,7 +71,9 @@ if config.dml == True:
|
|||||||
res = x.clone().detach()
|
res = x.clone().detach()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
import fairseq
|
||||||
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
|
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
logger.info(i18n)
|
logger.info(i18n)
|
||||||
# 判断是否有能用来训练和加速推理的N卡
|
# 判断是否有能用来训练和加速推理的N卡
|
||||||
|
|||||||
Reference in New Issue
Block a user