mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-09 20:40:48 +08:00
optimize(infer): move jit into rvc
This commit is contained in:
@@ -3,6 +3,7 @@ from collections import OrderedDict
|
||||
import torch
|
||||
|
||||
from .layers.synthesizers import SynthesizerTrnMsNSFsid
|
||||
from .jit import load_inputs, export_jit_model, save_pickle
|
||||
|
||||
|
||||
def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")):
|
||||
@@ -33,3 +34,31 @@ def load_synthesizer(
|
||||
torch.load(pth_path, map_location=torch.device("cpu")),
|
||||
device,
|
||||
)
|
||||
|
||||
def synthesizer_jit_export(
|
||||
model_path: str,
|
||||
mode: str = "script",
|
||||
inputs_path: str = None,
|
||||
save_path: str = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half=False,
|
||||
):
|
||||
if not save_path:
|
||||
save_path = model_path.rstrip(".pth")
|
||||
save_path += ".half.jit" if is_half else ".jit"
|
||||
if "cuda" in str(device) and ":" not in str(device):
|
||||
device = torch.device("cuda:0")
|
||||
from rvc.synthesizer import load_synthesizer
|
||||
|
||||
model, cpt = load_synthesizer(model_path, device)
|
||||
assert isinstance(cpt, dict)
|
||||
model.forward = model.infer
|
||||
inputs = None
|
||||
if mode == "trace":
|
||||
inputs = load_inputs(inputs_path, device, is_half)
|
||||
ckpt = export_jit_model(model, mode, inputs, device, is_half)
|
||||
cpt.pop("weight")
|
||||
cpt["model"] = ckpt["model"]
|
||||
cpt["device"] = device
|
||||
save_pickle(cpt, save_path)
|
||||
return cpt
|
||||
|
||||
Reference in New Issue
Block a user