mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 09:10:25 +08:00
optimize(infer): move jit into rvc
This commit is contained in:
@@ -1 +0,0 @@
|
||||
from .utils import load, rmvpe_jit_export, synthesizer_jit_export
|
||||
@@ -1,163 +0,0 @@
|
||||
from io import BytesIO
|
||||
import pickle
|
||||
import time
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def load_inputs(path, device, is_half=False):
|
||||
parm = torch.load(path, map_location=torch.device("cpu"))
|
||||
for key in parm.keys():
|
||||
parm[key] = parm[key].to(device)
|
||||
if is_half and parm[key].dtype == torch.float32:
|
||||
parm[key] = parm[key].half()
|
||||
elif not is_half and parm[key].dtype == torch.float16:
|
||||
parm[key] = parm[key].float()
|
||||
return parm
|
||||
|
||||
|
||||
def benchmark(
|
||||
model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
|
||||
):
|
||||
parm = load_inputs(inputs_path, device, is_half)
|
||||
total_ts = 0.0
|
||||
bar = tqdm(range(epoch))
|
||||
for i in bar:
|
||||
start_time = time.perf_counter()
|
||||
o = model(**parm)
|
||||
total_ts += time.perf_counter() - start_time
|
||||
print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
|
||||
|
||||
|
||||
def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
|
||||
benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
|
||||
|
||||
|
||||
def to_jit_model(
|
||||
model_path,
|
||||
model_type: str,
|
||||
mode: str = "trace",
|
||||
inputs_path: str = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half=False,
|
||||
):
|
||||
model = None
|
||||
if model_type.lower() == "synthesizer":
|
||||
from rvc.synthesizer import load_synthesizer
|
||||
|
||||
model, _ = load_synthesizer(model_path, device)
|
||||
model.forward = model.infer
|
||||
elif model_type.lower() == "rmvpe":
|
||||
from .rmvpe import get_rmvpe
|
||||
|
||||
model = get_rmvpe(model_path, device)
|
||||
elif model_type.lower() == "hubert":
|
||||
from rvc.hubert import get_hubert
|
||||
|
||||
model = get_hubert(model_path, device)
|
||||
model.forward = model.infer
|
||||
else:
|
||||
raise ValueError(f"No model type named {model_type}")
|
||||
model = model.eval()
|
||||
model = model.half() if is_half else model.float()
|
||||
if mode == "trace":
|
||||
assert not inputs_path
|
||||
inputs = load_inputs(inputs_path, device, is_half)
|
||||
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
||||
elif mode == "script":
|
||||
model_jit = torch.jit.script(model)
|
||||
model_jit.to(device)
|
||||
model_jit = model_jit.half() if is_half else model_jit.float()
|
||||
# model = model.half() if is_half else model.float()
|
||||
return (model, model_jit)
|
||||
|
||||
|
||||
def export(
|
||||
model: torch.nn.Module,
|
||||
mode: str = "trace",
|
||||
inputs: dict = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half: bool = False,
|
||||
) -> dict:
|
||||
model = model.half() if is_half else model.float()
|
||||
model.eval()
|
||||
if mode == "trace":
|
||||
assert inputs is not None
|
||||
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
||||
elif mode == "script":
|
||||
model_jit = torch.jit.script(model)
|
||||
model_jit.to(device)
|
||||
model_jit = model_jit.half() if is_half else model_jit.float()
|
||||
buffer = BytesIO()
|
||||
# model_jit=model_jit.cpu()
|
||||
torch.jit.save(model_jit, buffer)
|
||||
del model_jit
|
||||
cpt = OrderedDict()
|
||||
cpt["model"] = buffer.getvalue()
|
||||
cpt["is_half"] = is_half
|
||||
return cpt
|
||||
|
||||
|
||||
def load(path: str):
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def save(ckpt: dict, save_path: str):
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump(ckpt, f)
|
||||
|
||||
|
||||
def rmvpe_jit_export(
|
||||
model_path: str,
|
||||
mode: str = "script",
|
||||
inputs_path: str = None,
|
||||
save_path: str = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half=False,
|
||||
):
|
||||
if not save_path:
|
||||
save_path = model_path.rstrip(".pth")
|
||||
save_path += ".half.jit" if is_half else ".jit"
|
||||
if "cuda" in str(device) and ":" not in str(device):
|
||||
device = torch.device("cuda:0")
|
||||
from .rmvpe import get_rmvpe
|
||||
|
||||
model = get_rmvpe(model_path, device)
|
||||
inputs = None
|
||||
if mode == "trace":
|
||||
inputs = load_inputs(inputs_path, device, is_half)
|
||||
ckpt = export(model, mode, inputs, device, is_half)
|
||||
ckpt["device"] = str(device)
|
||||
save(ckpt, save_path)
|
||||
return ckpt
|
||||
|
||||
|
||||
def synthesizer_jit_export(
|
||||
model_path: str,
|
||||
mode: str = "script",
|
||||
inputs_path: str = None,
|
||||
save_path: str = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half=False,
|
||||
):
|
||||
if not save_path:
|
||||
save_path = model_path.rstrip(".pth")
|
||||
save_path += ".half.jit" if is_half else ".jit"
|
||||
if "cuda" in str(device) and ":" not in str(device):
|
||||
device = torch.device("cuda:0")
|
||||
from rvc.synthesizer import load_synthesizer
|
||||
|
||||
model, cpt = load_synthesizer(model_path, device)
|
||||
assert isinstance(cpt, dict)
|
||||
model.forward = model.infer
|
||||
inputs = None
|
||||
if mode == "trace":
|
||||
inputs = load_inputs(inputs_path, device, is_half)
|
||||
ckpt = export(model, mode, inputs, device, is_half)
|
||||
cpt.pop("weight")
|
||||
cpt["model"] = ckpt["model"]
|
||||
cpt["device"] = device
|
||||
save(cpt, save_path)
|
||||
return cpt
|
||||
@@ -125,27 +125,10 @@ class RVC:
|
||||
self.net_g = self.net_g.float()
|
||||
|
||||
def set_jit_model():
|
||||
jit_pth_path = self.pth_path.rstrip(".pth")
|
||||
jit_pth_path += ".half.jit" if self.is_half else ".jit"
|
||||
reload = False
|
||||
if str(self.device) == "cuda":
|
||||
self.device = torch.device("cuda:0")
|
||||
if os.path.exists(jit_pth_path):
|
||||
cpt = jit.load(jit_pth_path)
|
||||
model_device = cpt["device"]
|
||||
if model_device != str(self.device):
|
||||
reload = True
|
||||
else:
|
||||
reload = True
|
||||
from rvc.jit import get_jit_model
|
||||
from rvc.synthesizer import synthesizer_jit_export
|
||||
|
||||
if reload:
|
||||
cpt = jit.synthesizer_jit_export(
|
||||
self.pth_path,
|
||||
"script",
|
||||
None,
|
||||
device=self.device,
|
||||
is_half=self.is_half,
|
||||
)
|
||||
cpt = get_jit_model(self.pth_path, self.is_half, synthesizer_jit_export)
|
||||
|
||||
self.tgt_sr = cpt["config"][-1]
|
||||
self.if_f0 = cpt.get("f0", 1)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu")):
|
||||
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False):
|
||||
from rvc.f0.e2e import E2E
|
||||
|
||||
model = E2E(4, 1, (2, 2))
|
||||
ckpt = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(ckpt)
|
||||
model.eval()
|
||||
if is_half:
|
||||
model = model.half()
|
||||
model = model.to(device)
|
||||
return model
|
||||
@@ -6,13 +6,36 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from infer.lib import jit
|
||||
from rvc.jit import load_inputs, get_jit_model, export_jit_model, save_pickle
|
||||
|
||||
from .mel import MelSpectrogram
|
||||
from .e2e import E2E
|
||||
from .f0 import F0Predictor
|
||||
from .models import get_rmvpe
|
||||
|
||||
|
||||
def rmvpe_jit_export(
|
||||
model_path: str,
|
||||
mode: str = "script",
|
||||
inputs_path: str = None,
|
||||
save_path: str = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half=False,
|
||||
):
|
||||
if not save_path:
|
||||
save_path = model_path.rstrip(".pth")
|
||||
save_path += ".half.jit" if is_half else ".jit"
|
||||
if "cuda" in str(device) and ":" not in str(device):
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
model = get_rmvpe(model_path, device, is_half)
|
||||
inputs = None
|
||||
if mode == "trace":
|
||||
inputs = load_inputs(inputs_path, device, is_half)
|
||||
ckpt = export_jit_model(model, mode, inputs, device, is_half)
|
||||
ckpt["device"] = str(device)
|
||||
save_pickle(ckpt, save_path)
|
||||
return ckpt
|
||||
|
||||
class RMVPE(F0Predictor):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -57,51 +80,16 @@ class RMVPE(F0Predictor):
|
||||
providers=["DmlExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
|
||||
def get_jit_model():
|
||||
jit_model_path = model_path.rstrip(".pth")
|
||||
jit_model_path += ".half.jit" if is_half else ".jit"
|
||||
ckpt = None
|
||||
if os.path.exists(jit_model_path):
|
||||
ckpt = jit.load(jit_model_path)
|
||||
model_device = ckpt["device"]
|
||||
if model_device != str(self.device):
|
||||
del ckpt
|
||||
ckpt = None
|
||||
|
||||
if ckpt is None:
|
||||
ckpt = jit.rmvpe_jit_export(
|
||||
model_path=model_path,
|
||||
mode="script",
|
||||
inputs_path=None,
|
||||
save_path=jit_model_path,
|
||||
device=self.device,
|
||||
is_half=is_half,
|
||||
)
|
||||
|
||||
def rmvpe_jit_model():
|
||||
ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export)
|
||||
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
|
||||
model = model.to(self.device)
|
||||
return model
|
||||
|
||||
def get_default_model():
|
||||
model = E2E(4, 1, (2, 2))
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(ckpt)
|
||||
model.eval()
|
||||
if is_half:
|
||||
model = model.half()
|
||||
else:
|
||||
model = model.float()
|
||||
return model
|
||||
|
||||
if use_jit:
|
||||
if is_half and "cpu" in str(self.device):
|
||||
self.model = get_default_model()
|
||||
else:
|
||||
self.model = get_jit_model()
|
||||
if use_jit and not (is_half and "cpu" in str(self.device)):
|
||||
self.model = rmvpe_jit_model()
|
||||
else:
|
||||
self.model = get_default_model()
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
self.model = get_rmvpe(model_path, self.device, is_half)
|
||||
|
||||
def compute_f0(
|
||||
self,
|
||||
|
||||
1
rvc/jit/__init__.py
Normal file
1
rvc/jit/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .jit import load_inputs, get_jit_model, export_jit_model, save_pickle
|
||||
76
rvc/jit/jit.py
Normal file
76
rvc/jit/jit.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pickle
|
||||
from io import BytesIO
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def load_pickle(path: str):
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def save_pickle(ckpt: dict, save_path: str):
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump(ckpt, f)
|
||||
|
||||
def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False):
|
||||
parm = torch.load(path, map_location=torch.device("cpu"))
|
||||
for key in parm.keys():
|
||||
parm[key] = parm[key].to(device)
|
||||
if is_half and parm[key].dtype == torch.float32:
|
||||
parm[key] = parm[key].half()
|
||||
elif not is_half and parm[key].dtype == torch.float16:
|
||||
parm[key] = parm[key].float()
|
||||
return parm
|
||||
|
||||
def export_jit_model(
|
||||
model: torch.nn.Module,
|
||||
mode: str = "trace",
|
||||
inputs: dict = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half: bool = False,
|
||||
) -> dict:
|
||||
model = model.half() if is_half else model.float()
|
||||
model.eval()
|
||||
if mode == "trace":
|
||||
assert inputs is not None
|
||||
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
||||
elif mode == "script":
|
||||
model_jit = torch.jit.script(model)
|
||||
model_jit.to(device)
|
||||
model_jit = model_jit.half() if is_half else model_jit.float()
|
||||
buffer = BytesIO()
|
||||
# model_jit=model_jit.cpu()
|
||||
torch.jit.save(model_jit, buffer)
|
||||
del model_jit
|
||||
cpt = OrderedDict()
|
||||
cpt["model"] = buffer.getvalue()
|
||||
cpt["is_half"] = is_half
|
||||
return cpt
|
||||
|
||||
|
||||
def get_jit_model(model_path: str, is_half: bool, device: str, exporter):
|
||||
jit_model_path = model_path.rstrip(".pth")
|
||||
jit_model_path += ".half.jit" if is_half else ".jit"
|
||||
ckpt = None
|
||||
|
||||
if os.path.exists(jit_model_path):
|
||||
ckpt = load_pickle(jit_model_path)
|
||||
model_device = ckpt["device"]
|
||||
if model_device != str(device):
|
||||
del ckpt
|
||||
ckpt = None
|
||||
|
||||
if ckpt is None:
|
||||
ckpt = exporter(
|
||||
model_path=model_path,
|
||||
mode="script",
|
||||
inputs_path=None,
|
||||
save_path=jit_model_path,
|
||||
device=device,
|
||||
is_half=is_half,
|
||||
)
|
||||
|
||||
return ckpt
|
||||
@@ -3,6 +3,7 @@ from collections import OrderedDict
|
||||
import torch
|
||||
|
||||
from .layers.synthesizers import SynthesizerTrnMsNSFsid
|
||||
from .jit import load_inputs, export_jit_model, save_pickle
|
||||
|
||||
|
||||
def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")):
|
||||
@@ -33,3 +34,31 @@ def load_synthesizer(
|
||||
torch.load(pth_path, map_location=torch.device("cpu")),
|
||||
device,
|
||||
)
|
||||
|
||||
def synthesizer_jit_export(
|
||||
model_path: str,
|
||||
mode: str = "script",
|
||||
inputs_path: str = None,
|
||||
save_path: str = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half=False,
|
||||
):
|
||||
if not save_path:
|
||||
save_path = model_path.rstrip(".pth")
|
||||
save_path += ".half.jit" if is_half else ".jit"
|
||||
if "cuda" in str(device) and ":" not in str(device):
|
||||
device = torch.device("cuda:0")
|
||||
from rvc.synthesizer import load_synthesizer
|
||||
|
||||
model, cpt = load_synthesizer(model_path, device)
|
||||
assert isinstance(cpt, dict)
|
||||
model.forward = model.infer
|
||||
inputs = None
|
||||
if mode == "trace":
|
||||
inputs = load_inputs(inputs_path, device, is_half)
|
||||
ckpt = export_jit_model(model, mode, inputs, device, is_half)
|
||||
cpt.pop("weight")
|
||||
cpt["model"] = ckpt["model"]
|
||||
cpt["device"] = device
|
||||
save_pickle(cpt, save_path)
|
||||
return cpt
|
||||
|
||||
Reference in New Issue
Block a user