1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-08 03:55:47 +08:00

fix(fairseq): hubert load model error

This commit is contained in:
源文雨
2026-04-18 19:04:13 +08:00
parent 8ded36e9e1
commit f9ae0b5d32
12 changed files with 101 additions and 151 deletions

View File

@@ -4,7 +4,8 @@ import torch
from torch import nn
from torch.nn import Conv1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from .norms import WN
from .utils import (
@@ -85,25 +86,17 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
for l in self.convs1:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
for l in self.convs2:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self
@@ -161,16 +154,12 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
def __prepare_scriptable__(self):
for l in self.convs:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
if is_parametrized(l, "weight"):
remove_parametrizations(l, "weight")
return self
@@ -249,12 +238,8 @@ class ResidualCouplingLayer(nn.Module):
self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
if is_parametrized(self.enc, "weight"):
remove_parametrizations(self.enc, "weight")
return self
@@ -344,10 +329,6 @@ class ResidualCouplingBlock(nn.Module):
def __prepare_scriptable__(self):
for i in range(self.n_flows):
for hook in self.flows[i * 2]._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
if is_parametrized(self.flows[i * 2], "weight"):
remove_parametrizations(self.flows[i * 2], "weight")
return self