1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

fix(train): save small model fail

This commit is contained in:
源文雨
2024-06-04 04:07:19 +09:00
parent 5df99f2f73
commit 481f14dd74
8 changed files with 71 additions and 53 deletions

View File

@@ -14,7 +14,7 @@ MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
"""
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
@@ -64,37 +64,8 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
# traceback.print_exc()
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
"""
# def load_checkpoint(checkpoint_path, model, optimizer=None):
# assert os.path.isfile(checkpoint_path)
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
# iteration = checkpoint_dict['iteration']
# learning_rate = checkpoint_dict['learning_rate']
# if optimizer is not None:
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
# # print(1111)
# saved_state_dict = checkpoint_dict['model']
# # print(1111)
#
# if hasattr(model, 'module'):
# state_dict = model.module.state_dict()
# else:
# state_dict = model.state_dict()
# new_state_dict= {}
# for k, v in state_dict.items():
# try:
# new_state_dict[k] = saved_state_dict[k]
# except:
# logger.info("%s is not in the checkpoint" % k)
# new_state_dict[k] = v
# if hasattr(model, 'module'):
# model.module.load_state_dict(new_state_dict)
# else:
# model.load_state_dict(new_state_dict)
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
# checkpoint_path, iteration))
# return model, optimizer, learning_rate, iteration
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
@@ -159,7 +130,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
checkpoint_path,
)
"""
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
@@ -184,7 +155,7 @@ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoin
},
checkpoint_path,
)
"""
def summarize(
writer,

View File

@@ -3,6 +3,7 @@ import sys
import logging
logger = logging.getLogger(__name__)
logging.getLogger("numba").setLevel(logging.WARNING)
now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))

View File

@@ -5,19 +5,12 @@ import pathlib
from scipy.fft import fft
from pybase16384 import encode_to_string, decode_from_string
if __name__ == "__main__":
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from configs import Config, singleton_variable
from infer.lib.audio import load_audio
from configs import CPUConfig, singleton_variable
from .pipeline import Pipeline
from .utils import load_hubert
from infer.lib.audio import load_audio
class TorchSeedContext:
def __init__(self, seed):
@@ -102,8 +95,9 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
audio_max = np.abs(audio).max() / 0.95
if audio_max > 1:
np.divide(audio, audio_max, audio)
hbt = load_hubert(config.device, config.is_half)
audio_opt = pipeline.pipeline(
load_hubert(config.device, config.is_half),
hbt,
net_g,
0,
audio,
@@ -120,6 +114,7 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
version,
0.33,
)
del hbt
opt_len = len(audio_opt)
diff = 48000 - opt_len
n = diff // 2
@@ -141,7 +136,8 @@ def model_hash_ckpt(cpt):
SynthesizerTrnMs768NSFsid_nono,
)
config = Config()
config = CPUConfig()
with TorchSeedContext(114514):
tgt_sr = cpt["config"][-1]
if_f0 = cpt.get("f0", 1)
@@ -167,7 +163,7 @@ def model_hash_ckpt(cpt):
h = model_hash(config, tgt_sr, net_g, if_f0, version)
del net_g
del net_g
return h
@@ -217,4 +213,9 @@ def hash_similarity(h1: str, h2: str) -> float:
def hash_id(h: str) -> str:
return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1]
d = decode_from_string(h)
if len(d) != half_hash_len * 2:
return "invalid hash length"
return encode_to_string(
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])

View File

@@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
from functools import lru_cache
from time import time
import faiss
import librosa
import numpy as np
import parselmouth
@@ -330,7 +331,6 @@ class Pipeline(object):
and os.path.exists(file_index)
and index_rate != 0
):
if "faiss" not in sys.modules: import faiss
try:
index = faiss.read_index(file_index)
big_npy = index.reconstruct_n(0, index.ntotal)

View File

@@ -2,8 +2,6 @@ import os
from fairseq import checkpoint_utils
from configs import singleton_variable
def get_index_path_from_model(sid):
return next(
@@ -22,7 +20,6 @@ def get_index_path_from_model(sid):
)
@singleton_variable
def load_hubert(device, is_half):
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"],