1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

chore(format): run black on dev (#55)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2024-06-14 22:45:03 +09:00
committed by GitHub
parent c51a73f521
commit ed8a0c3e34
8 changed files with 14 additions and 3 deletions

View File

@@ -6,6 +6,7 @@ import shutil
from multiprocessing import cpu_count
import torch
# TODO: move device selection into rvc
import logging

View File

@@ -1,4 +1,4 @@
from . import ipex
import sys
del sys.modules["rvc.ipex"]
del sys.modules["rvc.ipex"]

View File

@@ -1,6 +1,9 @@
import torch
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False):
def get_rmvpe(
model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False
):
from rvc.f0.e2e import E2E
model = E2E(4, 1, (2, 2))

View File

@@ -36,6 +36,7 @@ def rmvpe_jit_export(
save_pickle(ckpt, save_path)
return ckpt
class RMVPE(F0Predictor):
def __init__(
self,
@@ -80,6 +81,7 @@ class RMVPE(F0Predictor):
providers=["DmlExecutionProvider"],
)
else:
def rmvpe_jit_model():
ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export)
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)

View File

@@ -1,7 +1,9 @@
try:
import torch
if torch.xpu.is_available():
from .init import ipex_init
ipex_init()
from .gradscaler import gradscaler_init
except Exception: # pylint: disable=broad-exception-caught

View File

@@ -1 +1 @@
from .jit import load_inputs, get_jit_model, export_jit_model, save_pickle
from .jit import load_inputs, get_jit_model, export_jit_model, save_pickle

View File

@@ -15,6 +15,7 @@ def save_pickle(ckpt: dict, save_path: str):
with open(save_path, "wb") as f:
pickle.dump(ckpt, f)
def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False):
parm = torch.load(path, map_location=torch.device("cpu"))
for key in parm.keys():
@@ -25,6 +26,7 @@ def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False)
parm[key] = parm[key].float()
return parm
def export_jit_model(
model: torch.nn.Module,
mode: str = "trace",

View File

@@ -35,6 +35,7 @@ def load_synthesizer(
device,
)
def synthesizer_jit_export(
model_path: str,
mode: str = "script",