mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-07 02:00:25 +08:00
chore(format): run black on dev (#94)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a8783c6639
commit
d3add81469
@@ -133,11 +133,21 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||
|
||||
try:
|
||||
dist.init_process_group(
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://", world_size=n_gpus, rank=rank
|
||||
backend=(
|
||||
"gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"
|
||||
),
|
||||
init_method="env://",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
)
|
||||
except:
|
||||
dist.init_process_group(
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank
|
||||
backend=(
|
||||
"gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"
|
||||
),
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
)
|
||||
torch.manual_seed(hps.train.seed)
|
||||
if torch.cuda.is_available():
|
||||
@@ -243,13 +253,17 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||
if hasattr(net_g, "module"):
|
||||
logger.info(
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
|
||||
torch.load(
|
||||
hps.pretrainG, map_location="cpu", weights_only=True
|
||||
)["model"]
|
||||
)
|
||||
) ##测试不加载优化器
|
||||
else:
|
||||
logger.info(
|
||||
net_g.load_state_dict(
|
||||
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
|
||||
torch.load(
|
||||
hps.pretrainG, map_location="cpu", weights_only=True
|
||||
)["model"]
|
||||
)
|
||||
) ##测试不加载优化器
|
||||
if hps.pretrainD != "":
|
||||
@@ -258,13 +272,17 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||
if hasattr(net_d, "module"):
|
||||
logger.info(
|
||||
net_d.module.load_state_dict(
|
||||
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
|
||||
torch.load(
|
||||
hps.pretrainD, map_location="cpu", weights_only=True
|
||||
)["model"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
net_d.load_state_dict(
|
||||
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
|
||||
torch.load(
|
||||
hps.pretrainD, map_location="cpu", weights_only=True
|
||||
)["model"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user