From 9d699b1d9910d97071893bfc39b6945fc628fd4c Mon Sep 17 00:00:00 2001 From: Alex Murkoff <413x1nkp@gmail.com> Date: Tue, 11 Jun 2024 11:02:54 +0700 Subject: [PATCH] perf: use hashing to determine the format in infer/lib/audio.py (#26) --- infer/lib/audio.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/infer/lib/audio.py b/infer/lib/audio.py index 89799ab..e406ab8 100644 --- a/infer/lib/audio.py +++ b/infer/lib/audio.py @@ -1,18 +1,24 @@ -import platform +from io import BufferedWriter, BytesIO +from pathlib import Path +from typing import Dict import ffmpeg import numpy as np import av +video_format_dict: Dict[str, str] = { + "m4a": "mp4", +} -def wav2(i, o, format): +audio_format_dict: Dict[str, str] = { + "ogg": "libvorbis", + "mp4": "aac", +} + +def wav2(i: BytesIO, o: BufferedWriter, format: str): inp = av.open(i, "r") - if format == "m4a": - format = "mp4" + format = video_format_dict.get(format, format) out = av.open(o, "w", format=format) - if format == "ogg": - format = "libvorbis" - if format == "mp4": - format = "aac" + format = audio_format_dict.get(format, format) ostream = out.add_stream(format) @@ -27,12 +33,15 @@ def wav2(i, o, format): inp.close() -def load_audio(file, sr): +def load_audio(file: str, sr: int) -> np.ndarray: + if not Path(file).exists(): + raise FileNotFoundError(f"File not found: {file}") + try: # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 # This launches a subprocess to decode audio while down-mixing and resampling as necessary. # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. - file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车 + file = str(clean_path(file)) # 防止小白拷路径头尾带了空格和"和回车 out, _ = ( ffmpeg.input(file, threads=0) .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) @@ -44,7 +53,5 @@ def load_audio(file, sr): return np.frombuffer(out, np.float32).flatten() -def clean_path(path_str): - if platform.system() == "Windows": - path_str = path_str.replace("/", "\\") - return path_str.strip(" ").strip('"').strip("\n").strip('"').strip(" ") +def clean_path(path: str) -> Path: + return Path(path.strip(' "\n')).resolve()