1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 17:20:25 +08:00
Files
Retrieval-based-Voice-Conve…/infer/lib/rtrvc.py
github-actions[bot] 89f7fa25cc chore(format): run black on dev (#102)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-11-29 00:36:19 +09:00

274 lines
10 KiB
Python

from io import BytesIO
import os
from typing import Union, Literal, Optional
from pathlib import Path
import fairseq
import faiss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.transforms import Resample
from rvc.f0 import Generator
from rvc.synthesizer import load_synthesizer
class RVC:
def __init__(
self,
key: Union[int, float],
formant: Union[int, float],
pth_path: torch.serialization.FILE_LIKE,
index_path: str,
index_rate: Union[int, float],
n_cpu: int = os.cpu_count(),
device: str = "cpu",
use_jit: bool = False,
is_half: bool = False,
is_dml: bool = False,
) -> None:
if is_dml:
def forward_dml(ctx, x, scale):
ctx.scale = scale
res = x.clone().detach()
return res
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
self.device = device
self.f0_up_key = key
self.formant_shift = formant
self.sr = 16000 # hubert sampling rate
self.window = 160 # hop length
self.f0_min = 50
self.f0_max = 1100
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
self.n_cpu = n_cpu
self.use_jit = use_jit
self.is_half = is_half
if index_rate > 0:
self.index = faiss.read_index(index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
self.pth_path = pth_path
self.index_path = index_path
self.index_rate = index_rate
self.cache_pitch: torch.Tensor = torch.zeros(
1024, device=self.device, dtype=torch.long
)
self.cache_pitchf = torch.zeros(1024, device=self.device, dtype=torch.float32)
self.resample_kernel = {}
self.f0_gen = Generator(
Path(os.environ["rmvpe_root"]), is_half, 0, device, self.window, self.sr
)
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"],
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(self.device)
if self.is_half:
hubert_model = hubert_model.half()
else:
hubert_model = hubert_model.float()
hubert_model.eval()
self.hubert = hubert_model
self.net_g: Optional[nn.Module] = None
def set_default_model():
self.net_g, cpt = load_synthesizer(self.pth_path, self.device)
self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
if self.is_half:
self.net_g = self.net_g.half()
else:
self.net_g = self.net_g.float()
def set_jit_model():
from rvc.jit import get_jit_model
from rvc.synthesizer import synthesizer_jit_export
cpt = get_jit_model(self.pth_path, self.is_half, synthesizer_jit_export)
self.tgt_sr = cpt["config"][-1]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
self.net_g = torch.jit.load(BytesIO(cpt["model"]), map_location=self.device)
self.net_g.infer = self.net_g.forward
self.net_g.eval().to(self.device)
if (
self.use_jit
and not is_dml
and not (self.is_half and "cpu" in str(self.device))
):
set_jit_model()
else:
set_default_model()
def set_key(self, new_key):
self.f0_up_key = new_key
def set_formant(self, new_formant):
self.formant_shift = new_formant
def set_index_rate(self, new_index_rate):
if new_index_rate > 0 and self.index_rate <= 0:
self.index = faiss.read_index(self.index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
self.index_rate = new_index_rate
def infer(
self,
input_wav: torch.Tensor,
block_frame_16k: int,
skip_head: int,
return_length: int,
f0method: Union[tuple, str],
protect: float = 1.0,
) -> np.ndarray:
with torch.no_grad():
if self.is_half:
feats = input_wav.half()
else:
feats = input_wav.float()
feats = feats.to(self.device)
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
feats = feats.view(1, -1)
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
inputs = {
"source": feats,
"padding_mask": padding_mask,
"output_layer": 9 if self.version == "v1" else 12,
}
logits = self.hubert.extract_features(**inputs)
feats = (
self.hubert.final_proj(logits[0]) if self.version == "v1" else logits[0]
)
feats = torch.cat((feats, feats[:, -1:, :]), 1)
if protect < 0.5 and self.if_f0 == 1:
feats0 = feats.clone()
try:
if hasattr(self, "index") and self.index_rate > 0:
npy = feats[0][skip_head // 2 :].cpu().numpy()
if self.is_half:
npy = npy.astype("float32")
score, ix = self.index.search(npy, k=8)
if (ix >= 0).all():
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(
self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1
)
if self.is_half:
npy = npy.astype("float16")
feats[0][skip_head // 2 :] = (
torch.from_numpy(npy).unsqueeze(0).to(self.device)
* self.index_rate
+ (1 - self.index_rate) * feats[0][skip_head // 2 :]
)
except:
pass
p_len = input_wav.shape[0] // self.window
factor = pow(2, self.formant_shift / 12)
return_length2 = int(np.ceil(return_length * factor))
cache_pitch = cache_pitchf = None
pitch = pitchf = None
if isinstance(f0method, tuple):
pitch, pitchf = f0method
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
elif self.if_f0 == 1:
f0_extractor_frame = block_frame_16k + 800
if f0method == "rmvpe":
f0_extractor_frame = (
5120 * ((f0_extractor_frame - 1) // 5120 + 1) - self.window
)
pitch, pitchf = self._get_f0(
input_wav[-f0_extractor_frame:],
self.f0_up_key - self.formant_shift,
method=f0method,
)
shift = block_frame_16k // self.window
self.cache_pitch[:-shift] = self.cache_pitch[shift:].clone()
self.cache_pitchf[:-shift] = self.cache_pitchf[shift:].clone()
self.cache_pitch[4 - pitch.shape[0] :] = pitch[3:-1]
self.cache_pitchf[4 - pitch.shape[0] :] = pitchf[3:-1]
cache_pitch = self.cache_pitch[None, -p_len:]
cache_pitchf = (
self.cache_pitchf[None, -p_len:] * return_length2 / return_length
)
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
feats = feats[:, :p_len, :]
if protect < 0.5 and pitch is not None and pitchf is not None:
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
0, 2, 1
)
feats0 = feats0[:, :p_len, :]
pitchff = pitchf.clone()
pitchff[pitchf > 0] = 1
pitchff[pitchf < 1] = protect
pitchff = pitchff.unsqueeze(-1)
feats = feats * pitchff + feats0 * (1 - pitchff)
feats = feats.to(feats0.dtype)
p_len = torch.LongTensor([p_len]).to(self.device)
sid = torch.LongTensor([0]).to(self.device)
with torch.no_grad():
infered_audio = (
self.net_g.infer(
feats,
p_len,
sid,
pitch=cache_pitch,
pitchf=cache_pitchf,
skip_head=skip_head,
return_length=return_length,
return_length2=return_length2,
)
.squeeze(1)
.float()
)
upp_res = int(np.floor(factor * self.tgt_sr // 100))
if upp_res != self.tgt_sr // 100:
if upp_res not in self.resample_kernel:
self.resample_kernel[upp_res] = Resample(
orig_freq=upp_res,
new_freq=self.tgt_sr // 100,
dtype=torch.float32,
).to(self.device)
infered_audio = self.resample_kernel[upp_res](
infered_audio[:, : return_length * upp_res]
)
return infered_audio.squeeze()
def _get_f0(
self,
x: torch.Tensor,
f0_up_key: Union[int, float],
filter_radius: Optional[Union[int, float]] = None,
method: Literal["crepe", "rmvpe", "fcpe", "pm", "harvest", "dio"] = "fcpe",
):
c, f = self.f0_gen.calculate(x, None, f0_up_key, method, filter_radius)
if not torch.is_tensor(c):
c = torch.from_numpy(c)
if not torch.is_tensor(f):
f = torch.from_numpy(f)
return c.long().to(self.device), f.float().to(self.device)