mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 09:10:25 +08:00
optimize(jit): move hubert & synthesizer into rvc
This commit is contained in:
@@ -1,2 +1 @@
|
||||
from .utils import load, rmvpe_jit_export, synthesizer_jit_export
|
||||
from .synthesizer import get_synthesizer, get_synthesizer_ckpt
|
||||
|
||||
@@ -44,18 +44,18 @@ def to_jit_model(
|
||||
):
|
||||
model = None
|
||||
if model_type.lower() == "synthesizer":
|
||||
from .synthesizer import get_synthesizer
|
||||
from rvc.synthesizer import load_synthesizer
|
||||
|
||||
model, _ = get_synthesizer(model_path, device)
|
||||
model, _ = load_synthesizer(model_path, device)
|
||||
model.forward = model.infer
|
||||
elif model_type.lower() == "rmvpe":
|
||||
from .rmvpe import get_rmvpe
|
||||
|
||||
model = get_rmvpe(model_path, device)
|
||||
elif model_type.lower() == "hubert":
|
||||
from .hubert import get_hubert_model
|
||||
from rvc.hubert import get_hubert
|
||||
|
||||
model = get_hubert_model(model_path, device)
|
||||
model = get_hubert(model_path, device)
|
||||
model.forward = model.infer
|
||||
else:
|
||||
raise ValueError(f"No model type named {model_type}")
|
||||
@@ -147,9 +147,9 @@ def synthesizer_jit_export(
|
||||
save_path += ".half.jit" if is_half else ".jit"
|
||||
if "cuda" in str(device) and ":" not in str(device):
|
||||
device = torch.device("cuda:0")
|
||||
from .synthesizer import get_synthesizer
|
||||
from rvc.synthesizer import load_synthesizer
|
||||
|
||||
model, cpt = get_synthesizer(model_path, device)
|
||||
model, cpt = load_synthesizer(model_path, device)
|
||||
assert isinstance(cpt, dict)
|
||||
model.forward = model.infer
|
||||
inputs = None
|
||||
|
||||
@@ -518,16 +518,15 @@ class RMVPE:
|
||||
def get_jit_model():
|
||||
jit_model_path = model_path.rstrip(".pth")
|
||||
jit_model_path += ".half.jit" if is_half else ".jit"
|
||||
reload = False
|
||||
ckpt = None
|
||||
if os.path.exists(jit_model_path):
|
||||
ckpt = jit.load(jit_model_path)
|
||||
model_device = ckpt["device"]
|
||||
if model_device != str(self.device):
|
||||
reload = True
|
||||
else:
|
||||
reload = True
|
||||
del ckpt
|
||||
ckpt = None
|
||||
|
||||
if reload:
|
||||
if ckpt is None:
|
||||
ckpt = jit.rmvpe_jit_export(
|
||||
model_path=model_path,
|
||||
mode="script",
|
||||
@@ -536,6 +535,7 @@ class RMVPE:
|
||||
device=device,
|
||||
is_half=is_half,
|
||||
)
|
||||
|
||||
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=device)
|
||||
return model
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@ import torch.nn.functional as F
|
||||
import torchcrepe
|
||||
from torchaudio.transforms import Resample
|
||||
|
||||
from rvc.synthesizer import load_synthesizer
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from multiprocessing import Manager as M
|
||||
@@ -113,7 +115,7 @@ class RVC:
|
||||
self.net_g: nn.Module = None
|
||||
|
||||
def set_default_model():
|
||||
self.net_g, cpt = jit.get_synthesizer(self.pth_path, self.device)
|
||||
self.net_g, cpt = load_synthesizer(self.pth_path, self.device)
|
||||
self.tgt_sr = cpt["config"][-1]
|
||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||
self.if_f0 = cpt.get("f0", 1)
|
||||
|
||||
@@ -6,7 +6,7 @@ from scipy.fft import fft
|
||||
from pybase16384 import encode_to_string, decode_from_string
|
||||
|
||||
from configs import CPUConfig, singleton_variable
|
||||
from infer.lib.jit import get_synthesizer_ckpt
|
||||
from rvc.synthesizer import get_synthesizer
|
||||
|
||||
from .pipeline import Pipeline
|
||||
from .utils import load_hubert
|
||||
@@ -132,7 +132,7 @@ def model_hash_ckpt(cpt):
|
||||
config = CPUConfig()
|
||||
|
||||
with TorchSeedContext(114514):
|
||||
net_g, cpt = get_synthesizer_ckpt(cpt, config.device)
|
||||
net_g, cpt = get_synthesizer(cpt, config.device)
|
||||
tgt_sr = cpt["config"][-1]
|
||||
if_f0 = cpt.get("f0", 1)
|
||||
version = cpt.get("version", "v1")
|
||||
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
from io import BytesIO
|
||||
|
||||
from infer.lib.audio import load_audio, wav2
|
||||
from infer.lib.jit import get_synthesizer_ckpt, get_synthesizer
|
||||
from rvc.synthesizer import get_synthesizer, load_synthesizer
|
||||
from .info import show_model_info
|
||||
from .pipeline import Pipeline
|
||||
from .utils import get_index_path_from_model, load_hubert
|
||||
@@ -62,7 +62,7 @@ class VC:
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
###楼下不这么折腾清理不干净
|
||||
self.net_g, self.cpt = get_synthesizer_ckpt(
|
||||
self.net_g, self.cpt = get_synthesizer(
|
||||
self.cpt, self.config.device
|
||||
)
|
||||
self.if_f0 = self.cpt.get("f0", 1)
|
||||
@@ -88,7 +88,7 @@ class VC:
|
||||
person = f'{os.getenv("weight_root")}/{sid}'
|
||||
logger.info(f"Loading: {person}")
|
||||
|
||||
self.net_g, self.cpt = get_synthesizer(person, self.config.device)
|
||||
self.net_g, self.cpt = load_synthesizer(person, self.config.device)
|
||||
self.tgt_sr = self.cpt["config"][-1]
|
||||
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
||||
self.if_f0 = self.cpt.get("f0", 1)
|
||||
|
||||
0
rvc/f0/__init__.py
Normal file
0
rvc/f0/__init__.py
Normal file
@@ -1,14 +1,13 @@
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from fairseq.checkpoint_utils import load_model_ensemble_and_task
|
||||
from fairseq.utils import index_put
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# from fairseq.data.data_utils import compute_mask_indices
|
||||
from fairseq.utils import index_put
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
||||
@@ -263,7 +262,7 @@ def apply_mask(self, x, padding_mask, target_list):
|
||||
return x, mask_indices
|
||||
|
||||
|
||||
def get_hubert_model(
|
||||
def get_hubert(
|
||||
model_path="assets/hubert/hubert_base.pt", device=torch.device("cpu")
|
||||
):
|
||||
models, _, _ = load_model_ensemble_and_task(
|
||||
@@ -5,7 +5,7 @@ import librosa
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
from .f0predictors import (
|
||||
from .f0 import (
|
||||
PMF0Predictor,
|
||||
HarvestF0Predictor,
|
||||
DioF0Predictor,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from rvc.layers.synthesizers import SynthesizerTrnMsNSFsid
|
||||
from .layers.synthesizers import SynthesizerTrnMsNSFsid
|
||||
|
||||
|
||||
def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
|
||||
def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")):
|
||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||
if_f0 = cpt.get("f0", 1)
|
||||
version = cpt.get("version", "v1")
|
||||
@@ -24,8 +26,9 @@ def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
|
||||
return net_g, cpt
|
||||
|
||||
|
||||
def get_synthesizer(pth_path, device=torch.device("cpu")):
|
||||
return get_synthesizer_ckpt(
|
||||
def load_synthesizer(
|
||||
pth_path: torch.serialization.FILE_LIKE, device=torch.device("cpu")):
|
||||
return get_synthesizer(
|
||||
torch.load(pth_path, map_location=torch.device("cpu")),
|
||||
device,
|
||||
)
|
||||
Reference in New Issue
Block a user