1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-08 03:55:47 +08:00

optimize(crepe): move crepe into rvc.f0

This commit is contained in:
源文雨
2024-06-14 14:29:36 +09:00
parent f79b925ee2
commit e298fde29c
7 changed files with 106 additions and 54 deletions

View File

@@ -292,22 +292,16 @@ class RVC:
self.device self.device
): ###不支持dmlcpu又太慢用不成拿fcpe顶替 ): ###不支持dmlcpu又太慢用不成拿fcpe顶替
return self.get_f0(x, f0_up_key, 1, "fcpe") return self.get_f0(x, f0_up_key, 1, "fcpe")
# printt("using crepe,device:%s"%self.device) if hasattr(self, "model_crepe") == False:
f0, pd = torchcrepe.predict( from rvc.f0 import CRePE
x.unsqueeze(0).float(), self.model_crepe = CRePE(
16000, 160,
160, self.f0_min,
self.f0_min, self.f0_max,
self.f0_max, 16000,
"full", self.device,
batch_size=512, )
# device=self.device if self.device.type!="privateuseone" else "cpu",###crepe不用半精度全部是全精度所以不愁###cpu延迟高到没法用 f0 = self.model_crepe.compute_f0(x)
device=self.device,
return_periodicity=True,
)
pd = torchcrepe.filter.median(pd, 3)
f0 = torchcrepe.filter.mean(f0, 3)
f0[pd < 0.1] = 0
f0 *= pow(2, f0_up_key / 12) f0 *= pow(2, f0_up_key / 12)
return self.get_f0_post(f0) return self.get_f0_post(f0)

View File

@@ -12,10 +12,9 @@ import librosa
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchcrepe
from scipy import signal from scipy import signal
from rvc.f0 import PM, Harvest, RMVPE from rvc.f0 import PM, Harvest, RMVPE, CRePE, Dio
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@@ -81,31 +80,24 @@ class Pipeline(object):
if not hasattr(self, "pm"): if not hasattr(self, "pm"):
self.pm = PM(self.window, f0_min, f0_max, self.sr) self.pm = PM(self.window, f0_min, f0_max, self.sr)
f0 = self.pm.compute_f0(x, p_len=p_len) f0 = self.pm.compute_f0(x, p_len=p_len)
if f0_method == "dio":
if not hasattr(self, "dio"):
self.dio = Dio(self.window, f0_min, f0_max, self.sr)
f0 = self.dio.compute_f0(x, p_len=p_len)
elif f0_method == "harvest": elif f0_method == "harvest":
if not hasattr(self, "harvest"): if not hasattr(self, "harvest"):
self.harvest = Harvest(self.window, f0_min, f0_max, self.sr) self.harvest = Harvest(self.window, f0_min, f0_max, self.sr)
f0 = self.harvest.compute_f0(x, p_len=p_len, filter_radius=filter_radius) f0 = self.harvest.compute_f0(x, p_len=p_len, filter_radius=filter_radius)
elif f0_method == "crepe": elif f0_method == "crepe":
model = "full" if not hasattr(self, "crepe"):
# Pick a batch size that doesn't cause memory errors on your gpu self.crepe = CRePE(
batch_size = 512 self.window,
# Compute pitch using first gpu f0_min,
audio = torch.tensor(np.copy(x))[None].float() f0_max,
f0, pd = torchcrepe.predict( self.sr,
audio, self.device,
self.sr, )
self.window, f0 = self.crepe.compute_f0(x, p_len=p_len)
f0_min,
f0_max,
model,
batch_size=batch_size,
device=self.device,
return_periodicity=True,
)
pd = torchcrepe.filter.median(pd, 3)
f0 = torchcrepe.filter.mean(f0, 3)
f0[pd < 0.1] = 0
f0 = f0[0].cpu().numpy()
elif f0_method == "rmvpe": elif f0_method == "rmvpe":
if not hasattr(self, "rmvpe"): if not hasattr(self, "rmvpe"):
logger.info( logger.info(
@@ -117,7 +109,7 @@ class Pipeline(object):
device=self.device, device=self.device,
# use_jit=self.config.use_jit, # use_jit=self.config.use_jit,
) )
f0 = self.rmvpe.compute_f0(x, filter_radius=0.03) f0 = self.rmvpe.compute_f0(x, p_len=p_len, filter_radius=0.03)
if "privateuseone" in str(self.device): # clean ortruntime memory if "privateuseone" in str(self.device): # clean ortruntime memory
del self.rmvpe.model del self.rmvpe.model

View File

@@ -1,5 +1,6 @@
from .f0 import F0Predictor from .f0 import F0Predictor
from .crepe import CRePE
from .dio import Dio from .dio import Dio
from .harvest import Harvest from .harvest import Harvest
from .pm import PM from .pm import PM

52
rvc/f0/crepe.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import Any, Optional, Union
import numpy as np
import torch
import torchcrepe
from .f0 import F0Predictor
class CRePE(F0Predictor):
def __init__(
self,
hop_length=512,
f0_min=50,
f0_max=1100,
sampling_rate=44100,
device="cpu",
):
super().__init__(
hop_length,
f0_min,
f0_max,
sampling_rate,
device,
)
def compute_f0(
self,
wav: np.ndarray[Any, np.dtype],
p_len: Optional[int] = None,
filter_radius: Optional[Union[int, float]] = None,
):
if p_len is None:
p_len = wav.shape[0] // self.hop_length
# Pick a batch size that doesn't cause memory errors on your gpu
batch_size = 512
# Compute pitch using device 'device'
f0, pd = torchcrepe.predict(
torch.tensor(np.copy(wav))[None].float().to(self.device),
self.sampling_rate,
self.hop_length,
self.f0_min,
self.f0_max,
batch_size=batch_size,
device=self.device,
return_periodicity=True,
)
pd = torchcrepe.filter.median(pd, 3)
f0 = torchcrepe.filter.mean(f0, 3)
f0[pd < 0.1] = 0
f0 = f0[0].cpu().numpy()
return self._interpolate_f0(self._resize_f0(f0, p_len))[0]

View File

@@ -1,14 +1,25 @@
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch
import numpy as np import numpy as np
class F0Predictor(object): class F0Predictor(object):
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100): def __init__(
self,
hop_length=512,
f0_min=50,
f0_max=1100,
sampling_rate=44100,
device: Optional[str] = None,
):
self.hop_length = hop_length self.hop_length = hop_length
self.f0_min = f0_min self.f0_min = f0_min
self.f0_max = f0_max self.f0_max = f0_max
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.device = device
def compute_f0( def compute_f0(
self, self,

View File

@@ -26,16 +26,18 @@ class RMVPE(F0Predictor):
f0_max = 8000 f0_max = 8000
sampling_rate = 16000 sampling_rate = 16000
super().__init__(hop_length, f0_min, f0_max, sampling_rate) super().__init__(
hop_length,
f0_min,
f0_max,
sampling_rate,
device,
)
self.is_half = is_half self.is_half = is_half
cents_mapping = 20 * np.arange(360) + 1997.3794084376191 cents_mapping = 20 * np.arange(360) + 1997.3794084376191
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368 self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.device = device
self.mel_extractor = MelSpectrogram( self.mel_extractor = MelSpectrogram(
is_half=is_half, is_half=is_half,
n_mel_channels=128, n_mel_channels=128,
@@ -44,10 +46,10 @@ class RMVPE(F0Predictor):
hop_length=hop_length, hop_length=hop_length,
mel_fmin=f0_min, mel_fmin=f0_min,
mel_fmax=f0_max, mel_fmax=f0_max,
device=device, device=self.device,
).to(device) ).to(self.device)
if "privateuseone" in str(device): if "privateuseone" in str(self.device):
import onnxruntime as ort import onnxruntime as ort
self.model = ort.InferenceSession( self.model = ort.InferenceSession(
@@ -73,11 +75,11 @@ class RMVPE(F0Predictor):
mode="script", mode="script",
inputs_path=None, inputs_path=None,
save_path=jit_model_path, save_path=jit_model_path,
device=device, device=self.device,
is_half=is_half, is_half=is_half,
) )
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=device) model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
return model return model
def get_default_model(): def get_default_model():
@@ -99,7 +101,7 @@ class RMVPE(F0Predictor):
else: else:
self.model = get_default_model() self.model = get_default_model()
self.model = self.model.to(device) self.model = self.model.to(self.device)
def compute_f0( def compute_f0(
self, self,

6
web.py
View File

@@ -861,9 +861,9 @@ with gr.Blocks(title="RVC WebUI") as app:
"Select the pitch extraction algorithm ('pm': faster extraction but lower-quality speech; 'harvest': better bass but extremely slow; 'crepe': better quality but GPU intensive), 'rmvpe': best quality, and little GPU requirement" "Select the pitch extraction algorithm ('pm': faster extraction but lower-quality speech; 'harvest': better bass but extremely slow; 'crepe': better quality but GPU intensive), 'rmvpe': best quality, and little GPU requirement"
), ),
choices=( choices=(
["pm", "harvest", "crepe", "rmvpe"] ["pm", "dio", "harvest", "rmvpe"]
if config.dml == False if config.dml
else ["pm", "harvest", "rmvpe"] else ["pm", "dio", "harvest", "crepe", "rmvpe"]
), ),
value="rmvpe", value="rmvpe",
interactive=True, interactive=True,