1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-08 20:10:44 +08:00

optimize(rmvpe): move rmvpe into rvc.f0

This commit is contained in:
源文雨
2024-06-13 00:42:42 +09:00
parent 77b371d615
commit 8ac5597a3f
12 changed files with 96 additions and 95 deletions

View File

@@ -313,7 +313,7 @@ class RVC:
def get_f0_rmvpe(self, x, f0_up_key): def get_f0_rmvpe(self, x, f0_up_key):
if hasattr(self, "model_rmvpe") == False: if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE from rvc.f0 import RMVPE
printt("Loading rmvpe model") printt("Loading rmvpe model")
self.model_rmvpe = RMVPE( self.model_rmvpe = RMVPE(
@@ -322,7 +322,7 @@ class RVC:
device=self.device, device=self.device,
use_jit=self.config.use_jit, use_jit=self.config.use_jit,
) )
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03) f0 = self.model_rmvpe.compute_f0(x, thred=0.03)
f0 *= pow(2, f0_up_key / 12) f0 *= pow(2, f0_up_key / 12)
return self.get_f0_post(f0) return self.get_f0_post(f0)

View File

@@ -83,13 +83,13 @@ class FeatureInput(object):
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs) f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
elif f0_method == "rmvpe": elif f0_method == "rmvpe":
if hasattr(self, "model_rmvpe") == False: if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE from rvc.f0.rmvpe import RMVPE
print("Loading rmvpe model") print("Loading rmvpe model")
self.model_rmvpe = RMVPE( self.model_rmvpe = RMVPE(
"assets/rmvpe/rmvpe.pt", is_half=False, device="cpu" "assets/rmvpe/rmvpe.pt", is_half=False, device="cpu"
) )
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03) f0 = self.model_rmvpe.compute_f0(x, filter_radius=0.03)
return f0 return f0
def coarse_f0(self, f0): def coarse_f0(self, f0):

View File

@@ -46,13 +46,13 @@ class FeatureInput(object):
# p_len = x.shape[0] // self.hop # p_len = x.shape[0] // self.hop
if f0_method == "rmvpe": if f0_method == "rmvpe":
if hasattr(self, "model_rmvpe") == False: if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE from rvc.f0.rmvpe import RMVPE
print("Loading rmvpe model") print("Loading rmvpe model")
self.model_rmvpe = RMVPE( self.model_rmvpe = RMVPE(
"assets/rmvpe/rmvpe.pt", is_half=is_half, device="cuda" "assets/rmvpe/rmvpe.pt", is_half=is_half, device="cuda"
) )
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03) f0 = self.model_rmvpe.compute_f0(x, filter_radius=0.03)
return f0 return f0
def coarse_f0(self, f0): def coarse_f0(self, f0):

View File

@@ -44,13 +44,13 @@ class FeatureInput(object):
# p_len = x.shape[0] // self.hop # p_len = x.shape[0] // self.hop
if f0_method == "rmvpe": if f0_method == "rmvpe":
if hasattr(self, "model_rmvpe") == False: if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE from rvc.f0.rmvpe import RMVPE
print("Loading rmvpe model") print("Loading rmvpe model")
self.model_rmvpe = RMVPE( self.model_rmvpe = RMVPE(
"assets/rmvpe/rmvpe.pt", is_half=False, device=device "assets/rmvpe/rmvpe.pt", is_half=False, device=device
) )
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03) f0 = self.model_rmvpe.compute_f0(x, filter_radius=0.03)
return f0 return f0
def coarse_f0(self, f0): def coarse_f0(self, f0):

View File

@@ -16,7 +16,7 @@ import torch.nn.functional as F
import torchcrepe import torchcrepe
from scipy import signal from scipy import signal
from rvc.f0 import PM, Harvest from rvc.f0 import PM, Harvest, RMVPE
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@@ -108,24 +108,23 @@ class Pipeline(object):
f0[pd < 0.1] = 0 f0[pd < 0.1] = 0
f0 = f0[0].cpu().numpy() f0 = f0[0].cpu().numpy()
elif f0_method == "rmvpe": elif f0_method == "rmvpe":
if not hasattr(self, "model_rmvpe"): if not hasattr(self, "rmvpe"):
from infer.lib.rmvpe import RMVPE
logger.info( logger.info(
"Loading rmvpe model %s" % "%s/rmvpe.pt" % os.environ["rmvpe_root"] "Loading rmvpe model %s" % "%s/rmvpe.pt" % os.environ["rmvpe_root"]
) )
self.model_rmvpe = RMVPE( self.rmvpe = RMVPE(
"%s/rmvpe.pt" % os.environ["rmvpe_root"], "%s/rmvpe.pt" % os.environ["rmvpe_root"],
is_half=self.is_half, is_half=self.is_half,
device=self.device, device=self.device,
# use_jit=self.config.use_jit, # use_jit=self.config.use_jit,
) )
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03) f0 = self.rmvpe.compute_f0(x, filter_radius=0.03)
if "privateuseone" in str(self.device): # clean ortruntime memory if "privateuseone" in str(self.device): # clean ortruntime memory
del self.model_rmvpe.model del self.rmvpe.model
del self.model_rmvpe del self.rmvpe
logger.info("Cleaning ortruntime memory") logger.info("Cleaning ortruntime memory")
elif f0_method == "fcpe": elif f0_method == "fcpe":
if not hasattr(self, "model_fcpe"): if not hasattr(self, "model_fcpe"):
from torchfcpe import spawn_bundled_infer_model from torchfcpe import spawn_bundled_infer_model

View File

@@ -1,4 +1,6 @@
from .f0 import F0Predictor
from .dio import Dio from .dio import Dio
from .harvest import Harvest from .harvest import Harvest
from .pm import PM from .pm import PM
from .f0 import F0Predictor from .rmvpe import RMVPE

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional from typing import Any, Optional, Union
import numpy as np import numpy as np
import pyworld import pyworld
@@ -14,7 +14,7 @@ class Dio(F0Predictor):
self, self,
wav: np.ndarray[Any, np.dtype], wav: np.ndarray[Any, np.dtype],
p_len: Optional[int] = None, p_len: Optional[int] = None,
filter_radius: Optional[int] = None, filter_radius: Optional[Union[int, float]] = None,
): ):
if p_len is None: if p_len is None:
p_len = wav.shape[0] // self.hop_length p_len = wav.shape[0] // self.hop_length

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional from typing import Any, Optional, Union
import numpy as np import numpy as np
@@ -14,7 +14,7 @@ class F0Predictor(object):
self, self,
wav: np.ndarray[Any, np.dtype], wav: np.ndarray[Any, np.dtype],
p_len: Optional[int] = None, p_len: Optional[int] = None,
filter_radius: Optional[int] = None, filter_radius: Optional[Union[int, float]] = None,
): ... ): ...
def interpolate_f0(self, f0: np.ndarray[Any, np.dtype]): def interpolate_f0(self, f0: np.ndarray[Any, np.dtype]):

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional from typing import Any, Optional, Union
import numpy as np import numpy as np
import pyworld import pyworld
@@ -15,7 +15,7 @@ class Harvest(F0Predictor):
self, self,
wav: np.ndarray[Any, np.dtype], wav: np.ndarray[Any, np.dtype],
p_len: Optional[int] = None, p_len: Optional[int] = None,
filter_radius: Optional[int] = None, filter_radius: Optional[Union[int, float]] = None,
): ):
if p_len is None: if p_len is None:
p_len = wav.shape[0] // self.hop_length p_len = wav.shape[0] // self.hop_length

View File

@@ -1,51 +1,60 @@
from io import BytesIO from io import BytesIO
import os import os
from typing import List, Optional, Tuple, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from infer.lib import jit from infer.lib import jit
import torch.nn.functional as F from .mel import MelSpectrogram
from .e2e import E2E
import logging from .f0 import F0Predictor
logger = logging.getLogger(__name__)
from rvc.f0.mel import MelSpectrogram
from rvc.f0.e2e import E2E
class RMVPE: class RMVPE(F0Predictor):
def __init__(self, model_path: str, is_half, device=None, use_jit=False): def __init__(
self.resample_kernel = {} self,
self.resample_kernel = {} model_path: str,
is_half: bool,
device: str,
use_jit=False,
):
hop_length=160
f0_min=30
f0_max=8000
sampling_rate=16000
super().__init__(hop_length, f0_min, f0_max, sampling_rate)
self.is_half = is_half self.is_half = is_half
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
if device is None: if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu" device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.device = device self.device = device
self.mel_extractor = MelSpectrogram( self.mel_extractor = MelSpectrogram(
is_half=is_half, is_half=is_half,
n_mel_channels=128, n_mel_channels=128,
sampling_rate=16000, sampling_rate=sampling_rate,
win_length=1024, win_length=1024,
hop_length=160, hop_length=hop_length,
mel_fmin=30, mel_fmin=f0_min,
mel_fmax=8000, mel_fmax=f0_max,
device=device, device=device,
).to(device) ).to(device)
if "privateuseone" in str(device): if "privateuseone" in str(device):
import onnxruntime as ort import onnxruntime as ort
ort_session = ort.InferenceSession( self.model = ort.InferenceSession(
"%s/rmvpe.onnx" % os.environ["rmvpe_root"], "%s/rmvpe.onnx" % os.environ["rmvpe_root"],
providers=["DmlExecutionProvider"], providers=["DmlExecutionProvider"],
) )
self.model = ort_session
else: else:
if str(self.device) == "cuda":
self.device = torch.device("cuda:0")
def get_jit_model(): def get_jit_model():
jit_model_path = model_path.rstrip(".pth") jit_model_path = model_path.rstrip(".pth")
jit_model_path += ".half.jit" if is_half else ".jit" jit_model_path += ".half.jit" if is_half else ".jit"
@@ -83,10 +92,6 @@ class RMVPE:
if use_jit: if use_jit:
if is_half and "cpu" in str(self.device): if is_half and "cpu" in str(self.device):
logger.warning(
"Use default rmvpe model. \
Jit is not supported on the CPU for half floating point"
)
self.model = get_default_model() self.model = get_default_model()
else: else:
self.model = get_jit_model() self.model = get_jit_model()
@@ -94,49 +99,21 @@ class RMVPE:
self.model = get_default_model() self.model = get_default_model()
self.model = self.model.to(device) self.model = self.model.to(device)
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
def mel2hidden(self, mel): def compute_f0(
with torch.no_grad(): self,
n_frames = mel.shape[-1] wav: np.ndarray[Any, np.dtype],
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames p_len: Optional[int] = None,
if n_pad > 0: filter_radius: Optional[Union[int, float]] = None,
mel = F.pad(mel, (0, n_pad), mode="constant") ):
if "privateuseone" in str(self.device): if p_len is None:
onnx_input_name = self.model.get_inputs()[0].name p_len = wav.shape[0] // self.hop_length
onnx_outputs_names = self.model.get_outputs()[0].name if not torch.is_tensor(wav):
hidden = self.model.run( wav = torch.from_numpy(wav)
[onnx_outputs_names],
input_feed={onnx_input_name: mel.cpu().numpy()},
)[0]
else:
mel = mel.half() if self.is_half else mel.float()
hidden = self.model(mel)
return hidden[:, :n_frames]
def decode(self, hidden, thred=0.03):
cents_pred = self.to_local_average_cents(hidden, threshold=thred)
f0 = 10 * (2 ** (cents_pred / 1200))
f0[f0 == 10] = 0
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
return f0
def infer_from_audio(self, audio, threshold=0.03):
# torch.cuda.synchronize()
# t0 = ttime()
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
mel = self.mel_extractor( mel = self.mel_extractor(
audio.float().to(self.device).unsqueeze(0), center=True wav.float().to(self.device).unsqueeze(0), center=True
) )
# print(123123123,mel.device.type) hidden = self._mel2hidden(mel)
# torch.cuda.synchronize()
# t1 = ttime()
hidden = self.mel2hidden(mel)
# torch.cuda.synchronize()
# t2 = ttime()
# print(234234,hidden.device.type)
if "privateuseone" not in str(self.device): if "privateuseone" not in str(self.device):
hidden = hidden.squeeze(0).cpu().numpy() hidden = hidden.squeeze(0).cpu().numpy()
else: else:
@@ -144,13 +121,11 @@ class RMVPE:
if self.is_half == True: if self.is_half == True:
hidden = hidden.astype("float32") hidden = hidden.astype("float32")
f0 = self.decode(hidden, thred=threshold) f0 = self._decode(hidden, thred=filter_radius)
# torch.cuda.synchronize()
# t3 = ttime()
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
return f0
def to_local_average_cents(self, salience, threshold=0.05): return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
def _to_local_average_cents(self, salience, threshold=0.05):
center = np.argmax(salience, axis=1) # 帧长#index center = np.argmax(salience, axis=1) # 帧长#index
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368 salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
center += 4 center += 4
@@ -169,3 +144,28 @@ class RMVPE:
maxx = np.max(salience, axis=1) # 帧长 maxx = np.max(salience, axis=1) # 帧长
devided[maxx <= threshold] = 0 devided[maxx <= threshold] = 0
return devided return devided
def _mel2hidden(self, mel):
with torch.no_grad():
n_frames = mel.shape[-1]
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
if n_pad > 0:
mel = F.pad(mel, (0, n_pad), mode="constant")
if "privateuseone" in str(self.device):
onnx_input_name = self.model.get_inputs()[0].name
onnx_outputs_names = self.model.get_outputs()[0].name
hidden = self.model.run(
[onnx_outputs_names],
input_feed={onnx_input_name: mel.cpu().numpy()},
)[0]
else:
mel = mel.half() if self.is_half else mel.float()
hidden = self.model(mel)
return hidden[:, :n_frames]
def _decode(self, hidden, thred=0.03):
cents_pred = self._to_local_average_cents(hidden, threshold=thred)
f0 = 10 * (2 ** (cents_pred / 1200))
f0[f0 == 10] = 0
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
return f0