diff --git a/configs/config.py b/configs/config.py index 14d3827..0e8057c 100644 --- a/configs/config.py +++ b/configs/config.py @@ -6,6 +6,7 @@ import shutil from multiprocessing import cpu_count import torch + # TODO: move device selection into rvc import logging diff --git a/rvc/__init__.py b/rvc/__init__.py index bfe7152..a7a2950 100644 --- a/rvc/__init__.py +++ b/rvc/__init__.py @@ -1,4 +1,4 @@ from . import ipex import sys -del sys.modules["rvc.ipex"] +del sys.modules["rvc.ipex"] diff --git a/rvc/f0/models.py b/rvc/f0/models.py index 28f6751..7c2853e 100644 --- a/rvc/f0/models.py +++ b/rvc/f0/models.py @@ -1,6 +1,9 @@ import torch -def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False): + +def get_rmvpe( + model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False +): from rvc.f0.e2e import E2E model = E2E(4, 1, (2, 2)) diff --git a/rvc/f0/rmvpe.py b/rvc/f0/rmvpe.py index 29d4f09..d878ef4 100644 --- a/rvc/f0/rmvpe.py +++ b/rvc/f0/rmvpe.py @@ -36,6 +36,7 @@ def rmvpe_jit_export( save_pickle(ckpt, save_path) return ckpt + class RMVPE(F0Predictor): def __init__( self, @@ -80,6 +81,7 @@ class RMVPE(F0Predictor): providers=["DmlExecutionProvider"], ) else: + def rmvpe_jit_model(): ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export) model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device) diff --git a/rvc/ipex/__init__.py b/rvc/ipex/__init__.py index 12288d0..93f59a4 100644 --- a/rvc/ipex/__init__.py +++ b/rvc/ipex/__init__.py @@ -1,7 +1,9 @@ try: import torch + if torch.xpu.is_available(): from .init import ipex_init + ipex_init() from .gradscaler import gradscaler_init except Exception: # pylint: disable=broad-exception-caught diff --git a/rvc/jit/__init__.py b/rvc/jit/__init__.py index 56197a6..9016b93 100644 --- a/rvc/jit/__init__.py +++ b/rvc/jit/__init__.py @@ -1 +1 @@ -from .jit import load_inputs, get_jit_model, export_jit_model, save_pickle \ No newline at end of file +from .jit import load_inputs, get_jit_model, export_jit_model, save_pickle diff --git a/rvc/jit/jit.py b/rvc/jit/jit.py index 0da1010..1da0f2c 100644 --- a/rvc/jit/jit.py +++ b/rvc/jit/jit.py @@ -15,6 +15,7 @@ def save_pickle(ckpt: dict, save_path: str): with open(save_path, "wb") as f: pickle.dump(ckpt, f) + def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False): parm = torch.load(path, map_location=torch.device("cpu")) for key in parm.keys(): @@ -25,6 +26,7 @@ def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False) parm[key] = parm[key].float() return parm + def export_jit_model( model: torch.nn.Module, mode: str = "trace", diff --git a/rvc/synthesizer.py b/rvc/synthesizer.py index 25bd77e..bfc7c4d 100644 --- a/rvc/synthesizer.py +++ b/rvc/synthesizer.py @@ -35,6 +35,7 @@ def load_synthesizer( device, ) + def synthesizer_jit_export( model_path: str, mode: str = "script",