mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-06 01:30:24 +08:00
feat: update to latest torch & gradio version
This commit is contained in:
@@ -131,9 +131,14 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||
writer = SummaryWriter(log_dir=hps.model_dir)
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
|
||||
)
|
||||
try:
|
||||
dist.init_process_group(
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://", world_size=n_gpus, rank=rank
|
||||
)
|
||||
except:
|
||||
dist.init_process_group(
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank
|
||||
)
|
||||
torch.manual_seed(hps.train.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
@@ -238,13 +243,13 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||
if hasattr(net_g, "module"):
|
||||
logger.info(
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
||||
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
|
||||
)
|
||||
) ##测试不加载优化器
|
||||
else:
|
||||
logger.info(
|
||||
net_g.load_state_dict(
|
||||
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
||||
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
|
||||
)
|
||||
) ##测试不加载优化器
|
||||
if hps.pretrainD != "":
|
||||
@@ -253,13 +258,13 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
|
||||
if hasattr(net_d, "module"):
|
||||
logger.info(
|
||||
net_d.module.load_state_dict(
|
||||
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
||||
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
net_d.load_state_dict(
|
||||
torch.load(hps.pretrainD, map_location="cpu")["model"]
|
||||
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import os
|
||||
import os, pathlib
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
|
||||
@@ -8,7 +8,7 @@ def get_index_path_from_model(sid):
|
||||
(
|
||||
f
|
||||
for f in [
|
||||
os.path.join(root, name)
|
||||
str(pathlib.Path(root, name))
|
||||
for path in [os.getenv("outside_index_root"), os.getenv("index_root")]
|
||||
for root, _, files in os.walk(path, topdown=False)
|
||||
for name in files
|
||||
|
||||
Reference in New Issue
Block a user