From 54f7ae097d06ccfda088684430e407d7f437e240 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 12 Jun 2024 00:03:26 +0900 Subject: [PATCH] optimize(jit): move hubert & synthesizer into rvc --- infer/lib/jit/__init__.py | 1 - infer/lib/jit/utils.py | 12 ++++++------ infer/lib/rmvpe.py | 10 +++++----- infer/lib/rtrvc.py | 4 +++- infer/modules/vc/hash.py | 4 ++-- infer/modules/vc/modules.py | 6 +++--- rvc/f0/__init__.py | 0 {infer/lib/jit => rvc}/hubert.py | 7 +++---- rvc/onnx/{f0predictors => f0}/__init__.py | 0 rvc/onnx/{f0predictors => f0}/dio.py | 0 rvc/onnx/{f0predictors => f0}/f0.py | 0 rvc/onnx/{f0predictors => f0}/harvest.py | 0 rvc/onnx/{f0predictors => f0}/pm.py | 0 rvc/onnx/infer.py | 2 +- {infer/lib/jit => rvc}/synthesizer.py | 11 +++++++---- 15 files changed, 30 insertions(+), 27 deletions(-) create mode 100644 rvc/f0/__init__.py rename {infer/lib/jit => rvc}/hubert.py (99%) rename rvc/onnx/{f0predictors => f0}/__init__.py (100%) rename rvc/onnx/{f0predictors => f0}/dio.py (100%) rename rvc/onnx/{f0predictors => f0}/f0.py (100%) rename rvc/onnx/{f0predictors => f0}/harvest.py (100%) rename rvc/onnx/{f0predictors => f0}/pm.py (100%) rename {infer/lib/jit => rvc}/synthesizer.py (69%) diff --git a/infer/lib/jit/__init__.py b/infer/lib/jit/__init__.py index 179fbe0..1625df5 100644 --- a/infer/lib/jit/__init__.py +++ b/infer/lib/jit/__init__.py @@ -1,2 +1 @@ from .utils import load, rmvpe_jit_export, synthesizer_jit_export -from .synthesizer import get_synthesizer, get_synthesizer_ckpt diff --git a/infer/lib/jit/utils.py b/infer/lib/jit/utils.py index ea73ba1..d9896fa 100644 --- a/infer/lib/jit/utils.py +++ b/infer/lib/jit/utils.py @@ -44,18 +44,18 @@ def to_jit_model( ): model = None if model_type.lower() == "synthesizer": - from .synthesizer import get_synthesizer + from rvc.synthesizer import load_synthesizer - model, _ = get_synthesizer(model_path, device) + model, _ = load_synthesizer(model_path, device) model.forward = model.infer elif model_type.lower() == "rmvpe": from .rmvpe import get_rmvpe model = get_rmvpe(model_path, device) elif model_type.lower() == "hubert": - from .hubert import get_hubert_model + from rvc.hubert import get_hubert - model = get_hubert_model(model_path, device) + model = get_hubert(model_path, device) model.forward = model.infer else: raise ValueError(f"No model type named {model_type}") @@ -147,9 +147,9 @@ def synthesizer_jit_export( save_path += ".half.jit" if is_half else ".jit" if "cuda" in str(device) and ":" not in str(device): device = torch.device("cuda:0") - from .synthesizer import get_synthesizer + from rvc.synthesizer import load_synthesizer - model, cpt = get_synthesizer(model_path, device) + model, cpt = load_synthesizer(model_path, device) assert isinstance(cpt, dict) model.forward = model.infer inputs = None diff --git a/infer/lib/rmvpe.py b/infer/lib/rmvpe.py index 86c6899..ba5a09b 100644 --- a/infer/lib/rmvpe.py +++ b/infer/lib/rmvpe.py @@ -518,16 +518,15 @@ class RMVPE: def get_jit_model(): jit_model_path = model_path.rstrip(".pth") jit_model_path += ".half.jit" if is_half else ".jit" - reload = False + ckpt = None if os.path.exists(jit_model_path): ckpt = jit.load(jit_model_path) model_device = ckpt["device"] if model_device != str(self.device): - reload = True - else: - reload = True + del ckpt + ckpt = None - if reload: + if ckpt is None: ckpt = jit.rmvpe_jit_export( model_path=model_path, mode="script", @@ -536,6 +535,7 @@ class RMVPE: device=device, is_half=is_half, ) + model = torch.jit.load(BytesIO(ckpt["model"]), map_location=device) return model diff --git a/infer/lib/rtrvc.py b/infer/lib/rtrvc.py index 7fc9dae..781aa0d 100644 --- a/infer/lib/rtrvc.py +++ b/infer/lib/rtrvc.py @@ -16,6 +16,8 @@ import torch.nn.functional as F import torchcrepe from torchaudio.transforms import Resample +from rvc.synthesizer import load_synthesizer + now_dir = os.getcwd() sys.path.append(now_dir) from multiprocessing import Manager as M @@ -113,7 +115,7 @@ class RVC: self.net_g: nn.Module = None def set_default_model(): - self.net_g, cpt = jit.get_synthesizer(self.pth_path, self.device) + self.net_g, cpt = load_synthesizer(self.pth_path, self.device) self.tgt_sr = cpt["config"][-1] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] self.if_f0 = cpt.get("f0", 1) diff --git a/infer/modules/vc/hash.py b/infer/modules/vc/hash.py index 22b8798..d8d7e7c 100644 --- a/infer/modules/vc/hash.py +++ b/infer/modules/vc/hash.py @@ -6,7 +6,7 @@ from scipy.fft import fft from pybase16384 import encode_to_string, decode_from_string from configs import CPUConfig, singleton_variable -from infer.lib.jit import get_synthesizer_ckpt +from rvc.synthesizer import get_synthesizer from .pipeline import Pipeline from .utils import load_hubert @@ -132,7 +132,7 @@ def model_hash_ckpt(cpt): config = CPUConfig() with TorchSeedContext(114514): - net_g, cpt = get_synthesizer_ckpt(cpt, config.device) + net_g, cpt = get_synthesizer(cpt, config.device) tgt_sr = cpt["config"][-1] if_f0 = cpt.get("f0", 1) version = cpt.get("version", "v1") diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index d33a2bb..dd2e42b 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -10,7 +10,7 @@ import torch from io import BytesIO from infer.lib.audio import load_audio, wav2 -from infer.lib.jit import get_synthesizer_ckpt, get_synthesizer +from rvc.synthesizer import get_synthesizer, load_synthesizer from .info import show_model_info from .pipeline import Pipeline from .utils import get_index_path_from_model, load_hubert @@ -62,7 +62,7 @@ class VC: elif torch.backends.mps.is_available(): torch.mps.empty_cache() ###楼下不这么折腾清理不干净 - self.net_g, self.cpt = get_synthesizer_ckpt( + self.net_g, self.cpt = get_synthesizer( self.cpt, self.config.device ) self.if_f0 = self.cpt.get("f0", 1) @@ -88,7 +88,7 @@ class VC: person = f'{os.getenv("weight_root")}/{sid}' logger.info(f"Loading: {person}") - self.net_g, self.cpt = get_synthesizer(person, self.config.device) + self.net_g, self.cpt = load_synthesizer(person, self.config.device) self.tgt_sr = self.cpt["config"][-1] self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk self.if_f0 = self.cpt.get("f0", 1) diff --git a/rvc/f0/__init__.py b/rvc/f0/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/infer/lib/jit/hubert.py b/rvc/hubert.py similarity index 99% rename from infer/lib/jit/hubert.py rename to rvc/hubert.py index aec7132..dbb4769 100644 --- a/infer/lib/jit/hubert.py +++ b/rvc/hubert.py @@ -1,14 +1,13 @@ import math import random from typing import Optional, Tuple + from fairseq.checkpoint_utils import load_model_ensemble_and_task +from fairseq.utils import index_put import numpy as np import torch import torch.nn.functional as F -# from fairseq.data.data_utils import compute_mask_indices -from fairseq.utils import index_put - # @torch.jit.script def pad_to_multiple(x, multiple, dim=-1, value=0): @@ -263,7 +262,7 @@ def apply_mask(self, x, padding_mask, target_list): return x, mask_indices -def get_hubert_model( +def get_hubert( model_path="assets/hubert/hubert_base.pt", device=torch.device("cpu") ): models, _, _ = load_model_ensemble_and_task( diff --git a/rvc/onnx/f0predictors/__init__.py b/rvc/onnx/f0/__init__.py similarity index 100% rename from rvc/onnx/f0predictors/__init__.py rename to rvc/onnx/f0/__init__.py diff --git a/rvc/onnx/f0predictors/dio.py b/rvc/onnx/f0/dio.py similarity index 100% rename from rvc/onnx/f0predictors/dio.py rename to rvc/onnx/f0/dio.py diff --git a/rvc/onnx/f0predictors/f0.py b/rvc/onnx/f0/f0.py similarity index 100% rename from rvc/onnx/f0predictors/f0.py rename to rvc/onnx/f0/f0.py diff --git a/rvc/onnx/f0predictors/harvest.py b/rvc/onnx/f0/harvest.py similarity index 100% rename from rvc/onnx/f0predictors/harvest.py rename to rvc/onnx/f0/harvest.py diff --git a/rvc/onnx/f0predictors/pm.py b/rvc/onnx/f0/pm.py similarity index 100% rename from rvc/onnx/f0predictors/pm.py rename to rvc/onnx/f0/pm.py diff --git a/rvc/onnx/infer.py b/rvc/onnx/infer.py index b14b590..a8e5a4f 100644 --- a/rvc/onnx/infer.py +++ b/rvc/onnx/infer.py @@ -5,7 +5,7 @@ import librosa import numpy as np import onnxruntime -from .f0predictors import ( +from .f0 import ( PMF0Predictor, HarvestF0Predictor, DioF0Predictor, diff --git a/infer/lib/jit/synthesizer.py b/rvc/synthesizer.py similarity index 69% rename from infer/lib/jit/synthesizer.py rename to rvc/synthesizer.py index 6f40ffb..9ff9266 100644 --- a/infer/lib/jit/synthesizer.py +++ b/rvc/synthesizer.py @@ -1,9 +1,11 @@ +from collections import OrderedDict + import torch -from rvc.layers.synthesizers import SynthesizerTrnMsNSFsid +from .layers.synthesizers import SynthesizerTrnMsNSFsid -def get_synthesizer_ckpt(cpt, device=torch.device("cpu")): +def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")): cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] if_f0 = cpt.get("f0", 1) version = cpt.get("version", "v1") @@ -24,8 +26,9 @@ def get_synthesizer_ckpt(cpt, device=torch.device("cpu")): return net_g, cpt -def get_synthesizer(pth_path, device=torch.device("cpu")): - return get_synthesizer_ckpt( +def load_synthesizer( + pth_path: torch.serialization.FILE_LIKE, device=torch.device("cpu")): + return get_synthesizer( torch.load(pth_path, map_location=torch.device("cpu")), device, )