1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

fix(fairseq): hubert load model error

This commit is contained in:
源文雨
2026-04-18 19:04:13 +08:00
parent 8ded36e9e1
commit f9ae0b5d32
12 changed files with 101 additions and 151 deletions

View File

@@ -29,6 +29,24 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
saved_state_dict = checkpoint_dict["model"]
# Convert old-style weight_norm keys (weight_g/weight_v) to new
# parametrizations format (parametrizations.weight.original0/original1)
# so that checkpoints saved with the deprecated API can still be loaded.
_converted = {}
for k, v in list(saved_state_dict.items()):
if k.endswith(".weight_g"):
new_key = k[: -len(".weight_g")] + ".parametrizations.weight.original0"
_converted[new_key] = v
elif k.endswith(".weight_v"):
new_key = k[: -len(".weight_v")] + ".parametrizations.weight.original1"
_converted[new_key] = v
if _converted:
logger.info(
"Converting %d old-style weight_norm keys from checkpoint to new parametrizations format",
len(_converted),
)
saved_state_dict.update(_converted)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:

View File

@@ -29,10 +29,12 @@ try:
GradScaler = gradscaler_init()
ipex_init()
else:
from torch.cuda.amp import GradScaler, autocast
except Exception:
from torch.cuda.amp import GradScaler, autocast
pass
finally:
if not ('GradScaler' in globals() and 'autocast' in globals()):
from torch.amp.grad_scaler import GradScaler
from torch.amp.autocast_mode import autocast
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
@@ -535,7 +537,7 @@ def train_and_evaluate(
# wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
# Calculate
with autocast(enabled=hps.train.fp16_run):
with autocast(device_type="cuda", enabled=hps.train.fp16_run):
(
y_hat,
ids_slice,
@@ -554,7 +556,7 @@ def train_and_evaluate(
y_mel = slice_on_last_dim(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
with autocast(enabled=False):
with autocast(device_type="cuda", enabled=False):
y_hat_mel = mel_spectrogram_torch(
y_hat.float().squeeze(1),
hps.data.filter_length,
@@ -573,7 +575,7 @@ def train_and_evaluate(
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
with autocast(enabled=False):
with autocast(device_type="cuda", enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
@@ -583,10 +585,10 @@ def train_and_evaluate(
grad_norm_d = total_grad_norm(net_d.parameters())
scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run):
with autocast(device_type="cuda", enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
with autocast(enabled=False):
with autocast(device_type="cuda", enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)

View File

@@ -10,9 +10,6 @@ from pybase16384 import encode_to_string, decode_from_string
from configs import CPUConfig
from rvc.synthesizer import get_synthesizer
from .pipeline import Pipeline
from .utils import load_hubert
class TorchSeedContext:
def __init__(self, seed):
@@ -95,6 +92,9 @@ def wave_hash(time_field):
def model_hash(config, tgt_sr, net_g, if_f0, version):
from .pipeline import Pipeline
from .utils import load_hubert
pipeline = Pipeline(tgt_sr, config)
audio = original_audio()
hbt = load_hubert(config.device, config.is_half)

View File

@@ -1,6 +1,7 @@
import os, pathlib
from fairseq import checkpoint_utils
import torch
from fairseq import checkpoint_utils, data
def get_index_path_from_model(sid):
@@ -21,6 +22,7 @@ def get_index_path_from_model(sid):
def load_hubert(device, is_half):
with torch.serialization.safe_globals([data.dictionary.Dictionary]):
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"],
suffix="",

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn
from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from .residuals import LRELU_SLOPE
from .utils import get_padding

View File

@@ -212,10 +212,7 @@ class PosteriorEncoder(nn.Module):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
from torch.nn.utils import parametrize
if parametrize.is_parametrized(self.enc, "weight"):
parametrize.remove_parametrizations(self.enc, "weight")
return self

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
from .utils import call_weight_data_normal_if_Conv
@@ -98,29 +99,16 @@ class Generator(torch.nn.Module):
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.resblocks:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()

View File

@@ -6,6 +6,8 @@ from torch.nn import functional as F
from .utils import activate_add_tanh_sigmoid_multiply
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
class LayerNorm(nn.Module):
def __init__(self, channels: int, eps: float = 1e-5):
@@ -49,7 +51,7 @@ class WN(torch.nn.Module):
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
@@ -61,7 +63,7 @@ class WN(torch.nn.Module):
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
@@ -71,7 +73,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def __call__(
@@ -117,32 +119,20 @@ class WN(torch.nn.Module):
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
if self.gin_channels != 0:
for hook in self.cond_layer._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.cond_layer)
if is_parametrized(self.cond_layer, "weight"):
remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self

View File

@@ -5,7 +5,8 @@ import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .generators import SineGenerator
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
@@ -191,27 +192,15 @@ class NSFGenerator(torch.nn.Module):
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.resblocks:
for hook in self.resblocks._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn
from torch.nn import Conv1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .norms import WN
from .utils import (
@@ -85,25 +86,17 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
for l in self.convs1:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.convs2:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self
@@ -161,16 +154,12 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
for l in self.convs:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self
@@ -249,12 +238,8 @@ class ResidualCouplingLayer(nn.Module):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
if is_parametrized(self.enc, "weight"):
remove_parametrizations(self.enc, "weight")
return self
@@ -344,10 +329,6 @@ class ResidualCouplingBlock(nn.Module):
def __prepare_scriptable__(self):
for i in range(self.n_flows):
for hook in self.flows[i * 2]._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
if is_parametrized(self.flows[i * 2], "weight"):
remove_parametrizations(self.flows[i * 2], "weight")
return self

View File

@@ -2,6 +2,7 @@ from typing import Optional, List, Union
import torch
from torch import nn
from torch.nn.utils import parametrize
from .encoders import TextEncoder, PosteriorEncoder
@@ -118,29 +119,13 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.dec)
for hook in self.flow._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flow)
if parametrize.is_parametrized(self.dec, "weight"):
parametrize.remove_parametrizations(self.dec, "weight")
if parametrize.is_parametrized(self.flow, "weight"):
parametrize.remove_parametrizations(self.flow, "weight")
if hasattr(self, "enc_q"):
for hook in self.enc_q._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc_q)
if parametrize.is_parametrized(self.enc_q, "weight"):
parametrize.remove_parametrizations(self.enc_q, "weight")
return self
@torch.jit.ignore()

19
web.py
View File

@@ -88,23 +88,24 @@ index_paths = [""]
def lookup_names(weight_root):
global names
names = []
for name in os.listdir(weight_root):
if name.endswith(".pth"):
names.append(name)
return names
def lookup_indices(index_root):
global index_paths
index_paths = []
for root, _, files in os.walk(index_root, topdown=False):
for name in files:
if name.endswith(".index") and "trained" not in name:
index_paths.append(str(pathlib.Path(root, name)))
return index_paths
lookup_names(weight_root)
lookup_indices(index_root)
lookup_indices(outside_index_root)
names = [""] + lookup_names(weight_root)
index_paths = [""] + lookup_indices(index_root) + lookup_indices(outside_index_root)
uvr5_names = []
for name in os.listdir(weight_uvr5_root):
if name.endswith(".pth") or "onnx" in name:
@@ -112,12 +113,8 @@ for name in os.listdir(weight_uvr5_root):
def change_choices():
global index_paths, names
names = [""]
lookup_names(weight_root)
index_paths = [""]
lookup_indices(index_root)
lookup_indices(outside_index_root)
names = [""] + lookup_names(weight_root)
index_paths = [""] + lookup_indices(index_root) + lookup_indices(outside_index_root)
return {"choices": sorted(names), "__type__": "update"}, {
"choices": sorted(index_paths),
"__type__": "update",