mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 09:10:25 +08:00
fix(train): save small model fail
This commit is contained in:
@@ -14,7 +14,7 @@ MATPLOTLIB_FLAG = False
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
logger = logging
|
||||
|
||||
|
||||
"""
|
||||
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
@@ -64,37 +64,8 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
||||
# traceback.print_exc()
|
||||
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
|
||||
return model, optimizer, learning_rate, iteration
|
||||
"""
|
||||
|
||||
|
||||
# def load_checkpoint(checkpoint_path, model, optimizer=None):
|
||||
# assert os.path.isfile(checkpoint_path)
|
||||
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
# iteration = checkpoint_dict['iteration']
|
||||
# learning_rate = checkpoint_dict['learning_rate']
|
||||
# if optimizer is not None:
|
||||
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
# # print(1111)
|
||||
# saved_state_dict = checkpoint_dict['model']
|
||||
# # print(1111)
|
||||
#
|
||||
# if hasattr(model, 'module'):
|
||||
# state_dict = model.module.state_dict()
|
||||
# else:
|
||||
# state_dict = model.state_dict()
|
||||
# new_state_dict= {}
|
||||
# for k, v in state_dict.items():
|
||||
# try:
|
||||
# new_state_dict[k] = saved_state_dict[k]
|
||||
# except:
|
||||
# logger.info("%s is not in the checkpoint" % k)
|
||||
# new_state_dict[k] = v
|
||||
# if hasattr(model, 'module'):
|
||||
# model.module.load_state_dict(new_state_dict)
|
||||
# else:
|
||||
# model.load_state_dict(new_state_dict)
|
||||
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
|
||||
# checkpoint_path, iteration))
|
||||
# return model, optimizer, learning_rate, iteration
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
@@ -159,7 +130,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
|
||||
logger.info(
|
||||
"Saving model and optimizer state at epoch {} to {}".format(
|
||||
@@ -184,7 +155,7 @@ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoin
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def summarize(
|
||||
writer,
|
||||
|
||||
@@ -3,6 +3,7 @@ import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(os.path.join(now_dir))
|
||||
|
||||
@@ -5,19 +5,12 @@ import pathlib
|
||||
from scipy.fft import fft
|
||||
from pybase16384 import encode_to_string, decode_from_string
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os, sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
from configs import Config, singleton_variable
|
||||
from infer.lib.audio import load_audio
|
||||
from configs import CPUConfig, singleton_variable
|
||||
|
||||
from .pipeline import Pipeline
|
||||
from .utils import load_hubert
|
||||
|
||||
from infer.lib.audio import load_audio
|
||||
|
||||
|
||||
class TorchSeedContext:
|
||||
def __init__(self, seed):
|
||||
@@ -102,8 +95,9 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
||||
audio_max = np.abs(audio).max() / 0.95
|
||||
if audio_max > 1:
|
||||
np.divide(audio, audio_max, audio)
|
||||
hbt = load_hubert(config.device, config.is_half)
|
||||
audio_opt = pipeline.pipeline(
|
||||
load_hubert(config.device, config.is_half),
|
||||
hbt,
|
||||
net_g,
|
||||
0,
|
||||
audio,
|
||||
@@ -120,6 +114,7 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
||||
version,
|
||||
0.33,
|
||||
)
|
||||
del hbt
|
||||
opt_len = len(audio_opt)
|
||||
diff = 48000 - opt_len
|
||||
n = diff // 2
|
||||
@@ -141,7 +136,8 @@ def model_hash_ckpt(cpt):
|
||||
SynthesizerTrnMs768NSFsid_nono,
|
||||
)
|
||||
|
||||
config = Config()
|
||||
config = CPUConfig()
|
||||
|
||||
with TorchSeedContext(114514):
|
||||
tgt_sr = cpt["config"][-1]
|
||||
if_f0 = cpt.get("f0", 1)
|
||||
@@ -167,7 +163,7 @@ def model_hash_ckpt(cpt):
|
||||
|
||||
h = model_hash(config, tgt_sr, net_g, if_f0, version)
|
||||
|
||||
del net_g
|
||||
del net_g
|
||||
|
||||
return h
|
||||
|
||||
@@ -217,4 +213,9 @@ def hash_similarity(h1: str, h2: str) -> float:
|
||||
|
||||
|
||||
def hash_id(h: str) -> str:
|
||||
return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1]
|
||||
d = decode_from_string(h)
|
||||
if len(d) != half_hash_len * 2:
|
||||
return "invalid hash length"
|
||||
return encode_to_string(
|
||||
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
|
||||
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])
|
||||
|
||||
@@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
from functools import lru_cache
|
||||
from time import time
|
||||
|
||||
import faiss
|
||||
import librosa
|
||||
import numpy as np
|
||||
import parselmouth
|
||||
@@ -330,7 +331,6 @@ class Pipeline(object):
|
||||
and os.path.exists(file_index)
|
||||
and index_rate != 0
|
||||
):
|
||||
if "faiss" not in sys.modules: import faiss
|
||||
try:
|
||||
index = faiss.read_index(file_index)
|
||||
big_npy = index.reconstruct_n(0, index.ntotal)
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
|
||||
from configs import singleton_variable
|
||||
|
||||
|
||||
def get_index_path_from_model(sid):
|
||||
return next(
|
||||
@@ -22,7 +20,6 @@ def get_index_path_from_model(sid):
|
||||
)
|
||||
|
||||
|
||||
@singleton_variable
|
||||
def load_hubert(device, is_half):
|
||||
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
||||
["assets/hubert/hubert_base.pt"],
|
||||
|
||||
Reference in New Issue
Block a user