mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-08 03:55:47 +08:00
optimize(infer): move jit into rvc
This commit is contained in:
76
rvc/jit/jit.py
Normal file
76
rvc/jit/jit.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pickle
|
||||
from io import BytesIO
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def load_pickle(path: str):
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def save_pickle(ckpt: dict, save_path: str):
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump(ckpt, f)
|
||||
|
||||
def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False):
|
||||
parm = torch.load(path, map_location=torch.device("cpu"))
|
||||
for key in parm.keys():
|
||||
parm[key] = parm[key].to(device)
|
||||
if is_half and parm[key].dtype == torch.float32:
|
||||
parm[key] = parm[key].half()
|
||||
elif not is_half and parm[key].dtype == torch.float16:
|
||||
parm[key] = parm[key].float()
|
||||
return parm
|
||||
|
||||
def export_jit_model(
|
||||
model: torch.nn.Module,
|
||||
mode: str = "trace",
|
||||
inputs: dict = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half: bool = False,
|
||||
) -> dict:
|
||||
model = model.half() if is_half else model.float()
|
||||
model.eval()
|
||||
if mode == "trace":
|
||||
assert inputs is not None
|
||||
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
||||
elif mode == "script":
|
||||
model_jit = torch.jit.script(model)
|
||||
model_jit.to(device)
|
||||
model_jit = model_jit.half() if is_half else model_jit.float()
|
||||
buffer = BytesIO()
|
||||
# model_jit=model_jit.cpu()
|
||||
torch.jit.save(model_jit, buffer)
|
||||
del model_jit
|
||||
cpt = OrderedDict()
|
||||
cpt["model"] = buffer.getvalue()
|
||||
cpt["is_half"] = is_half
|
||||
return cpt
|
||||
|
||||
|
||||
def get_jit_model(model_path: str, is_half: bool, device: str, exporter):
|
||||
jit_model_path = model_path.rstrip(".pth")
|
||||
jit_model_path += ".half.jit" if is_half else ".jit"
|
||||
ckpt = None
|
||||
|
||||
if os.path.exists(jit_model_path):
|
||||
ckpt = load_pickle(jit_model_path)
|
||||
model_device = ckpt["device"]
|
||||
if model_device != str(device):
|
||||
del ckpt
|
||||
ckpt = None
|
||||
|
||||
if ckpt is None:
|
||||
ckpt = exporter(
|
||||
model_path=model_path,
|
||||
mode="script",
|
||||
inputs_path=None,
|
||||
save_path=jit_model_path,
|
||||
device=device,
|
||||
is_half=is_half,
|
||||
)
|
||||
|
||||
return ckpt
|
||||
Reference in New Issue
Block a user