1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

fix(fairseq): hubert load model error

This commit is contained in:
源文雨
2026-04-18 19:04:13 +08:00
parent 8ded36e9e1
commit f9ae0b5d32
12 changed files with 101 additions and 151 deletions

View File

@@ -29,6 +29,24 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
saved_state_dict = checkpoint_dict["model"] saved_state_dict = checkpoint_dict["model"]
# Convert old-style weight_norm keys (weight_g/weight_v) to new
# parametrizations format (parametrizations.weight.original0/original1)
# so that checkpoints saved with the deprecated API can still be loaded.
_converted = {}
for k, v in list(saved_state_dict.items()):
if k.endswith(".weight_g"):
new_key = k[: -len(".weight_g")] + ".parametrizations.weight.original0"
_converted[new_key] = v
elif k.endswith(".weight_v"):
new_key = k[: -len(".weight_v")] + ".parametrizations.weight.original1"
_converted[new_key] = v
if _converted:
logger.info(
"Converting %d old-style weight_norm keys from checkpoint to new parametrizations format",
len(_converted),
)
saved_state_dict.update(_converted)
if hasattr(model, "module"): if hasattr(model, "module"):
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
else: else:

View File

@@ -29,10 +29,12 @@ try:
GradScaler = gradscaler_init() GradScaler = gradscaler_init()
ipex_init() ipex_init()
else:
from torch.cuda.amp import GradScaler, autocast
except Exception: except Exception:
from torch.cuda.amp import GradScaler, autocast pass
finally:
if not ('GradScaler' in globals() and 'autocast' in globals()):
from torch.amp.grad_scaler import GradScaler
from torch.amp.autocast_mode import autocast
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
@@ -535,7 +537,7 @@ def train_and_evaluate(
# wave_lengths = wave_lengths.cuda(rank, non_blocking=True) # wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
# Calculate # Calculate
with autocast(enabled=hps.train.fp16_run): with autocast(device_type="cuda", enabled=hps.train.fp16_run):
( (
y_hat, y_hat,
ids_slice, ids_slice,
@@ -554,7 +556,7 @@ def train_and_evaluate(
y_mel = slice_on_last_dim( y_mel = slice_on_last_dim(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length mel, ids_slice, hps.train.segment_size // hps.data.hop_length
) )
with autocast(enabled=False): with autocast(device_type="cuda", enabled=False):
y_hat_mel = mel_spectrogram_torch( y_hat_mel = mel_spectrogram_torch(
y_hat.float().squeeze(1), y_hat.float().squeeze(1),
hps.data.filter_length, hps.data.filter_length,
@@ -573,7 +575,7 @@ def train_and_evaluate(
# Discriminator # Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
with autocast(enabled=False): with autocast(device_type="cuda", enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g y_d_hat_r, y_d_hat_g
) )
@@ -583,10 +585,10 @@ def train_and_evaluate(
grad_norm_d = total_grad_norm(net_d.parameters()) grad_norm_d = total_grad_norm(net_d.parameters())
scaler.step(optim_d) scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run): with autocast(device_type="cuda", enabled=hps.train.fp16_run):
# Generator # Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
with autocast(enabled=False): with autocast(device_type="cuda", enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g) loss_fm = feature_loss(fmap_r, fmap_g)

View File

@@ -10,9 +10,6 @@ from pybase16384 import encode_to_string, decode_from_string
from configs import CPUConfig from configs import CPUConfig
from rvc.synthesizer import get_synthesizer from rvc.synthesizer import get_synthesizer
from .pipeline import Pipeline
from .utils import load_hubert
class TorchSeedContext: class TorchSeedContext:
def __init__(self, seed): def __init__(self, seed):
@@ -95,6 +92,9 @@ def wave_hash(time_field):
def model_hash(config, tgt_sr, net_g, if_f0, version): def model_hash(config, tgt_sr, net_g, if_f0, version):
from .pipeline import Pipeline
from .utils import load_hubert
pipeline = Pipeline(tgt_sr, config) pipeline = Pipeline(tgt_sr, config)
audio = original_audio() audio = original_audio()
hbt = load_hubert(config.device, config.is_half) hbt = load_hubert(config.device, config.is_half)

View File

@@ -1,6 +1,7 @@
import os, pathlib import os, pathlib
from fairseq import checkpoint_utils import torch
from fairseq import checkpoint_utils, data
def get_index_path_from_model(sid): def get_index_path_from_model(sid):
@@ -21,10 +22,11 @@ def get_index_path_from_model(sid):
def load_hubert(device, is_half): def load_hubert(device, is_half):
models, _, _ = checkpoint_utils.load_model_ensemble_and_task( with torch.serialization.safe_globals([data.dictionary.Dictionary]):
["assets/hubert/hubert_base.pt"], models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
suffix="", ["assets/hubert/hubert_base.pt"],
) suffix="",
)
hubert_model = models[0] hubert_model = models[0]
hubert_model = hubert_model.to(device) hubert_model = hubert_model.to(device)
if is_half: if is_half:

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn from torch import nn
from torch.nn import Conv1d, Conv2d from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from .residuals import LRELU_SLOPE from .residuals import LRELU_SLOPE
from .utils import get_padding from .utils import get_padding

View File

@@ -212,10 +212,7 @@ class PosteriorEncoder(nn.Module):
self.enc.remove_weight_norm() self.enc.remove_weight_norm()
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values(): from torch.nn.utils import parametrize
if ( if parametrize.is_parametrized(self.enc, "weight"):
hook.__module__ == "torch.nn.utils.weight_norm" parametrize.remove_parametrizations(self.enc, "weight")
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
return self return self

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn from torch import nn
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
from .utils import call_weight_data_normal_if_Conv from .utils import call_weight_data_normal_if_Conv
@@ -98,29 +99,16 @@ class Generator(torch.nn.Module):
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
for l in self.ups: for l in self.ups:
for hook in l._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
# The hook we want to remove is an instance of WeightNorm class, so remove_parametrizations(l, "weight")
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks: for l in self.resblocks:
for hook in l._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
if ( remove_parametrizations(l, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self return self
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.ups: for l in self.ups:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.resblocks: for l in self.resblocks:
l.remove_weight_norm() l.remove_weight_norm()

View File

@@ -6,6 +6,8 @@ from torch.nn import functional as F
from .utils import activate_add_tanh_sigmoid_multiply from .utils import activate_add_tanh_sigmoid_multiply
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, channels: int, eps: float = 1e-5): def __init__(self, channels: int, eps: float = 1e-5):
@@ -49,7 +51,7 @@ class WN(torch.nn.Module):
cond_layer = torch.nn.Conv1d( cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1 gin_channels, 2 * hidden_channels * n_layers, 1
) )
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
for i in range(n_layers): for i in range(n_layers):
dilation = dilation_rate**i dilation = dilation_rate**i
@@ -61,7 +63,7 @@ class WN(torch.nn.Module):
dilation=dilation, dilation=dilation,
padding=padding, padding=padding,
) )
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
# last one is not necessary # last one is not necessary
@@ -71,7 +73,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
def __call__( def __call__(
@@ -117,32 +119,20 @@ class WN(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
if self.gin_channels != 0: if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer) remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers: for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.res_skip_layers: for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l) remove_parametrizations(l, "weight")
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
if self.gin_channels != 0: if self.gin_channels != 0:
for hook in self.cond_layer._forward_pre_hooks.values(): if is_parametrized(self.cond_layer, "weight"):
if ( remove_parametrizations(self.cond_layer, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers: for l in self.in_layers:
for hook in l._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
if ( remove_parametrizations(l, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers: for l in self.res_skip_layers:
for hook in l._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
if ( remove_parametrizations(l, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self return self

View File

@@ -5,7 +5,8 @@ import torch
from torch import nn from torch import nn
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .generators import SineGenerator from .generators import SineGenerator
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
@@ -191,27 +192,15 @@ class NSFGenerator(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.ups: for l in self.ups:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.resblocks: for l in self.resblocks:
l.remove_weight_norm() l.remove_weight_norm()
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
for l in self.ups: for l in self.ups:
for hook in l._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
# The hook we want to remove is an instance of WeightNorm class, so remove_parametrizations(l, "weight")
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks: for l in self.resblocks:
for hook in self.resblocks._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
if ( remove_parametrizations(l, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self return self

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn from torch import nn
from torch.nn import Conv1d from torch.nn import Conv1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .norms import WN from .norms import WN
from .utils import ( from .utils import (
@@ -85,25 +86,17 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.convs1: for l in self.convs1:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.convs2: for l in self.convs2:
remove_weight_norm(l) remove_parametrizations(l, "weight")
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
for l in self.convs1: for l in self.convs1:
for hook in l._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
if ( remove_parametrizations(l, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.convs2: for l in self.convs2:
for hook in l._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
if ( remove_parametrizations(l, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self return self
@@ -161,16 +154,12 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.convs: for l in self.convs:
remove_weight_norm(l) remove_parametrizations(l, "weight")
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
for l in self.convs: for l in self.convs:
for hook in l._forward_pre_hooks.values(): if is_parametrized(l, "weight"):
if ( remove_parametrizations(l, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self return self
@@ -249,12 +238,8 @@ class ResidualCouplingLayer(nn.Module):
self.enc.remove_weight_norm() self.enc.remove_weight_norm()
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values(): if is_parametrized(self.enc, "weight"):
if ( remove_parametrizations(self.enc, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
return self return self
@@ -344,10 +329,6 @@ class ResidualCouplingBlock(nn.Module):
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
for i in range(self.n_flows): for i in range(self.n_flows):
for hook in self.flows[i * 2]._forward_pre_hooks.values(): if is_parametrized(self.flows[i * 2], "weight"):
if ( remove_parametrizations(self.flows[i * 2], "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
return self return self

View File

@@ -2,6 +2,7 @@ from typing import Optional, List, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn.utils import parametrize
from .encoders import TextEncoder, PosteriorEncoder from .encoders import TextEncoder, PosteriorEncoder
@@ -118,29 +119,13 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self.enc_q.remove_weight_norm() self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self): def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values(): if parametrize.is_parametrized(self.dec, "weight"):
# The hook we want to remove is an instance of WeightNorm class, so parametrize.remove_parametrizations(self.dec, "weight")
# normally we would do `if isinstance(...)` but this class is not accessible if parametrize.is_parametrized(self.flow, "weight"):
# because of shadowing, so we check the module name directly. parametrize.remove_parametrizations(self.flow, "weight")
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.dec)
for hook in self.flow._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flow)
if hasattr(self, "enc_q"): if hasattr(self, "enc_q"):
for hook in self.enc_q._forward_pre_hooks.values(): if parametrize.is_parametrized(self.enc_q, "weight"):
if ( parametrize.remove_parametrizations(self.enc_q, "weight")
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self return self
@torch.jit.ignore() @torch.jit.ignore()

19
web.py
View File

@@ -88,23 +88,24 @@ index_paths = [""]
def lookup_names(weight_root): def lookup_names(weight_root):
global names names = []
for name in os.listdir(weight_root): for name in os.listdir(weight_root):
if name.endswith(".pth"): if name.endswith(".pth"):
names.append(name) names.append(name)
return names
def lookup_indices(index_root): def lookup_indices(index_root):
global index_paths index_paths = []
for root, _, files in os.walk(index_root, topdown=False): for root, _, files in os.walk(index_root, topdown=False):
for name in files: for name in files:
if name.endswith(".index") and "trained" not in name: if name.endswith(".index") and "trained" not in name:
index_paths.append(str(pathlib.Path(root, name))) index_paths.append(str(pathlib.Path(root, name)))
return index_paths
lookup_names(weight_root) names = [""] + lookup_names(weight_root)
lookup_indices(index_root) index_paths = [""] + lookup_indices(index_root) + lookup_indices(outside_index_root)
lookup_indices(outside_index_root)
uvr5_names = [] uvr5_names = []
for name in os.listdir(weight_uvr5_root): for name in os.listdir(weight_uvr5_root):
if name.endswith(".pth") or "onnx" in name: if name.endswith(".pth") or "onnx" in name:
@@ -112,12 +113,8 @@ for name in os.listdir(weight_uvr5_root):
def change_choices(): def change_choices():
global index_paths, names names = [""] + lookup_names(weight_root)
names = [""] index_paths = [""] + lookup_indices(index_root) + lookup_indices(outside_index_root)
lookup_names(weight_root)
index_paths = [""]
lookup_indices(index_root)
lookup_indices(outside_index_root)
return {"choices": sorted(names), "__type__": "update"}, { return {"choices": sorted(names), "__type__": "update"}, {
"choices": sorted(index_paths), "choices": sorted(index_paths),
"__type__": "update", "__type__": "update",