From 22715eab7cccd13014bb0bbfa82cfa99788a7045 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:29:23 +0900 Subject: [PATCH] optimize(rmvpe): move mel&stft into rvc --- infer/lib/rmvpe.py | 222 ++------------------------------- infer/modules/gui/torchgate.py | 2 +- rvc/f0/mel.py | 71 +++++++++++ rvc/f0/stft.py | 194 ++++++++++++++++++++++++++++ 4 files changed, 276 insertions(+), 213 deletions(-) create mode 100644 rvc/f0/mel.py create mode 100644 rvc/f0/stft.py diff --git a/infer/lib/rmvpe.py b/infer/lib/rmvpe.py index ba5a09b..1f924c1 100644 --- a/infer/lib/rmvpe.py +++ b/infer/lib/rmvpe.py @@ -1,6 +1,6 @@ from io import BytesIO import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -25,136 +25,7 @@ import logging logger = logging.getLogger(__name__) - -class STFT(torch.nn.Module): - def __init__( - self, filter_length=1024, hop_length=512, win_length=None, window="hann" - ): - """ - This module implements an STFT using 1D convolution and 1D transpose convolutions. - This is a bit tricky so there are some cases that probably won't work as working - out the same sizes before and after in all overlap add setups is tough. Right now, - this code should work with hop lengths that are half the filter length (50% overlap - between frames). - - Keyword Arguments: - filter_length {int} -- Length of filters used (default: {1024}) - hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512}) - win_length {[type]} -- Length of the window function applied to each frame (if not specified, it - equals the filter length). (default: {None}) - window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris) - (default: {'hann'}) - """ - super(STFT, self).__init__() - self.filter_length = filter_length - self.hop_length = hop_length - self.win_length = win_length if win_length else filter_length - self.window = window - self.forward_transform = None - self.pad_amount = int(self.filter_length / 2) - fourier_basis = np.fft.fft(np.eye(self.filter_length)) - - cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack( - [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] - ) - forward_basis = torch.FloatTensor(fourier_basis) - inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis)) - - assert filter_length >= self.win_length - # get window and zero center pad it to filter_length - fft_window = get_window(window, self.win_length, fftbins=True) - fft_window = pad_center(fft_window, size=filter_length) - fft_window = torch.from_numpy(fft_window).float() - - # window the bases - forward_basis *= fft_window - inverse_basis = (inverse_basis.T * fft_window).T - - self.register_buffer("forward_basis", forward_basis.float()) - self.register_buffer("inverse_basis", inverse_basis.float()) - self.register_buffer("fft_window", fft_window.float()) - - def transform(self, input_data, return_phase=False): - """Take input data (audio) to STFT domain. - - Arguments: - input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) - - Returns: - magnitude {tensor} -- Magnitude of STFT with shape (num_batch, - num_frequencies, num_frames) - phase {tensor} -- Phase of STFT with shape (num_batch, - num_frequencies, num_frames) - """ - input_data = F.pad( - input_data, - (self.pad_amount, self.pad_amount), - mode="reflect", - ) - forward_transform = input_data.unfold( - 1, self.filter_length, self.hop_length - ).permute(0, 2, 1) - forward_transform = torch.matmul(self.forward_basis, forward_transform) - cutoff = int((self.filter_length / 2) + 1) - real_part = forward_transform[:, :cutoff, :] - imag_part = forward_transform[:, cutoff:, :] - magnitude = torch.sqrt(real_part**2 + imag_part**2) - if return_phase: - phase = torch.atan2(imag_part.data, real_part.data) - return magnitude, phase - else: - return magnitude - - def inverse(self, magnitude, phase): - """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced - by the ```transform``` function. - - Arguments: - magnitude {tensor} -- Magnitude of STFT with shape (num_batch, - num_frequencies, num_frames) - phase {tensor} -- Phase of STFT with shape (num_batch, - num_frequencies, num_frames) - - Returns: - inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of - shape (num_batch, num_samples) - """ - cat = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 - ) - fold = torch.nn.Fold( - output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length), - kernel_size=(1, self.filter_length), - stride=(1, self.hop_length), - ) - inverse_transform = torch.matmul(self.inverse_basis, cat) - inverse_transform = fold(inverse_transform)[ - :, 0, 0, self.pad_amount : -self.pad_amount - ] - window_square_sum = ( - self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0) - ) - window_square_sum = fold(window_square_sum)[ - :, 0, 0, self.pad_amount : -self.pad_amount - ] - inverse_transform /= window_square_sum - return inverse_transform - - def forward(self, input_data): - """Take input data (audio) to STFT domain and then back to audio. - - Arguments: - input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) - - Returns: - reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of - shape (num_batch, num_samples) - """ - self.magnitude, self.phase = self.transform(input_data, return_phase=True) - reconstruction = self.inverse(self.magnitude, self.phase) - return reconstruction - +from rvc.f0.mel import MelSpectrogram from time import time as ttime @@ -412,86 +283,6 @@ class E2E(nn.Module): return x -from librosa.filters import mel - - -class MelSpectrogram(torch.nn.Module): - def __init__( - self, - is_half, - n_mel_channels, - sampling_rate, - win_length, - hop_length, - n_fft=None, - mel_fmin=0, - mel_fmax=None, - clamp=1e-5, - ): - super().__init__() - n_fft = win_length if n_fft is None else n_fft - self.hann_window = {} - mel_basis = mel( - sr=sampling_rate, - n_fft=n_fft, - n_mels=n_mel_channels, - fmin=mel_fmin, - fmax=mel_fmax, - htk=True, - ) - mel_basis = torch.from_numpy(mel_basis).float() - self.register_buffer("mel_basis", mel_basis) - self.n_fft = win_length if n_fft is None else n_fft - self.hop_length = hop_length - self.win_length = win_length - self.sampling_rate = sampling_rate - self.n_mel_channels = n_mel_channels - self.clamp = clamp - self.is_half = is_half - - def forward(self, audio, keyshift=0, speed=1, center=True): - factor = 2 ** (keyshift / 12) - n_fft_new = int(np.round(self.n_fft * factor)) - win_length_new = int(np.round(self.win_length * factor)) - hop_length_new = int(np.round(self.hop_length * speed)) - keyshift_key = str(keyshift) + "_" + str(audio.device) - if keyshift_key not in self.hann_window: - self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to( - audio.device - ) - if "privateuseone" in str(audio.device): - if not hasattr(self, "stft"): - self.stft = STFT( - filter_length=n_fft_new, - hop_length=hop_length_new, - win_length=win_length_new, - window="hann", - ).to(audio.device) - magnitude = self.stft.transform(audio) - else: - fft = torch.stft( - audio, - n_fft=n_fft_new, - hop_length=hop_length_new, - win_length=win_length_new, - window=self.hann_window[keyshift_key], - center=center, - return_complex=True, - ) - magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) - if keyshift != 0: - size = self.n_fft // 2 + 1 - resize = magnitude.size(1) - if resize < size: - magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) - magnitude = magnitude[:, :size, :] * self.win_length / win_length_new - mel_output = torch.matmul(self.mel_basis, magnitude) - if self.is_half == True: - mel_output = mel_output.half() - log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) - return log_mel_spec - - class RMVPE: def __init__(self, model_path: str, is_half, device=None, use_jit=False): self.resample_kernel = {} @@ -501,7 +292,14 @@ class RMVPE: device = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = device self.mel_extractor = MelSpectrogram( - is_half, 128, 16000, 1024, 160, None, 30, 8000 + is_half=is_half, + n_mel_channels=128, + sampling_rate=16000, + win_length=1024, + hop_length=160, + mel_fmin=30, + mel_fmax=8000, + device=device, ).to(device) if "privateuseone" in str(device): import onnxruntime as ort diff --git a/infer/modules/gui/torchgate.py b/infer/modules/gui/torchgate.py index e4b80c4..3111bfa 100644 --- a/infer/modules/gui/torchgate.py +++ b/infer/modules/gui/torchgate.py @@ -1,5 +1,5 @@ import torch -from infer.lib.rmvpe import STFT +from rvc.f0.stft import STFT from torch.nn.functional import conv1d, conv2d from typing import Union, Optional from .utils import linspace, temperature_sigmoid, amp_to_db diff --git a/rvc/f0/mel.py b/rvc/f0/mel.py new file mode 100644 index 0000000..439a258 --- /dev/null +++ b/rvc/f0/mel.py @@ -0,0 +1,71 @@ +from typing import Optional + +import torch +import numpy as np +from librosa.filters import mel + +from .stft import STFT + + +class MelSpectrogram(torch.nn.Module): + def __init__( + self, + is_half: bool, + n_mel_channels: int, + sampling_rate: int, + win_length: int, + hop_length: int, + n_fft: Optional[int] = None, + mel_fmin: int = 0, + mel_fmax: int = None, + clamp: float = 1e-5, + device = torch.device("cpu"), + ): + super().__init__() + if n_fft is None: + n_fft = win_length + mel_basis = mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True, + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.clamp = clamp + self.is_half = is_half + + self.stft = STFT( + filter_length=n_fft, + hop_length=hop_length, + win_length=win_length, + window="hann", + use_torch_stft="privateuseone" not in str(device) + ).to(device) + + def forward( + self, + audio: torch.Tensor, + keyshift=0, + speed=1, + center=True, + ): + factor = 2 ** (keyshift / 12) + win_length_new = int(np.round(self.win_length * factor)) + magnitude = self.stft(audio, keyshift, speed, center) + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = torch.nn.functional.pad(magnitude, (0, 0, 0, size - resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + mel_output = torch.matmul(self.mel_basis, magnitude) + if self.is_half: + mel_output = mel_output.half() + log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + return log_mel_spec diff --git a/rvc/f0/stft.py b/rvc/f0/stft.py new file mode 100644 index 0000000..262f3c7 --- /dev/null +++ b/rvc/f0/stft.py @@ -0,0 +1,194 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from librosa.util import pad_center +from scipy.signal import get_window + + +class STFT(torch.nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=512, + win_length: Optional[int] = None, + window="hann", + use_torch_stft = True, + ): + """ + This module implements an STFT using 1D convolution and 1D transpose convolutions. + This is a bit tricky so there are some cases that probably won't work as working + out the same sizes before and after in all overlap add setups is tough. Right now, + this code should work with hop lengths that are half the filter length (50% overlap + between frames). + + Keyword Arguments: + filter_length {int} -- Length of filters used (default: {1024}) + hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512}) + win_length {[type]} -- Length of the window function applied to each frame (if not specified, it + equals the filter length). (default: {None}) + window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris) + (default: {'hann'}) + """ + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.pad_amount = int(self.filter_length / 2) + self.win_length = win_length + self.hann_window = {} + self.use_torch_stft = use_torch_stft + + if use_torch_stft: + return + + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + forward_basis = torch.FloatTensor(fourier_basis) + inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis)) + + if win_length is None or not win_length: + win_length = filter_length + assert filter_length >= win_length + + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis = (inverse_basis.T * fft_window).T + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + self.register_buffer("fft_window", fft_window.float()) + + def __call__( + self, + input_data: torch.Tensor, + keyshift: int = 0, + speed: int = 1, + center: bool = True, + ) -> torch.Tensor: + return super().__call__(input_data, keyshift, speed, center) + + def transform( + self, + input_data: torch.Tensor, + return_phase=False, + ) -> Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]: + """Take input data (audio) to STFT domain. + + Arguments: + input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) + + Returns: + magnitude {tensor} -- Magnitude of STFT with shape (num_batch, + num_frequencies, num_frames) + phase {tensor} -- Phase of STFT with shape (num_batch, + num_frequencies, num_frames) + """ + input_data = F.pad( + input_data, + (self.pad_amount, self.pad_amount), + mode="reflect", + ) + forward_transform = input_data.unfold( + 1, self.filter_length, self.hop_length + ).permute(0, 2, 1) + forward_transform = torch.matmul(self.forward_basis, forward_transform) + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + magnitude = torch.sqrt(real_part**2 + imag_part**2) + if return_phase: + phase = torch.atan2(imag_part.data, real_part.data) + return magnitude, phase + else: + return magnitude + + def inverse( + self, + magnitude: torch.Tensor, + phase: torch.Tensor, + ) -> torch.Tensor: + """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced + by the ```transform``` function. + + Arguments: + magnitude {tensor} -- Magnitude of STFT with shape (num_batch, + num_frequencies, num_frames) + phase {tensor} -- Phase of STFT with shape (num_batch, + num_frequencies, num_frames) + + Returns: + inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of + shape (num_batch, num_samples) + """ + cat = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + fold = torch.nn.Fold( + output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length), + kernel_size=(1, self.filter_length), + stride=(1, self.hop_length), + ) + inverse_transform = torch.matmul(self.inverse_basis, cat) + inverse_transform: torch.Tensor = fold(inverse_transform)[ + :, 0, 0, self.pad_amount : -self.pad_amount + ] + window_square_sum = ( + self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0) + ) + window_square_sum = fold(window_square_sum)[ + :, 0, 0, self.pad_amount : -self.pad_amount + ] + inverse_transform /= window_square_sum + return inverse_transform + + def forward( + self, + input_data: torch.Tensor, + keyshift: int = 0, + speed: int = 1, + center: bool = True, + ) -> torch.Tensor: + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.filter_length * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + if self.use_torch_stft: + keyshift_key = str(keyshift) + "_" + str(input_data.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window( + self.win_length, + ).to(input_data.device) + fft = torch.stft( + input_data, + n_fft=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window=self.hann_window[keyshift_key], + center=center, + return_complex=True, + ) + return torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + return self.transform(input_data) + """Take input data (audio) to STFT domain and then back to audio. + + Arguments: + input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) + + Returns: + reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of + shape (num_batch, num_samples) + reconstruction = self.inverse( + self.transform(input_data, return_phase=True), + ) + return reconstruction + """