1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

feat: update to latest torch & gradio version

This commit is contained in:
源文雨
2024-11-27 22:16:06 +09:00
parent 7c0b1c01f1
commit 4b68fb0e13
8 changed files with 40 additions and 37 deletions

View File

@@ -131,9 +131,14 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
dist.init_process_group(
backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
)
try:
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://", world_size=n_gpus, rank=rank
)
except:
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
@@ -238,13 +243,13 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
if hasattr(net_g, "module"):
logger.info(
net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
)
) ##测试不加载优化器
else:
logger.info(
net_g.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
)
) ##测试不加载优化器
if hps.pretrainD != "":
@@ -253,13 +258,13 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
if hasattr(net_d, "module"):
logger.info(
net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
)
)
else:
logger.info(
net_d.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
)
)