1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-07 02:00:25 +08:00

fix(train): save small model fail

This commit is contained in:
源文雨
2024-06-04 04:07:19 +09:00
parent 5df99f2f73
commit 481f14dd74
8 changed files with 71 additions and 53 deletions

View File

@@ -3,6 +3,7 @@ import sys
import logging
logger = logging.getLogger(__name__)
logging.getLogger("numba").setLevel(logging.WARNING)
now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))

View File

@@ -5,19 +5,12 @@ import pathlib
from scipy.fft import fft
from pybase16384 import encode_to_string, decode_from_string
if __name__ == "__main__":
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from configs import Config, singleton_variable
from infer.lib.audio import load_audio
from configs import CPUConfig, singleton_variable
from .pipeline import Pipeline
from .utils import load_hubert
from infer.lib.audio import load_audio
class TorchSeedContext:
def __init__(self, seed):
@@ -102,8 +95,9 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
audio_max = np.abs(audio).max() / 0.95
if audio_max > 1:
np.divide(audio, audio_max, audio)
hbt = load_hubert(config.device, config.is_half)
audio_opt = pipeline.pipeline(
load_hubert(config.device, config.is_half),
hbt,
net_g,
0,
audio,
@@ -120,6 +114,7 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
version,
0.33,
)
del hbt
opt_len = len(audio_opt)
diff = 48000 - opt_len
n = diff // 2
@@ -141,7 +136,8 @@ def model_hash_ckpt(cpt):
SynthesizerTrnMs768NSFsid_nono,
)
config = Config()
config = CPUConfig()
with TorchSeedContext(114514):
tgt_sr = cpt["config"][-1]
if_f0 = cpt.get("f0", 1)
@@ -167,7 +163,7 @@ def model_hash_ckpt(cpt):
h = model_hash(config, tgt_sr, net_g, if_f0, version)
del net_g
del net_g
return h
@@ -217,4 +213,9 @@ def hash_similarity(h1: str, h2: str) -> float:
def hash_id(h: str) -> str:
return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1]
d = decode_from_string(h)
if len(d) != half_hash_len * 2:
return "invalid hash length"
return encode_to_string(
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])

View File

@@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
from functools import lru_cache
from time import time
import faiss
import librosa
import numpy as np
import parselmouth
@@ -330,7 +331,6 @@ class Pipeline(object):
and os.path.exists(file_index)
and index_rate != 0
):
if "faiss" not in sys.modules: import faiss
try:
index = faiss.read_index(file_index)
big_npy = index.reconstruct_n(0, index.ntotal)

View File

@@ -2,8 +2,6 @@ import os
from fairseq import checkpoint_utils
from configs import singleton_variable
def get_index_path_from_model(sid):
return next(
@@ -22,7 +20,6 @@ def get_index_path_from_model(sid):
)
@singleton_variable
def load_hubert(device, is_half):
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"],