mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
81 lines
2.1 KiB
Python
81 lines
2.1 KiB
Python
import pickle
|
|
from io import BytesIO
|
|
from collections import OrderedDict
|
|
import os
|
|
|
|
import torch
|
|
|
|
from rvc.utils import FileLike
|
|
|
|
|
|
def load_pickle(path: str):
|
|
with open(path, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
|
|
def save_pickle(ckpt: dict, save_path: str):
|
|
with open(save_path, "wb") as f:
|
|
pickle.dump(ckpt, f)
|
|
|
|
|
|
def load_inputs(path: FileLike, device: str, is_half=False): # type: ignore
|
|
parm = torch.load(path, map_location=torch.device("cpu"))
|
|
for key in parm.keys():
|
|
parm[key] = parm[key].to(device)
|
|
if is_half and parm[key].dtype == torch.float32:
|
|
parm[key] = parm[key].half()
|
|
elif not is_half and parm[key].dtype == torch.float16:
|
|
parm[key] = parm[key].float()
|
|
return parm
|
|
|
|
|
|
def export_jit_model(
|
|
model: torch.nn.Module,
|
|
mode: str = "trace",
|
|
inputs: dict = None,
|
|
device=torch.device("cpu"),
|
|
is_half: bool = False,
|
|
) -> dict:
|
|
model = model.half() if is_half else model.float()
|
|
model.eval()
|
|
if mode == "trace":
|
|
assert inputs is not None
|
|
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
|
elif mode == "script":
|
|
model_jit = torch.jit.script(model)
|
|
model_jit.to(device)
|
|
model_jit = model_jit.half() if is_half else model_jit.float()
|
|
buffer = BytesIO()
|
|
# model_jit=model_jit.cpu()
|
|
torch.jit.save(model_jit, buffer)
|
|
del model_jit
|
|
cpt = OrderedDict()
|
|
cpt["model"] = buffer.getvalue()
|
|
cpt["is_half"] = is_half
|
|
return cpt
|
|
|
|
|
|
def get_jit_model(model_path: str, is_half: bool, device: str, exporter):
|
|
jit_model_path = model_path.rstrip(".pth")
|
|
jit_model_path += ".half.jit" if is_half else ".jit"
|
|
ckpt = None
|
|
|
|
if os.path.exists(jit_model_path):
|
|
ckpt = load_pickle(jit_model_path)
|
|
model_device = ckpt["device"]
|
|
if model_device != str(device):
|
|
del ckpt
|
|
ckpt = None
|
|
|
|
if ckpt is None:
|
|
ckpt = exporter(
|
|
model_path=model_path,
|
|
mode="script",
|
|
inputs_path=None,
|
|
save_path=jit_model_path,
|
|
device=device,
|
|
is_half=is_half,
|
|
)
|
|
|
|
return ckpt
|