diff --git a/infer/lib/audio.py b/infer/lib/audio.py index 30d1923..8c29960 100644 --- a/infer/lib/audio.py +++ b/infer/lib/audio.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Dict import numpy as np import av +import os from av.audio.resampler import AudioResampler video_format_dict: Dict[str, str] = { @@ -54,6 +55,67 @@ def load_audio(file: str, sr: int) -> np.ndarray: return audio.flatten() +def downsample_audio(input_path: str, output_path: str, format: str) -> None: + if not os.path.exists(input_path): return + + input_container = av.open(input_path) + output_container = av.open(output_path, 'w') + + # Create a stream in the output container + input_stream = input_container.streams.audio[0] + output_stream = output_container.add_stream(format) + + output_stream.bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2) + + # Copy packets from the input file to the output file + for packet in input_container.demux(input_stream): + for frame in packet.decode(): + for out_packet in output_stream.encode(frame): + output_container.mux(out_packet) + + for packet in output_stream.encode(): + output_container.mux(packet) + + # Close the containers + input_container.close() + output_container.close() + + try: # Remove the original file + os.remove(input_path) + except Exception as e: + print(f"Failed to remove the original file: {e}") + +def resample_audio(input_path: str, output_path: str, format: str, sr: int, layout: str) -> None: + if not os.path.exists(input_path): return + + input_container = av.open(input_path) + output_container = av.open(output_path, 'w') + + # Create a stream in the output container + input_stream = input_container.streams.audio[0] + output_stream = output_container.add_stream(format, rate=sr, layout=layout) + + resampler = AudioResampler(format, layout, sr) + + # Copy packets from the input file to the output file + for packet in input_container.demux(input_stream): + for frame in packet.decode(): + frame.pts = None # Clear presentation timestamp to avoid resampling issues + resampled = resampler.resample(frame) + for out_packet in output_stream.encode(resampled): + output_container.mux(out_packet) + + for packet in output_stream.encode(): + output_container.mux(packet) + + # Close the containers + input_container.close() + output_container.close() + + try: # Remove the original file + os.remove(input_path) + except Exception as e: + print(f"Failed to remove the original file: {e}") def clean_path(path: str) -> Path: return Path(path.strip(' "\n')).resolve() diff --git a/infer/modules/uvr5/mdxnet.py b/infer/modules/uvr5/mdxnet.py index 0a431be..3e6d652 100644 --- a/infer/modules/uvr5/mdxnet.py +++ b/infer/modules/uvr5/mdxnet.py @@ -10,6 +10,8 @@ import torch from tqdm import tqdm import av +from infer.lib.audio import downsample_audio + cpu = torch.device("cpu") @@ -219,10 +221,10 @@ class Predictor: sf.write(path_other, opt, rate) opt_path_vocal = path_vocal[:-4] + ".%s" % format opt_path_other = path_other[:-4] + ".%s" % format - process_audio(path_vocal, opt_path_vocal, format) - process_audio(path_other, opt_path_other, format) + downsample_audio(path_vocal, opt_path_vocal, format) + downsample_audio(path_other, opt_path_other, format) -def process_audio(input_path: str, output_path: str, format: str) -> None: +def downsample_audio(input_path: str, output_path: str, format: str) -> None: if not os.path.exists(input_path): return input_container = av.open(input_path) diff --git a/infer/modules/uvr5/modules.py b/infer/modules/uvr5/modules.py index 1463402..b06619d 100644 --- a/infer/modules/uvr5/modules.py +++ b/infer/modules/uvr5/modules.py @@ -5,7 +5,7 @@ import logging logger = logging.getLogger(__name__) import av -from av.audio.resampler import AudioResampler +from infer.lib.audio import resample_audio import torch from configs import Config @@ -63,7 +63,7 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format os.path.join(os.environ["TEMP"]), os.path.basename(inp_path), ) - process_audio(inp_path, tmp_path) + resample_audio(inp_path, tmp_path, 'pcm_s16le', 44100, 'stereo') inp_path = tmp_path try: if done == 0: @@ -105,37 +105,3 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format torch.mps.empty_cache() logger.info("Executed torch.mps.empty_cache()") yield "\n".join(infos) - -def process_audio(input_path: str, output_path: str) -> None: - if not os.path.exists(input_path): return - - input_container = av.open(input_path) - output_container = av.open(output_path, 'w') - - # Create a stream in the output container - input_stream = input_container.streams.audio[0] - output_stream = output_container.add_stream('pcm_s16le', rate=44100, layout='stereo') - - resampler = AudioResampler('pcm_s16le', 'stereo', 44100) - - output_stream.bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2) - - # Copy packets from the input file to the output file - for packet in input_container.demux(input_stream): - for frame in packet.decode(): - frame.pts = None # Clear presentation timestamp to avoid resampling issues - resampled = resampler.resample(frame) - for out_packet in output_stream.encode(resampled): - output_container.mux(out_packet) - - for packet in output_stream.encode(): - output_container.mux(packet) - - # Close the containers - input_container.close() - output_container.close() - - try: # Remove the original file - os.remove(input_path) - except Exception as e: - print(f"Failed to remove the original file: {e}") \ No newline at end of file diff --git a/infer/modules/uvr5/vr.py b/infer/modules/uvr5/vr.py index 8160e4a..5496ca3 100644 --- a/infer/modules/uvr5/vr.py +++ b/infer/modules/uvr5/vr.py @@ -6,7 +6,7 @@ logger = logging.getLogger(__name__) import librosa import numpy as np import soundfile as sf -import av +from infer.lib.audio import downsample_audio import torch from infer.lib.uvr5_pack.lib_v5 import nets_123821KB as Nets @@ -147,7 +147,7 @@ class AudioPre: ) if os.path.exists(path): opt_format_path = path[:-4] + ".%s" % format - process_audio(path, opt_format_path, format) + downsample_audio(path, opt_format_path, format) if vocal_root is not None: if is_hp3 == True: head = "instrument_" @@ -182,37 +182,7 @@ class AudioPre: self.mp.param["sr"], ) opt_format_path = path[:-4] + ".%s" % format - process_audio(path, opt_format_path, format) - -def process_audio(input_path: str, output_path: str, format: str) -> None: - if not os.path.exists(input_path): return - - input_container = av.open(input_path) - output_container = av.open(output_path, 'w') - - # Create a stream in the output container - input_stream = input_container.streams.audio[0] - output_stream = output_container.add_stream(format) - - output_stream.bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2) - - # Copy packets from the input file to the output file - for packet in input_container.demux(input_stream): - for frame in packet.decode(): - for out_packet in output_stream.encode(frame): - output_container.mux(out_packet) - - for packet in output_stream.encode(): - output_container.mux(packet) - - # Close the containers - input_container.close() - output_container.close() - - try: # Remove the original file - os.remove(input_path) - except Exception as e: - print(f"Failed to remove the original file: {e}") + downsample_audio(path, opt_format_path, format) class AudioPreDeEcho: def __init__(self, agg, model_path, device, is_half, tta=False): @@ -342,7 +312,7 @@ class AudioPreDeEcho: ) if os.path.exists(path): opt_format_path = path[:-4] + ".%s" % format - process_audio(path, opt_format_path, format) + downsample_audio(path, opt_format_path, format) if vocal_root is not None: if self.data["high_end_process"].startswith("mirroring"): input_high_end_ = spec_utils.mirroring( @@ -374,4 +344,4 @@ class AudioPreDeEcho: ) if os.path.exists(path): opt_format_path = path[:-4] + ".%s" % format - process_audio(path, opt_format_path, format) + downsample_audio(path, opt_format_path, format)