1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-06 17:50:25 +08:00

chore(format): run black on dev (#94)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2024-11-28 03:21:10 +09:00
committed by GitHub
parent a8783c6639
commit d3add81469
10 changed files with 126 additions and 47 deletions

View File

@@ -62,15 +62,24 @@ class PreProcess:
tmp_audio = (tmp_audio / tmp_max * (self.max * self.alpha)) + (
1 - self.alpha
) * tmp_audio
save_audio("%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1), tmp_audio, self.sr, f32=True)
save_audio(
"%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
tmp_audio,
self.sr,
f32=True,
)
with open("%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1), "wb") as f:
f.write(float_np_array_to_wav_buf(
load_audio(
float_np_array_to_wav_buf(tmp_audio, self.sr, f32=True),
sr=16000,
format="wav",
)
, 16000, True).getbuffer())
f.write(
float_np_array_to_wav_buf(
load_audio(
float_np_array_to_wav_buf(tmp_audio, self.sr, f32=True),
sr=16000,
format="wav",
),
16000,
True,
).getbuffer()
)
def pipeline(self, path, idx0):
try:

View File

@@ -133,11 +133,21 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
try:
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://", world_size=n_gpus, rank=rank
backend=(
"gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"
),
init_method="env://",
world_size=n_gpus,
rank=rank,
)
except:
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank
backend=(
"gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"
),
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
@@ -243,13 +253,17 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
if hasattr(net_g, "module"):
logger.info(
net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
torch.load(
hps.pretrainG, map_location="cpu", weights_only=True
)["model"]
)
) ##测试不加载优化器
else:
logger.info(
net_g.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
torch.load(
hps.pretrainG, map_location="cpu", weights_only=True
)["model"]
)
) ##测试不加载优化器
if hps.pretrainD != "":
@@ -258,13 +272,17 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
if hasattr(net_d, "module"):
logger.info(
net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
torch.load(
hps.pretrainD, map_location="cpu", weights_only=True
)["model"]
)
)
else:
logger.info(
net_d.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
torch.load(
hps.pretrainD, map_location="cpu", weights_only=True
)["model"]
)
)