mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-07 02:00:25 +08:00
fix(train): save small model fail
This commit is contained in:
@@ -3,6 +3,7 @@ import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(os.path.join(now_dir))
|
||||
|
||||
@@ -5,19 +5,12 @@ import pathlib
|
||||
from scipy.fft import fft
|
||||
from pybase16384 import encode_to_string, decode_from_string
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os, sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
from configs import Config, singleton_variable
|
||||
from infer.lib.audio import load_audio
|
||||
from configs import CPUConfig, singleton_variable
|
||||
|
||||
from .pipeline import Pipeline
|
||||
from .utils import load_hubert
|
||||
|
||||
from infer.lib.audio import load_audio
|
||||
|
||||
|
||||
class TorchSeedContext:
|
||||
def __init__(self, seed):
|
||||
@@ -102,8 +95,9 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
||||
audio_max = np.abs(audio).max() / 0.95
|
||||
if audio_max > 1:
|
||||
np.divide(audio, audio_max, audio)
|
||||
hbt = load_hubert(config.device, config.is_half)
|
||||
audio_opt = pipeline.pipeline(
|
||||
load_hubert(config.device, config.is_half),
|
||||
hbt,
|
||||
net_g,
|
||||
0,
|
||||
audio,
|
||||
@@ -120,6 +114,7 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
||||
version,
|
||||
0.33,
|
||||
)
|
||||
del hbt
|
||||
opt_len = len(audio_opt)
|
||||
diff = 48000 - opt_len
|
||||
n = diff // 2
|
||||
@@ -141,7 +136,8 @@ def model_hash_ckpt(cpt):
|
||||
SynthesizerTrnMs768NSFsid_nono,
|
||||
)
|
||||
|
||||
config = Config()
|
||||
config = CPUConfig()
|
||||
|
||||
with TorchSeedContext(114514):
|
||||
tgt_sr = cpt["config"][-1]
|
||||
if_f0 = cpt.get("f0", 1)
|
||||
@@ -167,7 +163,7 @@ def model_hash_ckpt(cpt):
|
||||
|
||||
h = model_hash(config, tgt_sr, net_g, if_f0, version)
|
||||
|
||||
del net_g
|
||||
del net_g
|
||||
|
||||
return h
|
||||
|
||||
@@ -217,4 +213,9 @@ def hash_similarity(h1: str, h2: str) -> float:
|
||||
|
||||
|
||||
def hash_id(h: str) -> str:
|
||||
return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1]
|
||||
d = decode_from_string(h)
|
||||
if len(d) != half_hash_len * 2:
|
||||
return "invalid hash length"
|
||||
return encode_to_string(
|
||||
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
|
||||
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])
|
||||
|
||||
@@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
from functools import lru_cache
|
||||
from time import time
|
||||
|
||||
import faiss
|
||||
import librosa
|
||||
import numpy as np
|
||||
import parselmouth
|
||||
@@ -330,7 +331,6 @@ class Pipeline(object):
|
||||
and os.path.exists(file_index)
|
||||
and index_rate != 0
|
||||
):
|
||||
if "faiss" not in sys.modules: import faiss
|
||||
try:
|
||||
index = faiss.read_index(file_index)
|
||||
big_npy = index.reconstruct_n(0, index.ntotal)
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
|
||||
from configs import singleton_variable
|
||||
|
||||
|
||||
def get_index_path_from_model(sid):
|
||||
return next(
|
||||
@@ -22,7 +20,6 @@ def get_index_path_from_model(sid):
|
||||
)
|
||||
|
||||
|
||||
@singleton_variable
|
||||
def load_hubert(device, is_half):
|
||||
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
||||
["assets/hubert/hubert_base.pt"],
|
||||
|
||||
Reference in New Issue
Block a user