1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

fix(fairseq): hubert load model error

This commit is contained in:
源文雨
2026-04-18 19:04:13 +08:00
parent 8ded36e9e1
commit f9ae0b5d32
12 changed files with 101 additions and 151 deletions

View File

@@ -29,6 +29,24 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
saved_state_dict = checkpoint_dict["model"]
# Convert old-style weight_norm keys (weight_g/weight_v) to new
# parametrizations format (parametrizations.weight.original0/original1)
# so that checkpoints saved with the deprecated API can still be loaded.
_converted = {}
for k, v in list(saved_state_dict.items()):
if k.endswith(".weight_g"):
new_key = k[: -len(".weight_g")] + ".parametrizations.weight.original0"
_converted[new_key] = v
elif k.endswith(".weight_v"):
new_key = k[: -len(".weight_v")] + ".parametrizations.weight.original1"
_converted[new_key] = v
if _converted:
logger.info(
"Converting %d old-style weight_norm keys from checkpoint to new parametrizations format",
len(_converted),
)
saved_state_dict.update(_converted)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else: