1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-07 19:40:44 +08:00

fix(fairseq): hubert load model error

This commit is contained in:
源文雨
2026-04-18 19:04:13 +08:00
parent 8ded36e9e1
commit f9ae0b5d32
12 changed files with 101 additions and 151 deletions

View File

@@ -6,6 +6,8 @@ from torch.nn import functional as F
from .utils import activate_add_tanh_sigmoid_multiply
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
class LayerNorm(nn.Module):
def __init__(self, channels: int, eps: float = 1e-5):
@@ -49,7 +51,7 @@ class WN(torch.nn.Module):
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
@@ -61,7 +63,7 @@ class WN(torch.nn.Module):
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
@@ -71,7 +73,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def __call__(
@@ -117,32 +119,20 @@ class WN(torch.nn.Module):
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
if self.gin_channels != 0:
for hook in self.cond_layer._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.cond_layer)
if is_parametrized(self.cond_layer, "weight"):
remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self