mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-08 12:00:49 +08:00
fix(fairseq): hubert load model error
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Optional, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
|
||||
from .encoders import TextEncoder, PosteriorEncoder
|
||||
@@ -118,29 +119,13 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if parametrize.is_parametrized(self.dec, "weight"):
|
||||
parametrize.remove_parametrizations(self.dec, "weight")
|
||||
if parametrize.is_parametrized(self.flow, "weight"):
|
||||
parametrize.remove_parametrizations(self.flow, "weight")
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
if parametrize.is_parametrized(self.enc_q, "weight"):
|
||||
parametrize.remove_parametrizations(self.enc_q, "weight")
|
||||
return self
|
||||
|
||||
@torch.jit.ignore()
|
||||
|
||||
Reference in New Issue
Block a user