mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-08 03:55:47 +08:00
optimize(infer): move jit into rvc
This commit is contained in:
13
rvc/f0/models.py
Normal file
13
rvc/f0/models.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False):
|
||||
from rvc.f0.e2e import E2E
|
||||
|
||||
model = E2E(4, 1, (2, 2))
|
||||
ckpt = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(ckpt)
|
||||
model.eval()
|
||||
if is_half:
|
||||
model = model.half()
|
||||
model = model.to(device)
|
||||
return model
|
||||
@@ -6,13 +6,36 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from infer.lib import jit
|
||||
from rvc.jit import load_inputs, get_jit_model, export_jit_model, save_pickle
|
||||
|
||||
from .mel import MelSpectrogram
|
||||
from .e2e import E2E
|
||||
from .f0 import F0Predictor
|
||||
from .models import get_rmvpe
|
||||
|
||||
|
||||
def rmvpe_jit_export(
|
||||
model_path: str,
|
||||
mode: str = "script",
|
||||
inputs_path: str = None,
|
||||
save_path: str = None,
|
||||
device=torch.device("cpu"),
|
||||
is_half=False,
|
||||
):
|
||||
if not save_path:
|
||||
save_path = model_path.rstrip(".pth")
|
||||
save_path += ".half.jit" if is_half else ".jit"
|
||||
if "cuda" in str(device) and ":" not in str(device):
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
model = get_rmvpe(model_path, device, is_half)
|
||||
inputs = None
|
||||
if mode == "trace":
|
||||
inputs = load_inputs(inputs_path, device, is_half)
|
||||
ckpt = export_jit_model(model, mode, inputs, device, is_half)
|
||||
ckpt["device"] = str(device)
|
||||
save_pickle(ckpt, save_path)
|
||||
return ckpt
|
||||
|
||||
class RMVPE(F0Predictor):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -57,51 +80,16 @@ class RMVPE(F0Predictor):
|
||||
providers=["DmlExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
|
||||
def get_jit_model():
|
||||
jit_model_path = model_path.rstrip(".pth")
|
||||
jit_model_path += ".half.jit" if is_half else ".jit"
|
||||
ckpt = None
|
||||
if os.path.exists(jit_model_path):
|
||||
ckpt = jit.load(jit_model_path)
|
||||
model_device = ckpt["device"]
|
||||
if model_device != str(self.device):
|
||||
del ckpt
|
||||
ckpt = None
|
||||
|
||||
if ckpt is None:
|
||||
ckpt = jit.rmvpe_jit_export(
|
||||
model_path=model_path,
|
||||
mode="script",
|
||||
inputs_path=None,
|
||||
save_path=jit_model_path,
|
||||
device=self.device,
|
||||
is_half=is_half,
|
||||
)
|
||||
|
||||
def rmvpe_jit_model():
|
||||
ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export)
|
||||
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
|
||||
model = model.to(self.device)
|
||||
return model
|
||||
|
||||
def get_default_model():
|
||||
model = E2E(4, 1, (2, 2))
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(ckpt)
|
||||
model.eval()
|
||||
if is_half:
|
||||
model = model.half()
|
||||
else:
|
||||
model = model.float()
|
||||
return model
|
||||
|
||||
if use_jit:
|
||||
if is_half and "cpu" in str(self.device):
|
||||
self.model = get_default_model()
|
||||
else:
|
||||
self.model = get_jit_model()
|
||||
if use_jit and not (is_half and "cpu" in str(self.device)):
|
||||
self.model = rmvpe_jit_model()
|
||||
else:
|
||||
self.model = get_default_model()
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
self.model = get_rmvpe(model_path, self.device, is_half)
|
||||
|
||||
def compute_f0(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user