1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-08 03:55:47 +08:00

optimize(infer): move jit into rvc

This commit is contained in:
源文雨
2024-06-14 22:44:07 +09:00
parent e936e24a91
commit c51a73f521
8 changed files with 143 additions and 229 deletions

13
rvc/f0/models.py Normal file
View File

@@ -0,0 +1,13 @@
import torch
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False):
from rvc.f0.e2e import E2E
model = E2E(4, 1, (2, 2))
ckpt = torch.load(model_path, map_location=device)
model.load_state_dict(ckpt)
model.eval()
if is_half:
model = model.half()
model = model.to(device)
return model

View File

@@ -6,13 +6,36 @@ import numpy as np
import torch
import torch.nn.functional as F
from infer.lib import jit
from rvc.jit import load_inputs, get_jit_model, export_jit_model, save_pickle
from .mel import MelSpectrogram
from .e2e import E2E
from .f0 import F0Predictor
from .models import get_rmvpe
def rmvpe_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
model = get_rmvpe(model_path, device, is_half)
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export_jit_model(model, mode, inputs, device, is_half)
ckpt["device"] = str(device)
save_pickle(ckpt, save_path)
return ckpt
class RMVPE(F0Predictor):
def __init__(
self,
@@ -57,51 +80,16 @@ class RMVPE(F0Predictor):
providers=["DmlExecutionProvider"],
)
else:
def get_jit_model():
jit_model_path = model_path.rstrip(".pth")
jit_model_path += ".half.jit" if is_half else ".jit"
ckpt = None
if os.path.exists(jit_model_path):
ckpt = jit.load(jit_model_path)
model_device = ckpt["device"]
if model_device != str(self.device):
del ckpt
ckpt = None
if ckpt is None:
ckpt = jit.rmvpe_jit_export(
model_path=model_path,
mode="script",
inputs_path=None,
save_path=jit_model_path,
device=self.device,
is_half=is_half,
)
def rmvpe_jit_model():
ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export)
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
model = model.to(self.device)
return model
def get_default_model():
model = E2E(4, 1, (2, 2))
ckpt = torch.load(model_path, map_location="cpu")
model.load_state_dict(ckpt)
model.eval()
if is_half:
model = model.half()
else:
model = model.float()
return model
if use_jit:
if is_half and "cpu" in str(self.device):
self.model = get_default_model()
else:
self.model = get_jit_model()
if use_jit and not (is_half and "cpu" in str(self.device)):
self.model = rmvpe_jit_model()
else:
self.model = get_default_model()
self.model = self.model.to(self.device)
self.model = get_rmvpe(model_path, self.device, is_half)
def compute_f0(
self,