1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

fix(fairseq): hubert load model error

This commit is contained in:
源文雨
2026-04-18 19:04:13 +08:00
parent 8ded36e9e1
commit f9ae0b5d32
12 changed files with 101 additions and 151 deletions

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn
from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from .residuals import LRELU_SLOPE
from .utils import get_padding

View File

@@ -212,10 +212,7 @@ class PosteriorEncoder(nn.Module):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
from torch.nn.utils import parametrize
if parametrize.is_parametrized(self.enc, "weight"):
parametrize.remove_parametrizations(self.enc, "weight")
return self

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
from .utils import call_weight_data_normal_if_Conv
@@ -98,29 +99,16 @@ class Generator(torch.nn.Module):
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.resblocks:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()

View File

@@ -6,6 +6,8 @@ from torch.nn import functional as F
from .utils import activate_add_tanh_sigmoid_multiply
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
class LayerNorm(nn.Module):
def __init__(self, channels: int, eps: float = 1e-5):
@@ -49,7 +51,7 @@ class WN(torch.nn.Module):
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
@@ -61,7 +63,7 @@ class WN(torch.nn.Module):
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
@@ -71,7 +73,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def __call__(
@@ -117,32 +119,20 @@ class WN(torch.nn.Module):
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
if self.gin_channels != 0:
for hook in self.cond_layer._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.cond_layer)
if is_parametrized(self.cond_layer, "weight"):
remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self

View File

@@ -5,7 +5,8 @@ import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .generators import SineGenerator
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
@@ -191,27 +192,15 @@ class NSFGenerator(torch.nn.Module):
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.resblocks:
for hook in self.resblocks._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn
from torch.nn import Conv1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .norms import WN
from .utils import (
@@ -85,25 +86,17 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
for l in self.convs1:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.convs2:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self
@@ -161,16 +154,12 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
for l in self.convs:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self
@@ -249,12 +238,8 @@ class ResidualCouplingLayer(nn.Module):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
if is_parametrized(self.enc, "weight"):
remove_parametrizations(self.enc, "weight")
return self
@@ -344,10 +329,6 @@ class ResidualCouplingBlock(nn.Module):
def __prepare_scriptable__(self):
for i in range(self.n_flows):
for hook in self.flows[i * 2]._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
if is_parametrized(self.flows[i * 2], "weight"):
remove_parametrizations(self.flows[i * 2], "weight")
return self

View File

@@ -2,6 +2,7 @@ from typing import Optional, List, Union
import torch
from torch import nn
from torch.nn.utils import parametrize
from .encoders import TextEncoder, PosteriorEncoder
@@ -118,29 +119,13 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.dec)
for hook in self.flow._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flow)
if parametrize.is_parametrized(self.dec, "weight"):
parametrize.remove_parametrizations(self.dec, "weight")
if parametrize.is_parametrized(self.flow, "weight"):
parametrize.remove_parametrizations(self.flow, "weight")
if hasattr(self, "enc_q"):
for hook in self.enc_q._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc_q)
if parametrize.is_parametrized(self.enc_q, "weight"):
parametrize.remove_parametrizations(self.enc_q, "weight")
return self
@torch.jit.ignore()