diff --git a/infer/lib/audio.py b/infer/lib/audio.py index 8c29960..ce57bfe 100644 --- a/infer/lib/audio.py +++ b/infer/lib/audio.py @@ -1,6 +1,6 @@ from io import BufferedWriter, BytesIO from pathlib import Path -from typing import Dict +from typing import Dict, Tuple import numpy as np import av import os @@ -47,7 +47,7 @@ def load_audio(file: str, sr: int) -> np.ndarray: for frame in container.decode(audio=0): frame.pts = None # Clear presentation timestamp to avoid resampling issues resampled = resampler.resample(frame) - decoded_audio.append(resampled.to_ndarray()) + decoded_audio.append(np.array(resampled)) audio = np.concatenate(decoded_audio) except Exception as e: @@ -85,7 +85,7 @@ def downsample_audio(input_path: str, output_path: str, format: str) -> None: except Exception as e: print(f"Failed to remove the original file: {e}") -def resample_audio(input_path: str, output_path: str, format: str, sr: int, layout: str) -> None: +def resample_audio(input_path: str, output_path: str, codec: str, format: str, sr: int, layout: str) -> None: if not os.path.exists(input_path): return input_container = av.open(input_path) @@ -93,7 +93,7 @@ def resample_audio(input_path: str, output_path: str, format: str, sr: int, layo # Create a stream in the output container input_stream = input_container.streams.audio[0] - output_stream = output_container.add_stream(format, rate=sr, layout=layout) + output_stream = output_container.add_stream(codec, rate=sr, layout=layout) resampler = AudioResampler(format, layout, sr) @@ -101,9 +101,10 @@ def resample_audio(input_path: str, output_path: str, format: str, sr: int, layo for packet in input_container.demux(input_stream): for frame in packet.decode(): frame.pts = None # Clear presentation timestamp to avoid resampling issues - resampled = resampler.resample(frame) - for out_packet in output_stream.encode(resampled): - output_container.mux(out_packet) + out_frames = resampler.resample(frame) + for out_frame in out_frames: + for out_packet in output_stream.encode(out_frame): + output_container.mux(out_packet) for packet in output_stream.encode(): output_container.mux(packet) @@ -117,5 +118,13 @@ def resample_audio(input_path: str, output_path: str, format: str, sr: int, layo except Exception as e: print(f"Failed to remove the original file: {e}") +def get_audio_properties(input_path: str) -> Tuple: + container = av.open(input_path) + audio_stream = next(s for s in container.streams if s.type == 'audio') + channels = 1 if audio_stream.layout == 'mono' else 2 + rate = audio_stream.base_rate + container.close() + return channels, rate + def clean_path(path: str) -> Path: return Path(path.strip(' "\n')).resolve() diff --git a/infer/modules/uvr5/modules.py b/infer/modules/uvr5/modules.py index b06619d..850507e 100644 --- a/infer/modules/uvr5/modules.py +++ b/infer/modules/uvr5/modules.py @@ -4,8 +4,7 @@ import logging logger = logging.getLogger(__name__) -import av -from infer.lib.audio import resample_audio +from infer.lib.audio import resample_audio, get_audio_properties import torch from configs import Config @@ -47,11 +46,10 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format need_reformat = 1 done = 0 try: - container = av.open(inp_path) - audio_stream = next(s for s in container.streams if s.type == 'audio') + channels, rate = get_audio_properties(inp_path) # Check the audio stream's properties - if audio_stream.channels == 2 and audio_stream.rate == 44100: + if channels == 2 and rate == 44100: pre_fun._path_audio_(inp_path, save_root_ins, save_root_vocal, format0, is_hp3=is_hp3) need_reformat = 0 done = 1 @@ -63,7 +61,7 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format os.path.join(os.environ["TEMP"]), os.path.basename(inp_path), ) - resample_audio(inp_path, tmp_path, 'pcm_s16le', 44100, 'stereo') + resample_audio(inp_path, tmp_path, 'pcm_s16le', 's16', 44100, 'stereo') inp_path = tmp_path try: if done == 0: