1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

optimize(jit): move hubert & synthesizer into rvc

This commit is contained in:
源文雨
2024-06-12 00:03:26 +09:00
parent 0efe48c49c
commit 54f7ae097d
15 changed files with 30 additions and 27 deletions

View File

@@ -1,2 +1 @@
from .utils import load, rmvpe_jit_export, synthesizer_jit_export
from .synthesizer import get_synthesizer, get_synthesizer_ckpt

View File

@@ -44,18 +44,18 @@ def to_jit_model(
):
model = None
if model_type.lower() == "synthesizer":
from .synthesizer import get_synthesizer
from rvc.synthesizer import load_synthesizer
model, _ = get_synthesizer(model_path, device)
model, _ = load_synthesizer(model_path, device)
model.forward = model.infer
elif model_type.lower() == "rmvpe":
from .rmvpe import get_rmvpe
model = get_rmvpe(model_path, device)
elif model_type.lower() == "hubert":
from .hubert import get_hubert_model
from rvc.hubert import get_hubert
model = get_hubert_model(model_path, device)
model = get_hubert(model_path, device)
model.forward = model.infer
else:
raise ValueError(f"No model type named {model_type}")
@@ -147,9 +147,9 @@ def synthesizer_jit_export(
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
from .synthesizer import get_synthesizer
from rvc.synthesizer import load_synthesizer
model, cpt = get_synthesizer(model_path, device)
model, cpt = load_synthesizer(model_path, device)
assert isinstance(cpt, dict)
model.forward = model.infer
inputs = None

View File

@@ -518,16 +518,15 @@ class RMVPE:
def get_jit_model():
jit_model_path = model_path.rstrip(".pth")
jit_model_path += ".half.jit" if is_half else ".jit"
reload = False
ckpt = None
if os.path.exists(jit_model_path):
ckpt = jit.load(jit_model_path)
model_device = ckpt["device"]
if model_device != str(self.device):
reload = True
else:
reload = True
del ckpt
ckpt = None
if reload:
if ckpt is None:
ckpt = jit.rmvpe_jit_export(
model_path=model_path,
mode="script",
@@ -536,6 +535,7 @@ class RMVPE:
device=device,
is_half=is_half,
)
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=device)
return model

View File

@@ -16,6 +16,8 @@ import torch.nn.functional as F
import torchcrepe
from torchaudio.transforms import Resample
from rvc.synthesizer import load_synthesizer
now_dir = os.getcwd()
sys.path.append(now_dir)
from multiprocessing import Manager as M
@@ -113,7 +115,7 @@ class RVC:
self.net_g: nn.Module = None
def set_default_model():
self.net_g, cpt = jit.get_synthesizer(self.pth_path, self.device)
self.net_g, cpt = load_synthesizer(self.pth_path, self.device)
self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1)

View File

@@ -6,7 +6,7 @@ from scipy.fft import fft
from pybase16384 import encode_to_string, decode_from_string
from configs import CPUConfig, singleton_variable
from infer.lib.jit import get_synthesizer_ckpt
from rvc.synthesizer import get_synthesizer
from .pipeline import Pipeline
from .utils import load_hubert
@@ -132,7 +132,7 @@ def model_hash_ckpt(cpt):
config = CPUConfig()
with TorchSeedContext(114514):
net_g, cpt = get_synthesizer_ckpt(cpt, config.device)
net_g, cpt = get_synthesizer(cpt, config.device)
tgt_sr = cpt["config"][-1]
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")

View File

@@ -10,7 +10,7 @@ import torch
from io import BytesIO
from infer.lib.audio import load_audio, wav2
from infer.lib.jit import get_synthesizer_ckpt, get_synthesizer
from rvc.synthesizer import get_synthesizer, load_synthesizer
from .info import show_model_info
from .pipeline import Pipeline
from .utils import get_index_path_from_model, load_hubert
@@ -62,7 +62,7 @@ class VC:
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
###楼下不这么折腾清理不干净
self.net_g, self.cpt = get_synthesizer_ckpt(
self.net_g, self.cpt = get_synthesizer(
self.cpt, self.config.device
)
self.if_f0 = self.cpt.get("f0", 1)
@@ -88,7 +88,7 @@ class VC:
person = f'{os.getenv("weight_root")}/{sid}'
logger.info(f"Loading: {person}")
self.net_g, self.cpt = get_synthesizer(person, self.config.device)
self.net_g, self.cpt = load_synthesizer(person, self.config.device)
self.tgt_sr = self.cpt["config"][-1]
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
self.if_f0 = self.cpt.get("f0", 1)

0
rvc/f0/__init__.py Normal file
View File

View File

@@ -1,14 +1,13 @@
import math
import random
from typing import Optional, Tuple
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.utils import index_put
import numpy as np
import torch
import torch.nn.functional as F
# from fairseq.data.data_utils import compute_mask_indices
from fairseq.utils import index_put
# @torch.jit.script
def pad_to_multiple(x, multiple, dim=-1, value=0):
@@ -263,7 +262,7 @@ def apply_mask(self, x, padding_mask, target_list):
return x, mask_indices
def get_hubert_model(
def get_hubert(
model_path="assets/hubert/hubert_base.pt", device=torch.device("cpu")
):
models, _, _ = load_model_ensemble_and_task(

View File

@@ -5,7 +5,7 @@ import librosa
import numpy as np
import onnxruntime
from .f0predictors import (
from .f0 import (
PMF0Predictor,
HarvestF0Predictor,
DioF0Predictor,

View File

@@ -1,9 +1,11 @@
from collections import OrderedDict
import torch
from rvc.layers.synthesizers import SynthesizerTrnMsNSFsid
from .layers.synthesizers import SynthesizerTrnMsNSFsid
def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")):
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
@@ -24,8 +26,9 @@ def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
return net_g, cpt
def get_synthesizer(pth_path, device=torch.device("cpu")):
return get_synthesizer_ckpt(
def load_synthesizer(
pth_path: torch.serialization.FILE_LIKE, device=torch.device("cpu")):
return get_synthesizer(
torch.load(pth_path, map_location=torch.device("cpu")),
device,
)