mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 09:10:25 +08:00
fix(fairseq): hubert load model error
This commit is contained in:
@@ -29,6 +29,24 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
||||
|
||||
saved_state_dict = checkpoint_dict["model"]
|
||||
# Convert old-style weight_norm keys (weight_g/weight_v) to new
|
||||
# parametrizations format (parametrizations.weight.original0/original1)
|
||||
# so that checkpoints saved with the deprecated API can still be loaded.
|
||||
_converted = {}
|
||||
for k, v in list(saved_state_dict.items()):
|
||||
if k.endswith(".weight_g"):
|
||||
new_key = k[: -len(".weight_g")] + ".parametrizations.weight.original0"
|
||||
_converted[new_key] = v
|
||||
elif k.endswith(".weight_v"):
|
||||
new_key = k[: -len(".weight_v")] + ".parametrizations.weight.original1"
|
||||
_converted[new_key] = v
|
||||
if _converted:
|
||||
logger.info(
|
||||
"Converting %d old-style weight_norm keys from checkpoint to new parametrizations format",
|
||||
len(_converted),
|
||||
)
|
||||
saved_state_dict.update(_converted)
|
||||
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user