1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00
Files
Retrieval-based-Voice-Conve…/gui.py
github-actions[bot] 3b4a546ced chore(format): run black on dev (#121)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-06-19 17:56:04 +09:00

1198 lines
53 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
from dotenv import load_dotenv
import shutil
load_dotenv()
load_dotenv("sha256.env")
os.environ["OMP_NUM_THREADS"] = "4"
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
now_dir = os.getcwd()
sys.path.append(now_dir)
import multiprocessing
flag_vc = False
def printt(strr, *args):
if len(args) == 0:
print(strr)
else:
print(strr % args)
def phase_vocoder(a, b, fade_out, fade_in):
window = torch.sqrt(fade_out * fade_in)
fa = torch.fft.rfft(a * window)
fb = torch.fft.rfft(b * window)
absab = torch.abs(fa) + torch.abs(fb)
n = a.shape[0]
if n % 2 == 0:
absab[1:-1] *= 2
else:
absab[1:] *= 2
phia = torch.angle(fa)
phib = torch.angle(fb)
deltaphase = phib - phia
deltaphase = deltaphase - 2 * np.pi * torch.floor(deltaphase / 2 / np.pi + 0.5)
w = 2 * np.pi * torch.arange(n // 2 + 1).to(a) + deltaphase
t = torch.arange(n).unsqueeze(-1).to(a) / n
result = (
a * (fade_out**2)
+ b * (fade_in**2)
+ torch.sum(absab * torch.cos(w * t + phia), -1) * window / n
)
return result
class Harvest(multiprocessing.Process):
def __init__(self, inp_q, opt_q):
multiprocessing.Process.__init__(self)
self.inp_q = inp_q
self.opt_q = opt_q
def run(self):
import numpy as np
import pyworld
while 1:
idx, x, res_f0, n_cpu, ts = self.inp_q.get()
f0, t = pyworld.harvest(
x.astype(np.double),
fs=16000,
f0_ceil=1100,
f0_floor=50,
frame_period=10,
)
res_f0[idx] = f0
if len(res_f0.keys()) >= n_cpu:
self.opt_q.put(ts)
if __name__ == "__main__":
import json
import multiprocessing
import re
import threading
import time
from multiprocessing import Queue, cpu_count
from infer.lib.audio import AudioIoProcess
from multiprocessing.shared_memory import SharedMemory
import librosa
from infer.modules.gui import TorchGate
import numpy as np
import FreeSimpleGUI as sg
import sounddevice as sd
import torch
import torch.nn.functional as F
import torchaudio.transforms as tat
import infer.lib.rtrvc as rtrvc
from i18n.i18n import I18nAuto
from configs import Config
i18n = I18nAuto()
# device = rvc_for_realtime.config.device
# device = torch.device(
# "cuda"
# if torch.cuda.is_available()
# else ("mps" if torch.backends.mps.is_available() else "cpu")
# )
current_dir = os.getcwd()
inp_q = Queue()
opt_q = Queue()
n_cpu = min(cpu_count(), 8)
for _ in range(n_cpu):
p = Harvest(inp_q, opt_q)
p.daemon = True
p.start()
class GUIConfig:
def __init__(self) -> None:
self.pth_path: str = ""
self.index_path: str = ""
self.pitch: int = 0
self.formant: float = 0.0
self.sr_type: str = "sr_model"
self.block_time: float = 0.25 # s
self.threhold: int = -60
self.crossfade_time: float = 0.05
self.extra_time: float = 2.5
self.I_noise_reduce: bool = False
self.O_noise_reduce: bool = False
self.use_pv: bool = False
self.rms_mix_rate: float = 0.0
self.index_rate: float = 0.0
self.n_cpu: int = min(n_cpu, 4)
self.f0method: str = "fcpe"
self.sg_hostapi: str = ""
self.wasapi_exclusive: bool = False
self.sg_input_device: str = ""
self.sg_output_device: str = ""
class GUI:
def __init__(self) -> None:
self.gui_config = GUIConfig()
self.config = Config()
self.function = "vc"
self.delay_time = 0
self.hostapis = None
self.input_devices = None
self.output_devices = None
self.input_devices_indices = None
self.output_devices_indices = None
self.stream = None
self.in_mem = None
self.out_mem = None
self.in_buf = None
self.out_buf = None
self.in_ptr = None
self.out_ptr = None
self.play_ptr = None
self.in_evt = None
self.stop_evt = None
self.update_devices()
self.launcher()
def check_assets(self):
global now_dir
from infer.lib.rvcmd import check_all_assets, download_all_assets
tmp = os.path.join(now_dir, "TEMP")
shutil.rmtree(tmp, ignore_errors=True)
os.makedirs(tmp, exist_ok=True)
if not check_all_assets(update=self.config.update):
if self.config.update:
download_all_assets(tmpdir=tmp)
if not check_all_assets(update=self.config.update):
printt("counld not satisfy all assets needed.")
exit(1)
def load(self):
try:
if not os.path.exists("configs/inuse/config.json"):
shutil.copy("configs/config.json", "configs/inuse/config.json")
with open("configs/inuse/config.json", "r") as j:
data = json.load(j)
data["sr_model"] = data["sr_type"] == "sr_model"
data["sr_device"] = data["sr_type"] == "sr_device"
data["pm"] = data["f0method"] == "pm"
data["dio"] = data["f0method"] == "dio"
data["harvest"] = data["f0method"] == "harvest"
data["crepe"] = data["f0method"] == "crepe"
data["rmvpe"] = data["f0method"] == "rmvpe"
data["fcpe"] = data["f0method"] == "fcpe"
if data["sg_hostapi"] in self.hostapis:
self.update_devices(hostapi_name=data["sg_hostapi"])
if (
data["sg_input_device"] not in self.input_devices
or data["sg_output_device"] not in self.output_devices
):
self.update_devices()
data["sg_hostapi"] = self.hostapis[0]
data["sg_input_device"] = self.input_devices[
self.input_devices_indices.index(sd.default.device[0])
]
data["sg_output_device"] = self.output_devices[
self.output_devices_indices.index(sd.default.device[1])
]
else:
data["sg_hostapi"] = self.hostapis[0]
data["sg_input_device"] = self.input_devices[
self.input_devices_indices.index(sd.default.device[0])
]
data["sg_output_device"] = self.output_devices[
self.output_devices_indices.index(sd.default.device[1])
]
except:
with open("configs/inuse/config.json", "w") as j:
data = {
"pth_path": "",
"index_path": "",
"sg_hostapi": self.hostapis[0],
"sg_wasapi_exclusive": False,
"sg_input_device": self.input_devices[
self.input_devices_indices.index(sd.default.device[0])
],
"sg_output_device": self.output_devices[
self.output_devices_indices.index(sd.default.device[1])
],
"sr_type": "sr_model",
"threhold": -60,
"pitch": 0,
"formant": 0.0,
"index_rate": 0,
"rms_mix_rate": 0,
"block_time": 0.25,
"crossfade_length": 0.05,
"extra_time": 2.5,
"n_cpu": 4,
"f0method": "rmvpe",
"use_jit": False,
"use_pv": False,
}
data["sr_model"] = data["sr_type"] == "sr_model"
data["sr_device"] = data["sr_type"] == "sr_device"
data["pm"] = data["f0method"] == "pm"
data["dio"] = data["f0method"] == "dio"
data["harvest"] = data["f0method"] == "harvest"
data["crepe"] = data["f0method"] == "crepe"
data["rmvpe"] = data["f0method"] == "rmvpe"
data["fcpe"] = data["f0method"] == "fcpe"
return data
def launcher(self):
data = self.load()
self.config.use_jit = False # data.get("use_jit", self.config.use_jit)
sg.theme("LightBlue3")
layout = [
[
sg.Frame(
title=i18n("Load model"),
layout=[
[
sg.Input(
default_text=data.get("pth_path", ""),
key="pth_path",
),
sg.FileBrowse(
i18n("Select the .pth file"),
initial_folder=os.path.join(
os.getcwd(), "assets", "weights"
),
file_types=[("Model File", "*.pth")],
),
],
[
sg.Input(
default_text=data.get("index_path", ""),
key="index_path",
),
sg.FileBrowse(
i18n("Select the .index file"),
initial_folder=os.path.join(os.getcwd(), "logs"),
file_types=[("Index File", "*.index")],
),
],
],
)
],
[
sg.Frame(
layout=[
[
sg.Text(i18n("Device type")),
sg.Combo(
self.hostapis,
key="sg_hostapi",
default_value=data.get("sg_hostapi", ""),
enable_events=True,
size=(20, 1),
),
sg.Checkbox(
i18n("Takeover WASAPI device"),
key="sg_wasapi_exclusive",
default=data.get("sg_wasapi_exclusive", False),
enable_events=True,
),
],
[
sg.Text(i18n("Input device")),
sg.Combo(
self.input_devices,
key="sg_input_device",
default_value=data.get("sg_input_device", ""),
enable_events=True,
size=(45, 1),
),
],
[
sg.Text(i18n("Output device")),
sg.Combo(
self.output_devices,
key="sg_output_device",
default_value=data.get("sg_output_device", ""),
enable_events=True,
size=(45, 1),
),
],
[
sg.Button(
i18n("Reload device list"), key="reload_devices"
),
sg.Radio(
i18n("Choose sample rate of the model"),
"sr_type",
key="sr_model",
default=data.get("sr_model", True),
enable_events=True,
),
sg.Radio(
i18n("Choose sample rate of the device"),
"sr_type",
key="sr_device",
default=data.get("sr_device", False),
enable_events=True,
),
sg.Text(i18n("Sampling rate")),
sg.Text("", key="sr_stream"),
],
],
title=i18n("Audio device"),
)
],
[
sg.Frame(
layout=[
[
sg.Text(i18n("Response threshold")),
sg.Slider(
range=(-60, 0),
key="threhold",
resolution=1,
orientation="h",
default_value=data.get("threhold", -60),
enable_events=True,
),
],
[
sg.Text(i18n("Pitch settings")),
sg.Slider(
range=(-24, 24),
key="pitch",
resolution=1,
orientation="h",
default_value=data.get("pitch", 0),
enable_events=True,
),
],
[
sg.Text(i18n("Formant offset")),
sg.Slider(
range=(-5, 5),
key="formant",
resolution=0.01,
orientation="h",
default_value=data.get("formant", 0.0),
enable_events=True,
),
],
[
sg.Text(i18n("Feature searching ratio")),
sg.Slider(
range=(0.0, 1.0),
key="index_rate",
resolution=0.01,
orientation="h",
default_value=data.get("index_rate", 0),
enable_events=True,
),
],
[
sg.Text(i18n("Loudness factor")),
sg.Slider(
range=(0.0, 1.0),
key="rms_mix_rate",
resolution=0.01,
orientation="h",
default_value=data.get("rms_mix_rate", 0),
enable_events=True,
),
],
[
sg.Text(i18n("Pitch detection algorithm")),
sg.Radio(
"pm",
"f0method",
key="pm",
default=data.get("pm", False),
enable_events=True,
),
sg.Radio(
"dio",
"f0method",
key="dio",
default=data.get("dio", False),
enable_events=True,
),
sg.Radio(
"harvest",
"f0method",
key="harvest",
default=data.get("harvest", False),
enable_events=True,
),
sg.Radio(
"crepe",
"f0method",
key="crepe",
default=data.get("crepe", False),
enable_events=True,
),
sg.Radio(
"rmvpe",
"f0method",
key="rmvpe",
default=data.get("rmvpe", False),
enable_events=True,
),
sg.Radio(
"fcpe",
"f0method",
key="fcpe",
default=data.get("fcpe", True),
enable_events=True,
),
],
],
title=i18n("General settings"),
),
sg.Frame(
layout=[
[
sg.Text(i18n("Sample length")),
sg.Slider(
range=(0.02, 1.5),
key="block_time",
resolution=0.01,
orientation="h",
default_value=data.get("block_time", 0.25),
enable_events=True,
),
],
# [
# sg.Text("设备延迟"),
# sg.Slider(
# range=(0, 1),
# key="device_latency",
# resolution=0.001,
# orientation="h",
# default_value=data.get("device_latency", 0.1),
# enable_events=True,
# ),
# ],
[
sg.Text(
i18n(
"Number of CPU processes used for harvest pitch algorithm"
)
),
sg.Slider(
range=(1, n_cpu),
key="n_cpu",
resolution=1,
orientation="h",
default_value=data.get(
"n_cpu", min(self.gui_config.n_cpu, n_cpu)
),
enable_events=True,
),
],
[
sg.Text(i18n("Fade length")),
sg.Slider(
range=(0.01, 0.15),
key="crossfade_length",
resolution=0.01,
orientation="h",
default_value=data.get("crossfade_length", 0.05),
enable_events=True,
),
],
[
sg.Text(i18n("Extra inference time")),
sg.Slider(
range=(0.05, 5.00),
key="extra_time",
resolution=0.01,
orientation="h",
default_value=data.get("extra_time", 2.5),
enable_events=True,
),
],
[
sg.Checkbox(
i18n("Input noise reduction"),
key="I_noise_reduce",
enable_events=True,
),
sg.Checkbox(
i18n("Output noise reduction"),
key="O_noise_reduce",
enable_events=True,
),
sg.Checkbox(
i18n("Enable phase vocoder"),
key="use_pv",
default=data.get("use_pv", False),
enable_events=True,
),
# sg.Checkbox(
# "JIT加速",
# default=self.config.use_jit,
# key="use_jit",
# enable_events=False,
# ),
],
# [sg.Text("注首次使用JIT加速时会出现卡顿\n 并伴随一些噪音,但这是正常现象!")],
],
title=i18n("Performance settings"),
),
],
[
sg.Button(i18n("Start audio conversion"), key="start_vc"),
sg.Button(i18n("Stop audio conversion"), key="stop_vc"),
sg.Radio(
i18n("Input voice monitor"),
"function",
key="im",
default=False,
enable_events=True,
),
sg.Radio(
i18n("Output converted voice"),
"function",
key="vc",
default=True,
enable_events=True,
),
sg.Text(i18n("Algorithmic delays (ms)")),
sg.Text("0", key="delay_time"),
sg.Text(i18n("Inference time (ms)")),
sg.Text("0", key="infer_time"),
],
]
self.window = sg.Window("RVC - GUI", layout=layout, finalize=True)
self.event_handler()
def event_handler(self):
global flag_vc
while True:
event, values = self.window.read()
if event == sg.WINDOW_CLOSED:
self.stop_stream()
# exit()
if event == "reload_devices" or event == "sg_hostapi":
self.gui_config.sg_hostapi = values["sg_hostapi"]
self.update_devices(hostapi_name=values["sg_hostapi"])
if self.gui_config.sg_hostapi not in self.hostapis:
self.gui_config.sg_hostapi = self.hostapis[0]
self.window["sg_hostapi"].Update(values=self.hostapis)
self.window["sg_hostapi"].Update(value=self.gui_config.sg_hostapi)
if (
self.gui_config.sg_input_device not in self.input_devices
and len(self.input_devices) > 0
):
self.gui_config.sg_input_device = self.input_devices[0]
self.window["sg_input_device"].Update(values=self.input_devices)
self.window["sg_input_device"].Update(
value=self.gui_config.sg_input_device
)
if self.gui_config.sg_output_device not in self.output_devices:
self.gui_config.sg_output_device = self.output_devices[0]
self.window["sg_output_device"].Update(values=self.output_devices)
self.window["sg_output_device"].Update(
value=self.gui_config.sg_output_device
)
if event == "start_vc" and not flag_vc:
if self.set_values(values) == True:
printt("cuda_is_available: %s", torch.cuda.is_available())
self.start_vc()
settings = {
"pth_path": values["pth_path"],
"index_path": values["index_path"],
"sg_hostapi": values["sg_hostapi"],
"sg_wasapi_exclusive": values["sg_wasapi_exclusive"],
"sg_input_device": values["sg_input_device"],
"sg_output_device": values["sg_output_device"],
"sr_type": ["sr_model", "sr_device"][
[
values["sr_model"],
values["sr_device"],
].index(True)
],
"threhold": values["threhold"],
"pitch": values["pitch"],
"formant": values["formant"],
"rms_mix_rate": values["rms_mix_rate"],
"index_rate": values["index_rate"],
# "device_latency": values["device_latency"],
"block_time": values["block_time"],
"crossfade_length": values["crossfade_length"],
"extra_time": values["extra_time"],
"n_cpu": values["n_cpu"],
# "use_jit": values["use_jit"],
"use_jit": False,
"use_pv": values["use_pv"],
"f0method": [
"pm",
"dio",
"harvest",
"crepe",
"rmvpe",
"fcpe",
][
[
values["pm"],
values["dio"],
values["harvest"],
values["crepe"],
values["rmvpe"],
values["fcpe"],
].index(True)
],
}
with open("configs/inuse/config.json", "w") as j:
json.dump(settings, j)
if self.stream is not None:
self.delay_time = (
self.stream.get_latency()
+ values["block_time"]
+ values["crossfade_length"]
+ 0.01
)
if values["I_noise_reduce"]:
self.delay_time += min(values["crossfade_length"], 0.04)
self.window["sr_stream"].update(self.gui_config.samplerate)
self.window["delay_time"].update(
int(np.round(self.delay_time * 1000))
)
# Parameter hot update
if event == "threhold":
self.gui_config.threhold = values["threhold"]
elif event == "pitch":
self.gui_config.pitch = values["pitch"]
if hasattr(self, "rvc"):
self.rvc.set_key(values["pitch"])
elif event == "formant":
self.gui_config.formant = values["formant"]
if hasattr(self, "rvc"):
self.rvc.set_formant(values["formant"])
elif event == "index_rate":
self.gui_config.index_rate = values["index_rate"]
if hasattr(self, "rvc"):
self.rvc.set_index_rate(values["index_rate"])
elif event == "rms_mix_rate":
self.gui_config.rms_mix_rate = values["rms_mix_rate"]
elif event in ["pm", "harvest", "crepe", "rmvpe", "fcpe"]:
self.gui_config.f0method = event
elif event == "I_noise_reduce":
self.gui_config.I_noise_reduce = values["I_noise_reduce"]
if self.stream is not None:
self.delay_time += (
1 if values["I_noise_reduce"] else -1
) * min(values["crossfade_length"], 0.04)
self.window["delay_time"].update(
int(np.round(self.delay_time * 1000))
)
elif event == "O_noise_reduce":
self.gui_config.O_noise_reduce = values["O_noise_reduce"]
elif event == "use_pv":
self.gui_config.use_pv = values["use_pv"]
elif event in ["vc", "im"]:
self.function = event
elif event == "stop_vc" or event != "start_vc":
# Other parameters do not support hot update
self.stop_stream()
def set_values(self, values):
if len(values["pth_path"].strip()) == 0:
sg.popup(i18n("Please choose the .pth file"))
return False
if len(values["index_path"].strip()) == 0:
sg.popup(i18n("Please choose the .index file"))
return False
pattern = re.compile("[^\x00-\x7f]+")
if pattern.findall(values["pth_path"]):
sg.popup(i18n("pth path cannot contain unicode characters"))
return False
if pattern.findall(values["index_path"]):
sg.popup(i18n("index path cannot contain unicode characters"))
return False
self.set_devices(values["sg_input_device"], values["sg_output_device"])
self.config.use_jit = False # values["use_jit"]
# self.device_latency = values["device_latency"]
self.gui_config.sg_hostapi = values["sg_hostapi"]
self.gui_config.sg_wasapi_exclusive = values["sg_wasapi_exclusive"]
self.gui_config.sg_input_device = values["sg_input_device"]
self.gui_config.sg_output_device = values["sg_output_device"]
self.gui_config.pth_path = values["pth_path"]
self.gui_config.index_path = values["index_path"]
self.gui_config.sr_type = ["sr_model", "sr_device"][
[
values["sr_model"],
values["sr_device"],
].index(True)
]
self.gui_config.threhold = values["threhold"]
self.gui_config.pitch = values["pitch"]
self.gui_config.formant = values["formant"]
self.gui_config.block_time = values["block_time"]
self.gui_config.crossfade_time = values["crossfade_length"]
self.gui_config.extra_time = values["extra_time"]
self.gui_config.I_noise_reduce = values["I_noise_reduce"]
self.gui_config.O_noise_reduce = values["O_noise_reduce"]
self.gui_config.use_pv = values["use_pv"]
self.gui_config.rms_mix_rate = values["rms_mix_rate"]
self.gui_config.index_rate = values["index_rate"]
self.gui_config.n_cpu = values["n_cpu"]
self.gui_config.f0method = [
"pm",
"dio",
"harvest",
"crepe",
"rmvpe",
"fcpe",
][
[
values["pm"],
values["dio"],
values["harvest"],
values["crepe"],
values["rmvpe"],
values["fcpe"],
].index(True)
]
return True
def start_vc(self):
torch.cuda.empty_cache()
self.rvc = rtrvc.RVC(
self.gui_config.pitch,
self.gui_config.formant,
self.gui_config.pth_path,
self.gui_config.index_path,
self.gui_config.index_rate,
self.gui_config.n_cpu,
self.config.device,
self.config.use_jit,
self.config.is_half,
self.config.dml,
)
self.gui_config.samplerate = (
self.rvc.tgt_sr
if self.gui_config.sr_type == "sr_model"
else self.get_device_samplerate()
)
self.gui_config.channels = self.get_device_channels()
self.zc = self.gui_config.samplerate // 100
self.block_frame = (
int(
np.round(
self.gui_config.block_time
* self.gui_config.samplerate
/ self.zc
)
)
* self.zc
)
self.block_frame_16k = 160 * self.block_frame // self.zc
self.crossfade_frame = (
int(
np.round(
self.gui_config.crossfade_time
* self.gui_config.samplerate
/ self.zc
)
)
* self.zc
)
self.sola_buffer_frame = min(self.crossfade_frame, 4 * self.zc)
self.sola_search_frame = self.zc
self.extra_frame = (
int(
np.round(
self.gui_config.extra_time
* self.gui_config.samplerate
/ self.zc
)
)
* self.zc
)
self.input_wav: torch.Tensor = torch.zeros(
self.extra_frame
+ self.crossfade_frame
+ self.sola_search_frame
+ self.block_frame,
device=self.config.device,
dtype=torch.float32,
)
self.input_wav_denoise: torch.Tensor = self.input_wav.clone()
self.input_wav_res: torch.Tensor = torch.zeros(
160 * self.input_wav.shape[0] // self.zc,
device=self.config.device,
dtype=torch.float32,
)
self.rms_buffer: np.ndarray = np.zeros(4 * self.zc, dtype="float32")
self.sola_buffer: torch.Tensor = torch.zeros(
self.sola_buffer_frame, device=self.config.device, dtype=torch.float32
)
self.nr_buffer: torch.Tensor = self.sola_buffer.clone()
self.output_buffer: torch.Tensor = self.input_wav.clone()
self.skip_head = self.extra_frame // self.zc
self.return_length = (
self.block_frame + self.sola_buffer_frame + self.sola_search_frame
) // self.zc
self.fade_in_window: torch.Tensor = (
torch.sin(
0.5
* np.pi
* torch.linspace(
0.0,
1.0,
steps=self.sola_buffer_frame,
device=self.config.device,
dtype=torch.float32,
)
)
** 2
)
self.fade_out_window: torch.Tensor = 1 - self.fade_in_window
self.resampler = tat.Resample(
orig_freq=self.gui_config.samplerate,
new_freq=16000,
dtype=torch.float32,
).to(self.config.device)
if self.rvc.tgt_sr != self.gui_config.samplerate:
self.resampler2 = tat.Resample(
orig_freq=self.rvc.tgt_sr,
new_freq=self.gui_config.samplerate,
dtype=torch.float32,
).to(self.config.device)
else:
self.resampler2 = None
self.tg = TorchGate(
sr=self.gui_config.samplerate, n_fft=4 * self.zc, prop_decrease=0.9
).to(self.config.device)
self.start_stream()
def start_stream(self):
global flag_vc
if not flag_vc:
flag_vc = True
if (
"WASAPI" in self.gui_config.sg_hostapi
and self.gui_config.sg_wasapi_exclusive
):
wasapi_exclusive = True
else:
wasapi_exclusive = False
self.stream = AudioIoProcess(
input_device=sd.default.device[0],
output_device=sd.default.device[1],
input_audio_block_size=self.block_frame,
sample_rate=self.gui_config.samplerate,
channel_num=self.gui_config.channels,
is_input_wasapi_exclusive=wasapi_exclusive,
is_output_wasapi_exclusive=wasapi_exclusive,
is_device_combined=True,
# TODO: Add control UI to allow devices with different type API & different WASAPI settings
)
self.in_mem = SharedMemory(name=self.stream.get_in_mem_name())
self.out_mem = SharedMemory(name=self.stream.get_out_mem_name())
self.in_buf = np.ndarray(
self.stream.get_np_shape(),
dtype=self.stream.get_np_dtype(),
buffer=self.in_mem.buf,
order="C",
)
self.out_buf = np.ndarray(
self.stream.get_np_shape(),
dtype=self.stream.get_np_dtype(),
buffer=self.out_mem.buf,
order="C",
)
self.in_ptr, self.out_ptr, self.play_ptr, self.in_evt, self.stop_evt = (
self.stream.get_ptrs_and_events()
)
self.stream.start()
def audio_loop():
while flag_vc:
self.audio_infer(self.block_frame << 1)
threading.Thread(target=audio_loop, daemon=True).start()
def stop_stream(self):
global flag_vc
if flag_vc:
flag_vc = False
if self.stream is not None:
print("Exiting")
self.stop_evt.set()
self.in_mem.close()
self.out_mem.close()
self.stream.join()
self.stream = None
def audio_infer(self, buf_size: int): # 2 * self.block_frame
"""
音频处理
"""
global flag_vc
self.in_evt.wait()
rptr = self.in_ptr.value
self.in_evt.clear()
start_time = time.perf_counter()
rend = rptr + self.block_frame
indata = np.copy(self.in_buf[rptr:rend])
indata = librosa.to_mono(indata.T)
if self.gui_config.threhold > -60:
indata = np.append(self.rms_buffer, indata)
rms = librosa.feature.rms(
y=indata, frame_length=4 * self.zc, hop_length=self.zc
)[:, 2:]
self.rms_buffer[:] = indata[-4 * self.zc :]
indata = indata[2 * self.zc - self.zc // 2 :]
db_threhold = (
librosa.amplitude_to_db(rms, ref=1.0)[0] < self.gui_config.threhold
)
for i in range(db_threhold.shape[0]):
if db_threhold[i]:
indata[i * self.zc : (i + 1) * self.zc] = 0
indata = indata[self.zc // 2 :]
self.input_wav[: -self.block_frame] = self.input_wav[
self.block_frame :
].clone()
self.input_wav[-indata.shape[0] :] = torch.from_numpy(indata).to(
self.config.device
)
self.input_wav_res[: -self.block_frame_16k] = self.input_wav_res[
self.block_frame_16k :
].clone()
# input noise reduction and resampling
if self.gui_config.I_noise_reduce:
self.input_wav_denoise[: -self.block_frame] = self.input_wav_denoise[
self.block_frame :
].clone()
input_wav = self.input_wav[-self.sola_buffer_frame - self.block_frame :]
input_wav = self.tg(
input_wav.unsqueeze(0), self.input_wav.unsqueeze(0)
).squeeze(0)
input_wav[: self.sola_buffer_frame] *= self.fade_in_window
input_wav[: self.sola_buffer_frame] += (
self.nr_buffer * self.fade_out_window
)
self.input_wav_denoise[-self.block_frame :] = input_wav[
: self.block_frame
]
self.nr_buffer[:] = input_wav[self.block_frame :]
self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler(
self.input_wav_denoise[-self.block_frame - 2 * self.zc :]
)[160:]
else:
self.input_wav_res[-160 * (indata.shape[0] // self.zc + 1) :] = (
self.resampler(self.input_wav[-indata.shape[0] - 2 * self.zc :])[
160:
]
)
# infer
if self.function == "vc":
infer_wav = self.rvc.infer(
self.input_wav_res,
self.block_frame_16k,
self.skip_head,
self.return_length,
self.gui_config.f0method,
)
if self.resampler2 is not None:
infer_wav = self.resampler2(infer_wav)
elif self.gui_config.I_noise_reduce:
infer_wav = self.input_wav_denoise[self.extra_frame :].clone()
else:
infer_wav = self.input_wav[self.extra_frame :].clone()
# output noise reduction
if self.gui_config.O_noise_reduce and self.function == "vc":
self.output_buffer[: -self.block_frame] = self.output_buffer[
self.block_frame :
].clone()
self.output_buffer[-self.block_frame :] = infer_wav[-self.block_frame :]
infer_wav = self.tg(
infer_wav.unsqueeze(0), self.output_buffer.unsqueeze(0)
).squeeze(0)
# volume envelop mixing
if self.gui_config.rms_mix_rate < 1 and self.function == "vc":
if self.gui_config.I_noise_reduce:
input_wav = self.input_wav_denoise[self.extra_frame :]
else:
input_wav = self.input_wav[self.extra_frame :]
rms1 = librosa.feature.rms(
y=input_wav[: infer_wav.shape[0]].cpu().numpy(),
frame_length=4 * self.zc,
hop_length=self.zc,
)
rms1 = torch.from_numpy(rms1).to(self.config.device)
rms1 = F.interpolate(
rms1.unsqueeze(0),
size=infer_wav.shape[0] + 1,
mode="linear",
align_corners=True,
)[0, 0, :-1]
rms2 = librosa.feature.rms(
y=infer_wav[:].cpu().numpy(),
frame_length=4 * self.zc,
hop_length=self.zc,
)
rms2 = torch.from_numpy(rms2).to(self.config.device)
rms2 = F.interpolate(
rms2.unsqueeze(0),
size=infer_wav.shape[0] + 1,
mode="linear",
align_corners=True,
)[0, 0, :-1]
rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3)
infer_wav *= torch.pow(
rms1 / rms2, torch.tensor(1 - self.gui_config.rms_mix_rate)
)
# SOLA algorithm from https://github.com/yxlllc/DDSP-SVC
conv_input = infer_wav[
None, None, : self.sola_buffer_frame + self.sola_search_frame
]
cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
cor_den = torch.sqrt(
F.conv1d(
conv_input**2,
torch.ones(1, 1, self.sola_buffer_frame, device=self.config.device),
)
+ 1e-8
)
if sys.platform == "darwin":
_, sola_offset = torch.max(cor_nom[0, 0] / cor_den[0, 0])
sola_offset = sola_offset.item()
else:
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
# printt("sola_offset = %d", int(sola_offset))
infer_wav = infer_wav[sola_offset:]
if "privateuseone" in str(self.config.device) or not self.gui_config.use_pv:
infer_wav[: self.sola_buffer_frame] *= self.fade_in_window
infer_wav[: self.sola_buffer_frame] += (
self.sola_buffer * self.fade_out_window
)
else:
infer_wav[: self.sola_buffer_frame] = phase_vocoder(
self.sola_buffer,
infer_wav[: self.sola_buffer_frame],
self.fade_out_window,
self.fade_in_window,
)
self.sola_buffer[:] = infer_wav[
self.block_frame : self.block_frame + self.sola_buffer_frame
]
outdata = (
infer_wav[: self.block_frame]
.repeat(self.gui_config.channels, 1)
.t()
.cpu()
.numpy()
)
# 装填输出缓冲
start = self.out_ptr.value
play_pos = self.play_ptr.value
# 计算播放进度差(写指针距离播放指针的帧数)
delta = (start - play_pos + buf_size) % buf_size
if delta < self.block_frame:
# 装填赶不上播放,导致播放进度追上来了,
# 此时已产生无法挽回的破音,
# 只好直接卡着播放指针写入,保证接下来的尽快放出来
print("[W] Output underrun")
write_pos = play_pos
else:
# 否则按块对齐
write_pos = (start + self.block_frame) % buf_size
# 写入共享缓冲区
end = (write_pos + self.block_frame) % buf_size
if end > write_pos:
self.out_buf[write_pos:end] = outdata
else:
first = buf_size - write_pos
self.out_buf[write_pos:] = outdata[:first]
self.out_buf[:end] = outdata[first:]
# 更新写指针
self.out_ptr.value = write_pos
if self.in_evt.is_set():
print("[W] Input overrun")
self.in_evt.clear()
total_time = time.perf_counter() - start_time
if flag_vc:
self.window["infer_time"].update(int(total_time * 1000))
# printt("Infer time: %.2f", total_time)
def update_devices(self, hostapi_name=None):
"""获取设备列表"""
global flag_vc
flag_vc = False
sd._terminate()
sd._initialize()
devices = sd.query_devices()
hostapis = sd.query_hostapis()
for hostapi in hostapis:
for device_idx in hostapi["devices"]:
devices[device_idx]["hostapi_name"] = hostapi["name"]
self.hostapis = [hostapi["name"] for hostapi in hostapis]
if hostapi_name not in self.hostapis:
hostapi_name = self.hostapis[0]
self.input_devices = [
d["name"]
for d in devices
if d["max_input_channels"] > 0 and d["hostapi_name"] == hostapi_name
]
self.output_devices = [
d["name"]
for d in devices
if d["max_output_channels"] > 0 and d["hostapi_name"] == hostapi_name
]
self.input_devices_indices = [
d["index"] if "index" in d else d["name"]
for d in devices
if d["max_input_channels"] > 0 and d["hostapi_name"] == hostapi_name
]
self.output_devices_indices = [
d["index"] if "index" in d else d["name"]
for d in devices
if d["max_output_channels"] > 0 and d["hostapi_name"] == hostapi_name
]
def set_devices(self, input_device, output_device):
"""设置输出设备"""
sd.default.device[0] = self.input_devices_indices[
self.input_devices.index(input_device)
]
sd.default.device[1] = self.output_devices_indices[
self.output_devices.index(output_device)
]
printt("Input device: %s:%s", str(sd.default.device[0]), input_device)
printt("Output device: %s:%s", str(sd.default.device[1]), output_device)
def get_device_samplerate(self):
return int(
sd.query_devices(device=sd.default.device[0])["default_samplerate"]
)
def get_device_channels(self):
max_input_channels = sd.query_devices(device=sd.default.device[0])[
"max_input_channels"
]
max_output_channels = sd.query_devices(device=sd.default.device[1])[
"max_output_channels"
]
return min(max_input_channels, max_output_channels, 2)
gui = GUI()