mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-07 19:40:44 +08:00
optimize(crepe): move crepe into rvc.f0
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from .f0 import F0Predictor
|
||||
|
||||
from .crepe import CRePE
|
||||
from .dio import Dio
|
||||
from .harvest import Harvest
|
||||
from .pm import PM
|
||||
|
||||
52
rvc/f0/crepe.py
Normal file
52
rvc/f0/crepe.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchcrepe
|
||||
|
||||
from .f0 import F0Predictor
|
||||
|
||||
|
||||
class CRePE(F0Predictor):
|
||||
def __init__(
|
||||
self,
|
||||
hop_length=512,
|
||||
f0_min=50,
|
||||
f0_max=1100,
|
||||
sampling_rate=44100,
|
||||
device="cpu",
|
||||
):
|
||||
super().__init__(
|
||||
hop_length,
|
||||
f0_min,
|
||||
f0_max,
|
||||
sampling_rate,
|
||||
device,
|
||||
)
|
||||
|
||||
def compute_f0(
|
||||
self,
|
||||
wav: np.ndarray[Any, np.dtype],
|
||||
p_len: Optional[int] = None,
|
||||
filter_radius: Optional[Union[int, float]] = None,
|
||||
):
|
||||
if p_len is None:
|
||||
p_len = wav.shape[0] // self.hop_length
|
||||
# Pick a batch size that doesn't cause memory errors on your gpu
|
||||
batch_size = 512
|
||||
# Compute pitch using device 'device'
|
||||
f0, pd = torchcrepe.predict(
|
||||
torch.tensor(np.copy(wav))[None].float().to(self.device),
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.f0_min,
|
||||
self.f0_max,
|
||||
batch_size=batch_size,
|
||||
device=self.device,
|
||||
return_periodicity=True,
|
||||
)
|
||||
pd = torchcrepe.filter.median(pd, 3)
|
||||
f0 = torchcrepe.filter.mean(f0, 3)
|
||||
f0[pd < 0.1] = 0
|
||||
f0 = f0[0].cpu().numpy()
|
||||
return self._interpolate_f0(self._resize_f0(f0, p_len))[0]
|
||||
13
rvc/f0/f0.py
13
rvc/f0/f0.py
@@ -1,14 +1,25 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class F0Predictor(object):
|
||||
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
|
||||
def __init__(
|
||||
self,
|
||||
hop_length=512,
|
||||
f0_min=50,
|
||||
f0_max=1100,
|
||||
sampling_rate=44100,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
self.hop_length = hop_length
|
||||
self.f0_min = f0_min
|
||||
self.f0_max = f0_max
|
||||
self.sampling_rate = sampling_rate
|
||||
if device is None:
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
|
||||
def compute_f0(
|
||||
self,
|
||||
|
||||
@@ -26,16 +26,18 @@ class RMVPE(F0Predictor):
|
||||
f0_max = 8000
|
||||
sampling_rate = 16000
|
||||
|
||||
super().__init__(hop_length, f0_min, f0_max, sampling_rate)
|
||||
super().__init__(
|
||||
hop_length,
|
||||
f0_min,
|
||||
f0_max,
|
||||
sampling_rate,
|
||||
device,
|
||||
)
|
||||
|
||||
self.is_half = is_half
|
||||
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
|
||||
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
|
||||
|
||||
if device is None:
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
self.device = device
|
||||
|
||||
self.mel_extractor = MelSpectrogram(
|
||||
is_half=is_half,
|
||||
n_mel_channels=128,
|
||||
@@ -44,10 +46,10 @@ class RMVPE(F0Predictor):
|
||||
hop_length=hop_length,
|
||||
mel_fmin=f0_min,
|
||||
mel_fmax=f0_max,
|
||||
device=device,
|
||||
).to(device)
|
||||
device=self.device,
|
||||
).to(self.device)
|
||||
|
||||
if "privateuseone" in str(device):
|
||||
if "privateuseone" in str(self.device):
|
||||
import onnxruntime as ort
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
@@ -73,11 +75,11 @@ class RMVPE(F0Predictor):
|
||||
mode="script",
|
||||
inputs_path=None,
|
||||
save_path=jit_model_path,
|
||||
device=device,
|
||||
device=self.device,
|
||||
is_half=is_half,
|
||||
)
|
||||
|
||||
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=device)
|
||||
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
|
||||
return model
|
||||
|
||||
def get_default_model():
|
||||
@@ -99,7 +101,7 @@ class RMVPE(F0Predictor):
|
||||
else:
|
||||
self.model = get_default_model()
|
||||
|
||||
self.model = self.model.to(device)
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
def compute_f0(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user