1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-07 19:40:44 +08:00

fix(train): save small model fail

This commit is contained in:
源文雨
2024-06-04 04:07:19 +09:00
parent 5df99f2f73
commit 481f14dd74
8 changed files with 71 additions and 53 deletions

View File

@@ -14,7 +14,7 @@ MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
"""
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
@@ -64,37 +64,8 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
# traceback.print_exc()
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
"""
# def load_checkpoint(checkpoint_path, model, optimizer=None):
# assert os.path.isfile(checkpoint_path)
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
# iteration = checkpoint_dict['iteration']
# learning_rate = checkpoint_dict['learning_rate']
# if optimizer is not None:
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
# # print(1111)
# saved_state_dict = checkpoint_dict['model']
# # print(1111)
#
# if hasattr(model, 'module'):
# state_dict = model.module.state_dict()
# else:
# state_dict = model.state_dict()
# new_state_dict= {}
# for k, v in state_dict.items():
# try:
# new_state_dict[k] = saved_state_dict[k]
# except:
# logger.info("%s is not in the checkpoint" % k)
# new_state_dict[k] = v
# if hasattr(model, 'module'):
# model.module.load_state_dict(new_state_dict)
# else:
# model.load_state_dict(new_state_dict)
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
# checkpoint_path, iteration))
# return model, optimizer, learning_rate, iteration
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
@@ -159,7 +130,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
checkpoint_path,
)
"""
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
@@ -184,7 +155,7 @@ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoin
},
checkpoint_path,
)
"""
def summarize(
writer,