From f9ae0b5d3270ea9134d4e2812e9e362df419d505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 18 Apr 2026 19:04:13 +0800 Subject: [PATCH] fix(fairseq): hubert load model error --- infer/lib/train/utils.py | 18 +++++++++++++ infer/modules/train/train.py | 18 +++++++------ infer/modules/vc/hash.py | 6 ++--- infer/modules/vc/utils.py | 12 +++++---- rvc/layers/discriminators.py | 3 ++- rvc/layers/encoders.py | 9 +++---- rvc/layers/generators.py | 26 ++++++------------- rvc/layers/norms.py | 38 +++++++++++----------------- rvc/layers/nsf.py | 25 ++++++------------ rvc/layers/residuals.py | 49 +++++++++++------------------------- rvc/layers/synthesizers.py | 29 ++++++--------------- web.py | 19 ++++++-------- 12 files changed, 101 insertions(+), 151 deletions(-) diff --git a/infer/lib/train/utils.py b/infer/lib/train/utils.py index 9f50ca7..6dd4e9c 100644 --- a/infer/lib/train/utils.py +++ b/infer/lib/train/utils.py @@ -29,6 +29,24 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) saved_state_dict = checkpoint_dict["model"] + # Convert old-style weight_norm keys (weight_g/weight_v) to new + # parametrizations format (parametrizations.weight.original0/original1) + # so that checkpoints saved with the deprecated API can still be loaded. + _converted = {} + for k, v in list(saved_state_dict.items()): + if k.endswith(".weight_g"): + new_key = k[: -len(".weight_g")] + ".parametrizations.weight.original0" + _converted[new_key] = v + elif k.endswith(".weight_v"): + new_key = k[: -len(".weight_v")] + ".parametrizations.weight.original1" + _converted[new_key] = v + if _converted: + logger.info( + "Converting %d old-style weight_norm keys from checkpoint to new parametrizations format", + len(_converted), + ) + saved_state_dict.update(_converted) + if hasattr(model, "module"): state_dict = model.module.state_dict() else: diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index c8b4875..42d525a 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -29,10 +29,12 @@ try: GradScaler = gradscaler_init() ipex_init() - else: - from torch.cuda.amp import GradScaler, autocast except Exception: - from torch.cuda.amp import GradScaler, autocast + pass +finally: + if not ('GradScaler' in globals() and 'autocast' in globals()): + from torch.amp.grad_scaler import GradScaler + from torch.amp.autocast_mode import autocast torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False @@ -535,7 +537,7 @@ def train_and_evaluate( # wave_lengths = wave_lengths.cuda(rank, non_blocking=True) # Calculate - with autocast(enabled=hps.train.fp16_run): + with autocast(device_type="cuda", enabled=hps.train.fp16_run): ( y_hat, ids_slice, @@ -554,7 +556,7 @@ def train_and_evaluate( y_mel = slice_on_last_dim( mel, ids_slice, hps.train.segment_size // hps.data.hop_length ) - with autocast(enabled=False): + with autocast(device_type="cuda", enabled=False): y_hat_mel = mel_spectrogram_torch( y_hat.float().squeeze(1), hps.data.filter_length, @@ -573,7 +575,7 @@ def train_and_evaluate( # Discriminator y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) - with autocast(enabled=False): + with autocast(device_type="cuda", enabled=False): loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( y_d_hat_r, y_d_hat_g ) @@ -583,10 +585,10 @@ def train_and_evaluate( grad_norm_d = total_grad_norm(net_d.parameters()) scaler.step(optim_d) - with autocast(enabled=hps.train.fp16_run): + with autocast(device_type="cuda", enabled=hps.train.fp16_run): # Generator y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) - with autocast(enabled=False): + with autocast(device_type="cuda", enabled=False): loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl loss_fm = feature_loss(fmap_r, fmap_g) diff --git a/infer/modules/vc/hash.py b/infer/modules/vc/hash.py index 6a3561b..949f5d2 100644 --- a/infer/modules/vc/hash.py +++ b/infer/modules/vc/hash.py @@ -10,9 +10,6 @@ from pybase16384 import encode_to_string, decode_from_string from configs import CPUConfig from rvc.synthesizer import get_synthesizer -from .pipeline import Pipeline -from .utils import load_hubert - class TorchSeedContext: def __init__(self, seed): @@ -95,6 +92,9 @@ def wave_hash(time_field): def model_hash(config, tgt_sr, net_g, if_f0, version): + from .pipeline import Pipeline + from .utils import load_hubert + pipeline = Pipeline(tgt_sr, config) audio = original_audio() hbt = load_hubert(config.device, config.is_half) diff --git a/infer/modules/vc/utils.py b/infer/modules/vc/utils.py index dfe4b72..cc8f1bd 100644 --- a/infer/modules/vc/utils.py +++ b/infer/modules/vc/utils.py @@ -1,6 +1,7 @@ import os, pathlib -from fairseq import checkpoint_utils +import torch +from fairseq import checkpoint_utils, data def get_index_path_from_model(sid): @@ -21,10 +22,11 @@ def get_index_path_from_model(sid): def load_hubert(device, is_half): - models, _, _ = checkpoint_utils.load_model_ensemble_and_task( - ["assets/hubert/hubert_base.pt"], - suffix="", - ) + with torch.serialization.safe_globals([data.dictionary.Dictionary]): + models, _, _ = checkpoint_utils.load_model_ensemble_and_task( + ["assets/hubert/hubert_base.pt"], + suffix="", + ) hubert_model = models[0] hubert_model = hubert_model.to(device) if is_half: diff --git a/rvc/layers/discriminators.py b/rvc/layers/discriminators.py index 29e2dff..9811ef4 100644 --- a/rvc/layers/discriminators.py +++ b/rvc/layers/discriminators.py @@ -4,7 +4,8 @@ import torch from torch import nn from torch.nn import Conv1d, Conv2d from torch.nn import functional as F -from torch.nn.utils import spectral_norm, weight_norm +from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import weight_norm from .residuals import LRELU_SLOPE from .utils import get_padding diff --git a/rvc/layers/encoders.py b/rvc/layers/encoders.py index eeaa3fd..2a6db19 100644 --- a/rvc/layers/encoders.py +++ b/rvc/layers/encoders.py @@ -212,10 +212,7 @@ class PosteriorEncoder(nn.Module): self.enc.remove_weight_norm() def __prepare_scriptable__(self): - for hook in self.enc._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.enc) + from torch.nn.utils import parametrize + if parametrize.is_parametrized(self.enc, "weight"): + parametrize.remove_parametrizations(self.enc, "weight") return self diff --git a/rvc/layers/generators.py b/rvc/layers/generators.py index d78d6cd..309a8bd 100644 --- a/rvc/layers/generators.py +++ b/rvc/layers/generators.py @@ -4,7 +4,8 @@ import torch from torch import nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F -from torch.nn.utils import remove_weight_norm, weight_norm +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE from .utils import call_weight_data_normal_if_Conv @@ -98,29 +99,16 @@ class Generator(torch.nn.Module): def __prepare_scriptable__(self): for l in self.ups: - for hook in l._forward_pre_hooks.values(): - # The hook we want to remove is an instance of WeightNorm class, so - # normally we would do `if isinstance(...)` but this class is not accessible - # because of shadowing, so we check the module name directly. - # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) - + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") for l in self.resblocks: - for hook in l._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") return self def remove_weight_norm(self): for l in self.ups: - remove_weight_norm(l) + remove_parametrizations(l, "weight") for l in self.resblocks: l.remove_weight_norm() diff --git a/rvc/layers/norms.py b/rvc/layers/norms.py index 4b07143..0512d6d 100644 --- a/rvc/layers/norms.py +++ b/rvc/layers/norms.py @@ -6,6 +6,8 @@ from torch.nn import functional as F from .utils import activate_add_tanh_sigmoid_multiply +from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations + class LayerNorm(nn.Module): def __init__(self, channels: int, eps: float = 1e-5): @@ -49,7 +51,7 @@ class WN(torch.nn.Module): cond_layer = torch.nn.Conv1d( gin_channels, 2 * hidden_channels * n_layers, 1 ) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight") for i in range(n_layers): dilation = dilation_rate**i @@ -61,7 +63,7 @@ class WN(torch.nn.Module): dilation=dilation, padding=padding, ) - in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) # last one is not necessary @@ -71,7 +73,7 @@ class WN(torch.nn.Module): res_skip_channels = hidden_channels res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) def __call__( @@ -117,32 +119,20 @@ class WN(torch.nn.Module): def remove_weight_norm(self): if self.gin_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) + remove_parametrizations(self.cond_layer, "weight") for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) + remove_parametrizations(l, "weight") for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) + remove_parametrizations(l, "weight") def __prepare_scriptable__(self): if self.gin_channels != 0: - for hook in self.cond_layer._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.cond_layer) + if is_parametrized(self.cond_layer, "weight"): + remove_parametrizations(self.cond_layer, "weight") for l in self.in_layers: - for hook in l._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") for l in self.res_skip_layers: - for hook in l._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") return self diff --git a/rvc/layers/nsf.py b/rvc/layers/nsf.py index 22fd968..5ba4737 100644 --- a/rvc/layers/nsf.py +++ b/rvc/layers/nsf.py @@ -5,7 +5,8 @@ import torch from torch import nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F -from torch.nn.utils import remove_weight_norm, weight_norm +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations from .generators import SineGenerator from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE @@ -191,27 +192,15 @@ class NSFGenerator(torch.nn.Module): def remove_weight_norm(self): for l in self.ups: - remove_weight_norm(l) + remove_parametrizations(l, "weight") for l in self.resblocks: l.remove_weight_norm() def __prepare_scriptable__(self): for l in self.ups: - for hook in l._forward_pre_hooks.values(): - # The hook we want to remove is an instance of WeightNorm class, so - # normally we would do `if isinstance(...)` but this class is not accessible - # because of shadowing, so we check the module name directly. - # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") for l in self.resblocks: - for hook in self.resblocks._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") return self diff --git a/rvc/layers/residuals.py b/rvc/layers/residuals.py index 45b3d11..521c739 100644 --- a/rvc/layers/residuals.py +++ b/rvc/layers/residuals.py @@ -4,7 +4,8 @@ import torch from torch import nn from torch.nn import Conv1d from torch.nn import functional as F -from torch.nn.utils import remove_weight_norm, weight_norm +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations from .norms import WN from .utils import ( @@ -85,25 +86,17 @@ class ResBlock1(torch.nn.Module): def remove_weight_norm(self): for l in self.convs1: - remove_weight_norm(l) + remove_parametrizations(l, "weight") for l in self.convs2: - remove_weight_norm(l) + remove_parametrizations(l, "weight") def __prepare_scriptable__(self): for l in self.convs1: - for hook in l._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") for l in self.convs2: - for hook in l._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") return self @@ -161,16 +154,12 @@ class ResBlock2(torch.nn.Module): def remove_weight_norm(self): for l in self.convs: - remove_weight_norm(l) + remove_parametrizations(l, "weight") def __prepare_scriptable__(self): for l in self.convs: - for hook in l._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) + if is_parametrized(l, "weight"): + remove_parametrizations(l, "weight") return self @@ -249,12 +238,8 @@ class ResidualCouplingLayer(nn.Module): self.enc.remove_weight_norm() def __prepare_scriptable__(self): - for hook in self.enc._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.enc) + if is_parametrized(self.enc, "weight"): + remove_parametrizations(self.enc, "weight") return self @@ -344,10 +329,6 @@ class ResidualCouplingBlock(nn.Module): def __prepare_scriptable__(self): for i in range(self.n_flows): - for hook in self.flows[i * 2]._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.flows[i * 2]) + if is_parametrized(self.flows[i * 2], "weight"): + remove_parametrizations(self.flows[i * 2], "weight") return self diff --git a/rvc/layers/synthesizers.py b/rvc/layers/synthesizers.py index c2c70ba..1a07899 100644 --- a/rvc/layers/synthesizers.py +++ b/rvc/layers/synthesizers.py @@ -2,6 +2,7 @@ from typing import Optional, List, Union import torch from torch import nn +from torch.nn.utils import parametrize from .encoders import TextEncoder, PosteriorEncoder @@ -118,29 +119,13 @@ class SynthesizerTrnMsNSFsid(nn.Module): self.enc_q.remove_weight_norm() def __prepare_scriptable__(self): - for hook in self.dec._forward_pre_hooks.values(): - # The hook we want to remove is an instance of WeightNorm class, so - # normally we would do `if isinstance(...)` but this class is not accessible - # because of shadowing, so we check the module name directly. - # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.dec) - for hook in self.flow._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.flow) + if parametrize.is_parametrized(self.dec, "weight"): + parametrize.remove_parametrizations(self.dec, "weight") + if parametrize.is_parametrized(self.flow, "weight"): + parametrize.remove_parametrizations(self.flow, "weight") if hasattr(self, "enc_q"): - for hook in self.enc_q._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(self.enc_q) + if parametrize.is_parametrized(self.enc_q, "weight"): + parametrize.remove_parametrizations(self.enc_q, "weight") return self @torch.jit.ignore() diff --git a/web.py b/web.py index df7f11c..9d41177 100644 --- a/web.py +++ b/web.py @@ -88,23 +88,24 @@ index_paths = [""] def lookup_names(weight_root): - global names + names = [] for name in os.listdir(weight_root): if name.endswith(".pth"): names.append(name) + return names def lookup_indices(index_root): - global index_paths + index_paths = [] for root, _, files in os.walk(index_root, topdown=False): for name in files: if name.endswith(".index") and "trained" not in name: index_paths.append(str(pathlib.Path(root, name))) + return index_paths -lookup_names(weight_root) -lookup_indices(index_root) -lookup_indices(outside_index_root) +names = [""] + lookup_names(weight_root) +index_paths = [""] + lookup_indices(index_root) + lookup_indices(outside_index_root) uvr5_names = [] for name in os.listdir(weight_uvr5_root): if name.endswith(".pth") or "onnx" in name: @@ -112,12 +113,8 @@ for name in os.listdir(weight_uvr5_root): def change_choices(): - global index_paths, names - names = [""] - lookup_names(weight_root) - index_paths = [""] - lookup_indices(index_root) - lookup_indices(outside_index_root) + names = [""] + lookup_names(weight_root) + index_paths = [""] + lookup_indices(index_root) + lookup_indices(outside_index_root) return {"choices": sorted(names), "__type__": "update"}, { "choices": sorted(index_paths), "__type__": "update",