From c51a73f521ea8d34eb4909774bdc043ba7d53c84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 14 Jun 2024 22:44:07 +0900 Subject: [PATCH] optimize(infer): move jit into rvc --- infer/lib/jit/__init__.py | 1 - infer/lib/jit/utils.py | 163 --------------------- infer/lib/rtrvc.py | 23 +-- infer/lib/jit/rmvpe.py => rvc/f0/models.py | 5 +- rvc/f0/rmvpe.py | 74 ++++------ rvc/jit/__init__.py | 1 + rvc/jit/jit.py | 76 ++++++++++ rvc/synthesizer.py | 29 ++++ 8 files changed, 143 insertions(+), 229 deletions(-) delete mode 100644 infer/lib/jit/__init__.py delete mode 100644 infer/lib/jit/utils.py rename infer/lib/jit/rmvpe.py => rvc/f0/models.py (79%) create mode 100644 rvc/jit/__init__.py create mode 100644 rvc/jit/jit.py diff --git a/infer/lib/jit/__init__.py b/infer/lib/jit/__init__.py deleted file mode 100644 index 1625df5..0000000 --- a/infer/lib/jit/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import load, rmvpe_jit_export, synthesizer_jit_export diff --git a/infer/lib/jit/utils.py b/infer/lib/jit/utils.py deleted file mode 100644 index d9896fa..0000000 --- a/infer/lib/jit/utils.py +++ /dev/null @@ -1,163 +0,0 @@ -from io import BytesIO -import pickle -import time -import torch -from tqdm import tqdm -from collections import OrderedDict - - -def load_inputs(path, device, is_half=False): - parm = torch.load(path, map_location=torch.device("cpu")) - for key in parm.keys(): - parm[key] = parm[key].to(device) - if is_half and parm[key].dtype == torch.float32: - parm[key] = parm[key].half() - elif not is_half and parm[key].dtype == torch.float16: - parm[key] = parm[key].float() - return parm - - -def benchmark( - model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False -): - parm = load_inputs(inputs_path, device, is_half) - total_ts = 0.0 - bar = tqdm(range(epoch)) - for i in bar: - start_time = time.perf_counter() - o = model(**parm) - total_ts += time.perf_counter() - start_time - print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}") - - -def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False): - benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half) - - -def to_jit_model( - model_path, - model_type: str, - mode: str = "trace", - inputs_path: str = None, - device=torch.device("cpu"), - is_half=False, -): - model = None - if model_type.lower() == "synthesizer": - from rvc.synthesizer import load_synthesizer - - model, _ = load_synthesizer(model_path, device) - model.forward = model.infer - elif model_type.lower() == "rmvpe": - from .rmvpe import get_rmvpe - - model = get_rmvpe(model_path, device) - elif model_type.lower() == "hubert": - from rvc.hubert import get_hubert - - model = get_hubert(model_path, device) - model.forward = model.infer - else: - raise ValueError(f"No model type named {model_type}") - model = model.eval() - model = model.half() if is_half else model.float() - if mode == "trace": - assert not inputs_path - inputs = load_inputs(inputs_path, device, is_half) - model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) - elif mode == "script": - model_jit = torch.jit.script(model) - model_jit.to(device) - model_jit = model_jit.half() if is_half else model_jit.float() - # model = model.half() if is_half else model.float() - return (model, model_jit) - - -def export( - model: torch.nn.Module, - mode: str = "trace", - inputs: dict = None, - device=torch.device("cpu"), - is_half: bool = False, -) -> dict: - model = model.half() if is_half else model.float() - model.eval() - if mode == "trace": - assert inputs is not None - model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) - elif mode == "script": - model_jit = torch.jit.script(model) - model_jit.to(device) - model_jit = model_jit.half() if is_half else model_jit.float() - buffer = BytesIO() - # model_jit=model_jit.cpu() - torch.jit.save(model_jit, buffer) - del model_jit - cpt = OrderedDict() - cpt["model"] = buffer.getvalue() - cpt["is_half"] = is_half - return cpt - - -def load(path: str): - with open(path, "rb") as f: - return pickle.load(f) - - -def save(ckpt: dict, save_path: str): - with open(save_path, "wb") as f: - pickle.dump(ckpt, f) - - -def rmvpe_jit_export( - model_path: str, - mode: str = "script", - inputs_path: str = None, - save_path: str = None, - device=torch.device("cpu"), - is_half=False, -): - if not save_path: - save_path = model_path.rstrip(".pth") - save_path += ".half.jit" if is_half else ".jit" - if "cuda" in str(device) and ":" not in str(device): - device = torch.device("cuda:0") - from .rmvpe import get_rmvpe - - model = get_rmvpe(model_path, device) - inputs = None - if mode == "trace": - inputs = load_inputs(inputs_path, device, is_half) - ckpt = export(model, mode, inputs, device, is_half) - ckpt["device"] = str(device) - save(ckpt, save_path) - return ckpt - - -def synthesizer_jit_export( - model_path: str, - mode: str = "script", - inputs_path: str = None, - save_path: str = None, - device=torch.device("cpu"), - is_half=False, -): - if not save_path: - save_path = model_path.rstrip(".pth") - save_path += ".half.jit" if is_half else ".jit" - if "cuda" in str(device) and ":" not in str(device): - device = torch.device("cuda:0") - from rvc.synthesizer import load_synthesizer - - model, cpt = load_synthesizer(model_path, device) - assert isinstance(cpt, dict) - model.forward = model.infer - inputs = None - if mode == "trace": - inputs = load_inputs(inputs_path, device, is_half) - ckpt = export(model, mode, inputs, device, is_half) - cpt.pop("weight") - cpt["model"] = ckpt["model"] - cpt["device"] = device - save(cpt, save_path) - return cpt diff --git a/infer/lib/rtrvc.py b/infer/lib/rtrvc.py index 77f8db0..acd2773 100644 --- a/infer/lib/rtrvc.py +++ b/infer/lib/rtrvc.py @@ -125,27 +125,10 @@ class RVC: self.net_g = self.net_g.float() def set_jit_model(): - jit_pth_path = self.pth_path.rstrip(".pth") - jit_pth_path += ".half.jit" if self.is_half else ".jit" - reload = False - if str(self.device) == "cuda": - self.device = torch.device("cuda:0") - if os.path.exists(jit_pth_path): - cpt = jit.load(jit_pth_path) - model_device = cpt["device"] - if model_device != str(self.device): - reload = True - else: - reload = True + from rvc.jit import get_jit_model + from rvc.synthesizer import synthesizer_jit_export - if reload: - cpt = jit.synthesizer_jit_export( - self.pth_path, - "script", - None, - device=self.device, - is_half=self.is_half, - ) + cpt = get_jit_model(self.pth_path, self.is_half, synthesizer_jit_export) self.tgt_sr = cpt["config"][-1] self.if_f0 = cpt.get("f0", 1) diff --git a/infer/lib/jit/rmvpe.py b/rvc/f0/models.py similarity index 79% rename from infer/lib/jit/rmvpe.py rename to rvc/f0/models.py index 6240802..28f6751 100644 --- a/infer/lib/jit/rmvpe.py +++ b/rvc/f0/models.py @@ -1,12 +1,13 @@ import torch - -def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu")): +def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False): from rvc.f0.e2e import E2E model = E2E(4, 1, (2, 2)) ckpt = torch.load(model_path, map_location=device) model.load_state_dict(ckpt) model.eval() + if is_half: + model = model.half() model = model.to(device) return model diff --git a/rvc/f0/rmvpe.py b/rvc/f0/rmvpe.py index c94b16b..29d4f09 100644 --- a/rvc/f0/rmvpe.py +++ b/rvc/f0/rmvpe.py @@ -6,13 +6,36 @@ import numpy as np import torch import torch.nn.functional as F -from infer.lib import jit +from rvc.jit import load_inputs, get_jit_model, export_jit_model, save_pickle from .mel import MelSpectrogram -from .e2e import E2E from .f0 import F0Predictor +from .models import get_rmvpe +def rmvpe_jit_export( + model_path: str, + mode: str = "script", + inputs_path: str = None, + save_path: str = None, + device=torch.device("cpu"), + is_half=False, +): + if not save_path: + save_path = model_path.rstrip(".pth") + save_path += ".half.jit" if is_half else ".jit" + if "cuda" in str(device) and ":" not in str(device): + device = torch.device("cuda:0") + + model = get_rmvpe(model_path, device, is_half) + inputs = None + if mode == "trace": + inputs = load_inputs(inputs_path, device, is_half) + ckpt = export_jit_model(model, mode, inputs, device, is_half) + ckpt["device"] = str(device) + save_pickle(ckpt, save_path) + return ckpt + class RMVPE(F0Predictor): def __init__( self, @@ -57,51 +80,16 @@ class RMVPE(F0Predictor): providers=["DmlExecutionProvider"], ) else: - - def get_jit_model(): - jit_model_path = model_path.rstrip(".pth") - jit_model_path += ".half.jit" if is_half else ".jit" - ckpt = None - if os.path.exists(jit_model_path): - ckpt = jit.load(jit_model_path) - model_device = ckpt["device"] - if model_device != str(self.device): - del ckpt - ckpt = None - - if ckpt is None: - ckpt = jit.rmvpe_jit_export( - model_path=model_path, - mode="script", - inputs_path=None, - save_path=jit_model_path, - device=self.device, - is_half=is_half, - ) - + def rmvpe_jit_model(): + ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export) model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device) + model = model.to(self.device) return model - def get_default_model(): - model = E2E(4, 1, (2, 2)) - ckpt = torch.load(model_path, map_location="cpu") - model.load_state_dict(ckpt) - model.eval() - if is_half: - model = model.half() - else: - model = model.float() - return model - - if use_jit: - if is_half and "cpu" in str(self.device): - self.model = get_default_model() - else: - self.model = get_jit_model() + if use_jit and not (is_half and "cpu" in str(self.device)): + self.model = rmvpe_jit_model() else: - self.model = get_default_model() - - self.model = self.model.to(self.device) + self.model = get_rmvpe(model_path, self.device, is_half) def compute_f0( self, diff --git a/rvc/jit/__init__.py b/rvc/jit/__init__.py new file mode 100644 index 0000000..56197a6 --- /dev/null +++ b/rvc/jit/__init__.py @@ -0,0 +1 @@ +from .jit import load_inputs, get_jit_model, export_jit_model, save_pickle \ No newline at end of file diff --git a/rvc/jit/jit.py b/rvc/jit/jit.py new file mode 100644 index 0000000..0da1010 --- /dev/null +++ b/rvc/jit/jit.py @@ -0,0 +1,76 @@ +import pickle +from io import BytesIO +from collections import OrderedDict +import os + +import torch + + +def load_pickle(path: str): + with open(path, "rb") as f: + return pickle.load(f) + + +def save_pickle(ckpt: dict, save_path: str): + with open(save_path, "wb") as f: + pickle.dump(ckpt, f) + +def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False): + parm = torch.load(path, map_location=torch.device("cpu")) + for key in parm.keys(): + parm[key] = parm[key].to(device) + if is_half and parm[key].dtype == torch.float32: + parm[key] = parm[key].half() + elif not is_half and parm[key].dtype == torch.float16: + parm[key] = parm[key].float() + return parm + +def export_jit_model( + model: torch.nn.Module, + mode: str = "trace", + inputs: dict = None, + device=torch.device("cpu"), + is_half: bool = False, +) -> dict: + model = model.half() if is_half else model.float() + model.eval() + if mode == "trace": + assert inputs is not None + model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) + elif mode == "script": + model_jit = torch.jit.script(model) + model_jit.to(device) + model_jit = model_jit.half() if is_half else model_jit.float() + buffer = BytesIO() + # model_jit=model_jit.cpu() + torch.jit.save(model_jit, buffer) + del model_jit + cpt = OrderedDict() + cpt["model"] = buffer.getvalue() + cpt["is_half"] = is_half + return cpt + + +def get_jit_model(model_path: str, is_half: bool, device: str, exporter): + jit_model_path = model_path.rstrip(".pth") + jit_model_path += ".half.jit" if is_half else ".jit" + ckpt = None + + if os.path.exists(jit_model_path): + ckpt = load_pickle(jit_model_path) + model_device = ckpt["device"] + if model_device != str(device): + del ckpt + ckpt = None + + if ckpt is None: + ckpt = exporter( + model_path=model_path, + mode="script", + inputs_path=None, + save_path=jit_model_path, + device=device, + is_half=is_half, + ) + + return ckpt diff --git a/rvc/synthesizer.py b/rvc/synthesizer.py index e972d7f..25bd77e 100644 --- a/rvc/synthesizer.py +++ b/rvc/synthesizer.py @@ -3,6 +3,7 @@ from collections import OrderedDict import torch from .layers.synthesizers import SynthesizerTrnMsNSFsid +from .jit import load_inputs, export_jit_model, save_pickle def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")): @@ -33,3 +34,31 @@ def load_synthesizer( torch.load(pth_path, map_location=torch.device("cpu")), device, ) + +def synthesizer_jit_export( + model_path: str, + mode: str = "script", + inputs_path: str = None, + save_path: str = None, + device=torch.device("cpu"), + is_half=False, +): + if not save_path: + save_path = model_path.rstrip(".pth") + save_path += ".half.jit" if is_half else ".jit" + if "cuda" in str(device) and ":" not in str(device): + device = torch.device("cuda:0") + from rvc.synthesizer import load_synthesizer + + model, cpt = load_synthesizer(model_path, device) + assert isinstance(cpt, dict) + model.forward = model.infer + inputs = None + if mode == "trace": + inputs = load_inputs(inputs_path, device, is_half) + ckpt = export_jit_model(model, mode, inputs, device, is_half) + cpt.pop("weight") + cpt["model"] = ckpt["model"] + cpt["device"] = device + save_pickle(cpt, save_path) + return cpt