mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-08 03:55:47 +08:00
fix(fairseq): hubert load model error
This commit is contained in:
@@ -4,7 +4,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
|
||||
|
||||
from .norms import WN
|
||||
from .utils import (
|
||||
@@ -85,25 +86,17 @@ class ResBlock1(torch.nn.Module):
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.convs1:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
return self
|
||||
|
||||
|
||||
@@ -161,16 +154,12 @@ class ResBlock2(torch.nn.Module):
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.convs:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
if is_parametrized(l, "weight"):
|
||||
remove_parametrizations(l, "weight")
|
||||
return self
|
||||
|
||||
|
||||
@@ -249,12 +238,8 @@ class ResidualCouplingLayer(nn.Module):
|
||||
self.enc.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.enc._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc)
|
||||
if is_parametrized(self.enc, "weight"):
|
||||
remove_parametrizations(self.enc, "weight")
|
||||
return self
|
||||
|
||||
|
||||
@@ -344,10 +329,6 @@ class ResidualCouplingBlock(nn.Module):
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for i in range(self.n_flows):
|
||||
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
||||
if is_parametrized(self.flows[i * 2], "weight"):
|
||||
remove_parametrizations(self.flows[i * 2], "weight")
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user