mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
fix(fairseq): hubert load model error
This commit is contained in:
@@ -29,6 +29,24 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
||||
|
||||
saved_state_dict = checkpoint_dict["model"]
|
||||
# Convert old-style weight_norm keys (weight_g/weight_v) to new
|
||||
# parametrizations format (parametrizations.weight.original0/original1)
|
||||
# so that checkpoints saved with the deprecated API can still be loaded.
|
||||
_converted = {}
|
||||
for k, v in list(saved_state_dict.items()):
|
||||
if k.endswith(".weight_g"):
|
||||
new_key = k[: -len(".weight_g")] + ".parametrizations.weight.original0"
|
||||
_converted[new_key] = v
|
||||
elif k.endswith(".weight_v"):
|
||||
new_key = k[: -len(".weight_v")] + ".parametrizations.weight.original1"
|
||||
_converted[new_key] = v
|
||||
if _converted:
|
||||
logger.info(
|
||||
"Converting %d old-style weight_norm keys from checkpoint to new parametrizations format",
|
||||
len(_converted),
|
||||
)
|
||||
saved_state_dict.update(_converted)
|
||||
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
|
||||
@@ -29,10 +29,12 @@ try:
|
||||
|
||||
GradScaler = gradscaler_init()
|
||||
ipex_init()
|
||||
else:
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
except Exception:
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
pass
|
||||
finally:
|
||||
if not ('GradScaler' in globals() and 'autocast' in globals()):
|
||||
from torch.amp.grad_scaler import GradScaler
|
||||
from torch.amp.autocast_mode import autocast
|
||||
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@@ -535,7 +537,7 @@ def train_and_evaluate(
|
||||
# wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
||||
|
||||
# Calculate
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
with autocast(device_type="cuda", enabled=hps.train.fp16_run):
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
@@ -554,7 +556,7 @@ def train_and_evaluate(
|
||||
y_mel = slice_on_last_dim(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
)
|
||||
with autocast(enabled=False):
|
||||
with autocast(device_type="cuda", enabled=False):
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.float().squeeze(1),
|
||||
hps.data.filter_length,
|
||||
@@ -573,7 +575,7 @@ def train_and_evaluate(
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
||||
with autocast(enabled=False):
|
||||
with autocast(device_type="cuda", enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r, y_d_hat_g
|
||||
)
|
||||
@@ -583,10 +585,10 @@ def train_and_evaluate(
|
||||
grad_norm_d = total_grad_norm(net_d.parameters())
|
||||
scaler.step(optim_d)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
with autocast(device_type="cuda", enabled=hps.train.fp16_run):
|
||||
# Generator
|
||||
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
||||
with autocast(enabled=False):
|
||||
with autocast(device_type="cuda", enabled=False):
|
||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
||||
|
||||
@@ -10,9 +10,6 @@ from pybase16384 import encode_to_string, decode_from_string
|
||||
from configs import CPUConfig
|
||||
from rvc.synthesizer import get_synthesizer
|
||||
|
||||
from .pipeline import Pipeline
|
||||
from .utils import load_hubert
|
||||
|
||||
|
||||
class TorchSeedContext:
|
||||
def __init__(self, seed):
|
||||
@@ -95,6 +92,9 @@ def wave_hash(time_field):
|
||||
|
||||
|
||||
def model_hash(config, tgt_sr, net_g, if_f0, version):
|
||||
from .pipeline import Pipeline
|
||||
from .utils import load_hubert
|
||||
|
||||
pipeline = Pipeline(tgt_sr, config)
|
||||
audio = original_audio()
|
||||
hbt = load_hubert(config.device, config.is_half)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os, pathlib
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, data
|
||||
|
||||
|
||||
def get_index_path_from_model(sid):
|
||||
@@ -21,10 +22,11 @@ def get_index_path_from_model(sid):
|
||||
|
||||
|
||||
def load_hubert(device, is_half):
|
||||
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
||||
["assets/hubert/hubert_base.pt"],
|
||||
suffix="",
|
||||
)
|
||||
with torch.serialization.safe_globals([data.dictionary.Dictionary]):
|
||||
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
||||
["assets/hubert/hubert_base.pt"],
|
||||
suffix="",
|
||||
)
|
||||
hubert_model = models[0]
|
||||
hubert_model = hubert_model.to(device)
|
||||
if is_half:
|
||||
|
||||
Reference in New Issue
Block a user