1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-09 04:29:50 +08:00

optimize: some training optimizations (#95)

* optimzie(train&uvr5): rm sf & simp. AudioPre

* fix(audio): too many mallocs

* feat(audio): load_audio support stereo

* fix(audio): float32 wav saving

* fix(train): missing ckpt var
This commit is contained in:
源文雨
2024-11-28 03:20:14 +09:00
committed by GitHub
parent f4644ec1ec
commit a8783c6639
19 changed files with 163 additions and 433 deletions

View File

@@ -1,11 +1,16 @@
from io import BufferedWriter, BytesIO
from pathlib import Path
from typing import Dict, Tuple
from typing import Dict, Tuple, Optional, Union, List
import os
import math
import wave
import numpy as np
from numba import jit
import av
from av.audio.resampler import AudioResampler
from av.audio.frame import AudioFrame
import scipy.io.wavfile as wavfile
video_format_dict: Dict[str, str] = {
"m4a": "mp4",
@@ -17,6 +22,29 @@ audio_format_dict: Dict[str, str] = {
}
@jit(nopython=True)
def float_to_int16(audio: np.ndarray) -> np.ndarray:
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
am = 32767 * 32768 // am
return np.multiply(audio, am).astype(np.int16)
def float_np_array_to_wav_buf(wav: np.ndarray, sr: int, f32=False) -> BytesIO:
buf = BytesIO()
if f32:
wavfile.write(buf, sr, wav.astype(np.float32))
else:
with wave.open(buf, "wb") as wf:
wf.setnchannels(2 if len(wav.shape) > 1 else 1)
wf.setsampwidth(2) # Sample width in bytes
wf.setframerate(sr) # Sample rate in Hz
wf.writeframes(float_to_int16(wav.T if len(wav.shape) > 1 else wav))
buf.seek(0, 0)
return buf
def save_audio(path: str, audio: np.ndarray, sr: int, f32=False):
with open(path, "wb") as f:
f.write(float_np_array_to_wav_buf(audio, sr, f32).getbuffer())
def wav2(i: BytesIO, o: BufferedWriter, format: str):
inp = av.open(i, "r")
format = video_format_dict.get(format, format)
@@ -36,43 +64,72 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
inp.close()
def load_audio(file: str, sr: int) -> np.ndarray:
if not Path(file).exists():
def load_audio(
file: Union[str, BytesIO, Path],
sr: Optional[int]=None,
format: Optional[str]=None,
mono=True
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()):
raise FileNotFoundError(f"File not found: {file}")
rate = 0
try:
container = av.open(file)
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
container = av.open(file, format=format)
audio_stream = next(s for s in container.streams if s.type == "audio")
channels = 1 if audio_stream.layout == "mono" else 2
container.seek(0)
resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) if sr is not None else None
# Estimated maximum total number of samples to pre-allocate the array
# AV stores length in microseconds by default
estimated_total_samples = int(container.duration * sr // 1_000_000)
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
# Estimated maximum total number of samples to pre-allocate the array
# AV stores length in microseconds by default
estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000
decoded_audio = np.zeros(estimated_total_samples + 1 if channels == 1 else (channels, estimated_total_samples + 1), dtype=np.float32)
offset = 0
for frame in container.decode(audio=0):
frame.pts = None # Clear presentation timestamp to avoid resampling issues
resampled_frames = resampler.resample(frame)
offset = 0
def process_packet(packet: List[AudioFrame]):
frames_data = []
rate = 0
for frame in packet:
frame.pts = None # 清除时间戳,避免重新采样问题
resampled_frames = resampler.resample(frame) if resampler is not None else [frame]
for resampled_frame in resampled_frames:
frame_data = resampled_frame.to_ndarray()[0]
end_index = offset + len(frame_data)
frame_data = resampled_frame.to_ndarray()
rate = resampled_frame.rate
frames_data.append(frame_data)
return (rate, frames_data)
# Check if decoded_audio has enough space, and resize if necessary
if end_index > decoded_audio.shape[0]:
decoded_audio = np.resize(decoded_audio, end_index + 1)
def frame_iter(container):
for p in container.demux(container.streams.audio[0]):
yield p.decode()
decoded_audio[offset:end_index] = frame_data
offset += len(frame_data)
for r, frames_data in map(process_packet, frame_iter(container)):
if not rate: rate = r
for frame_data in frames_data:
end_index = offset + len(frame_data[0])
# Truncate the array to the actual size
decoded_audio = decoded_audio[:offset]
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
# 检查 decoded_audio 是否有足够的空间,并在必要时调整大小
if end_index > decoded_audio.shape[1]:
decoded_audio = np.resize(decoded_audio, (decoded_audio.shape[0], end_index*4))
return decoded_audio
np.copyto(decoded_audio[..., offset:end_index], frame_data)
offset += len(frame_data[0])
# Truncate the array to the actual size
decoded_audio = decoded_audio[..., :offset]
if mono and decoded_audio.shape[0] > 1:
decoded_audio = decoded_audio.mean(0)
if sr is not None:
return decoded_audio
return decoded_audio, rate
def downsample_audio(input_path: str, output_path: str, format: str) -> None:
def downsample_audio(input_path: str, output_path: str, format: str, br=128_000) -> None:
"""
default to 128kb/s (equivalent to -q:a 2)
"""
if not os.path.exists(input_path):
return
@@ -83,7 +140,7 @@ def downsample_audio(input_path: str, output_path: str, format: str) -> None:
input_stream = input_container.streams.audio[0]
output_stream = output_container.add_stream(format)
output_stream.bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2)
output_stream.bit_rate = br
# Copy packets from the input file to the output file
for packet in input_container.demux(input_stream):
@@ -141,7 +198,7 @@ def resample_audio(
print(f"Failed to remove the original file: {e}")
def get_audio_properties(input_path: str) -> Tuple:
def get_audio_properties(input_path: str) -> Tuple[int, int]:
container = av.open(input_path)
audio_stream = next(s for s in container.streams if s.type == "audio")
channels = 1 if audio_stream.layout == "mono" else 2

View File

@@ -183,8 +183,7 @@ def main():
import os.path
from argparse import ArgumentParser
import librosa
import soundfile
from .audio import load_audio, save_audio
parser = ArgumentParser()
parser.add_argument("audio", type=str, help="The audio to be sliced")
@@ -230,7 +229,7 @@ def main():
out = args.out
if out is None:
out = os.path.dirname(os.path.abspath(args.audio))
audio, sr = librosa.load(args.audio, sr=None, mono=False)
audio, sr = load_audio(args.audio, mono=False)
slicer = Slicer(
sr=sr,
threshold=args.db_thresh,
@@ -245,15 +244,11 @@ def main():
for i, chunk in enumerate(chunks):
if len(chunk.shape) > 1:
chunk = chunk.T
soundfile.write(
os.path.join(
out,
f"%s_%d.wav"
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
),
chunk,
sr,
)
save_audio(os.path.join(
out,
f"%s_%d.wav"
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
), chunk, sr)
if __name__ == "__main__":

View File

@@ -16,62 +16,12 @@ MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
"""
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
##################
def go(model, bkey):
saved_state_dict = checkpoint_dict[bkey]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
logger.warning(
"shape-%s-mismatch. need: %s, get: %s",
k,
state_dict[k].shape,
saved_state_dict[k].shape,
) #
raise KeyError
except:
# logger.info(traceback.format_exc())
logger.info("%s is not in the checkpoint", k) # pretrain缺失的
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
return model
go(combd, "combd")
model = go(sbd, "sbd")
#############
logger.info("Loaded model weights")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None and load_opt == 1
): ###加载不了如果是空的的话重新初始化可能还会影响lr时间表的更新因此在train文件最外围catch
# try:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
# except:
# traceback.print_exc()
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
"""
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
saved_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model"]
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
@@ -132,34 +82,6 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
)
"""
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(combd, "module"):
state_dict_combd = combd.module.state_dict()
else:
state_dict_combd = combd.state_dict()
if hasattr(sbd, "module"):
state_dict_sbd = sbd.module.state_dict()
else:
state_dict_sbd = sbd.state_dict()
torch.save(
{
"combd": state_dict_combd,
"sbd": state_dict_sbd,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
"""
def summarize(
writer,
global_step,
@@ -366,53 +288,6 @@ def get_hparams(init=True):
return hparams
"""
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warning(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warning(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
"""
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))