mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
fix(fairseq): hubert load model error
This commit is contained in:
@@ -4,7 +4,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, Conv2d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from .residuals import LRELU_SLOPE
|
||||
from .utils import get_padding
|
||||
|
||||
@@ -212,10 +212,7 @@ class PosteriorEncoder(nn.Module):
|
||||
self.enc.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.enc._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc)
|
||||
from torch.nn.utils import parametrize
|
||||
if parametrize.is_parametrized(self.enc, "weight"):
|
||||
parametrize.remove_parametrizations(self.enc, "weight")
|
||||
return self
|
||||
|
||||
@@ -4,7 +4,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
|
||||
|
||||
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
|
||||
from .utils import call_weight_data_normal_if_Conv
|
||||
@@ -98,29 +99,16 @@ class Generator(torch.nn.Module):
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.ups:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
return self
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from torch.nn import functional as F
|
||||
|
||||
from .utils import activate_add_tanh_sigmoid_multiply
|
||||
|
||||
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels: int, eps: float = 1e-5):
|
||||
@@ -49,7 +51,7 @@ class WN(torch.nn.Module):
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
gin_channels, 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate**i
|
||||
@@ -61,7 +63,7 @@ class WN(torch.nn.Module):
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
@@ -71,7 +73,7 @@ class WN(torch.nn.Module):
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
||||
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def __call__(
|
||||
@@ -117,32 +119,20 @@ class WN(torch.nn.Module):
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
remove_parametrizations(self.cond_layer, "weight")
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
if self.gin_channels != 0:
|
||||
for hook in self.cond_layer._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
if is_parametrized(self.cond_layer, "weight"):
|
||||
remove_parametrizations(self.cond_layer, "weight")
|
||||
for l in self.in_layers:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.res_skip_layers:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
return self
|
||||
|
||||
@@ -5,7 +5,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
|
||||
|
||||
from .generators import SineGenerator
|
||||
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
|
||||
@@ -191,27 +192,15 @@ class NSFGenerator(torch.nn.Module):
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.ups:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
for hook in self.resblocks._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
return self
|
||||
|
||||
@@ -4,7 +4,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
|
||||
|
||||
from .norms import WN
|
||||
from .utils import (
|
||||
@@ -85,25 +86,17 @@ class ResBlock1(torch.nn.Module):
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.convs1:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
return self
|
||||
|
||||
|
||||
@@ -161,16 +154,12 @@ class ResBlock2(torch.nn.Module):
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.convs:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
return self
|
||||
|
||||
|
||||
@@ -249,12 +238,8 @@ class ResidualCouplingLayer(nn.Module):
|
||||
self.enc.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.enc._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc)
|
||||
if is_parametrized(self.enc, "weight"):
|
||||
remove_parametrizations(self.enc, "weight")
|
||||
return self
|
||||
|
||||
|
||||
@@ -344,10 +329,6 @@ class ResidualCouplingBlock(nn.Module):
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for i in range(self.n_flows):
|
||||
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
||||
if is_parametrized(self.flows[i * 2], "weight"):
|
||||
remove_parametrizations(self.flows[i * 2], "weight")
|
||||
return self
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
|
||||
from .encoders import TextEncoder, PosteriorEncoder
|
||||
@@ -118,29 +119,13 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if parametrize.is_parametrized(self.dec, "weight"):
|
||||
parametrize.remove_parametrizations(self.dec, "weight")
|
||||
if parametrize.is_parametrized(self.flow, "weight"):
|
||||
parametrize.remove_parametrizations(self.flow, "weight")
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
if parametrize.is_parametrized(self.enc_q, "weight"):
|
||||
parametrize.remove_parametrizations(self.enc_q, "weight")
|
||||
return self
|
||||
|
||||
@torch.jit.ignore()
|
||||
|
||||
Reference in New Issue
Block a user