1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-08 12:00:49 +08:00

fix(fairseq): hubert load model error

This commit is contained in:
源文雨
2026-04-18 19:04:13 +08:00
parent 8ded36e9e1
commit f9ae0b5d32
12 changed files with 101 additions and 151 deletions

View File

@@ -2,6 +2,7 @@ from typing import Optional, List, Union
import torch
from torch import nn
from torch.nn.utils import parametrize
from .encoders import TextEncoder, PosteriorEncoder
@@ -118,29 +119,13 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.dec)
for hook in self.flow._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flow)
if parametrize.is_parametrized(self.dec, "weight"):
parametrize.remove_parametrizations(self.dec, "weight")
if parametrize.is_parametrized(self.flow, "weight"):
parametrize.remove_parametrizations(self.flow, "weight")
if hasattr(self, "enc_q"):
for hook in self.enc_q._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc_q)
if parametrize.is_parametrized(self.enc_q, "weight"):
parametrize.remove_parametrizations(self.enc_q, "weight")
return self
@torch.jit.ignore()