1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 17:20:25 +08:00

feat(infer): add model hash identification

and optimize infer-web ui
This commit is contained in:
源文雨
2024-06-02 22:47:52 +09:00
parent 7e48279c6c
commit b9ad0258ae
13 changed files with 327 additions and 105 deletions

View File

@@ -74,7 +74,7 @@ from infer.lib.train.losses import (
kl_loss,
)
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from infer.lib.train.process_ckpt import savee
from infer.lib.train.process_ckpt import save_small_model
global_step = 0
@@ -602,7 +602,7 @@ def train_and_evaluate(
% (
hps.name,
epoch,
savee(
save_small_model(
ckpt,
hps.sample_rate,
hps.if_f0,
@@ -626,7 +626,7 @@ def train_and_evaluate(
logger.info(
"saving final ckpt:%s"
% (
savee(
save_small_model(
ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
)
)

View File

@@ -1,3 +1,5 @@
from .pipeline import Pipeline
from .modules import VC
from .utils import get_index_path_from_model, load_hubert
from .info import show_info
from .hash import model_hash_ckpt, hash_id

170
infer/modules/vc/hash.py Normal file
View File

@@ -0,0 +1,170 @@
import numpy as np
import torch
import hashlib
import pathlib
from scipy.fft import fft
from pybase16384 import encode_to_string, decode_from_string
if __name__ == "__main__":
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from configs.config import Config, singleton_variable
from .pipeline import Pipeline
from .utils import load_hubert
from infer.lib.audio import load_audio
class TorchSeedContext:
def __init__(self, seed):
self.seed = seed
self.state = None
def __enter__(self):
self.state = torch.random.get_rng_state()
torch.manual_seed(self.seed)
def __exit__(self, type, value, traceback):
torch.random.set_rng_state(self.state)
half_hash_len = 512
expand_factor = 65536*8
@singleton_variable
def original_audio_time_minus():
__original_audio = load_audio(str(pathlib.Path(__file__).parent / "lgdsng.mp3"), 16000)
np.divide(__original_audio, np.abs(__original_audio).max(), __original_audio)
return -__original_audio
@singleton_variable
def original_audio_freq_minus():
__original_audio = load_audio(str(pathlib.Path(__file__).parent / "lgdsng.mp3"), 16000)
np.divide(__original_audio, np.abs(__original_audio).max(), __original_audio)
__original_audio = fft(__original_audio)
return -__original_audio
def _cut_u16(n):
if n > 16384: n = 16384 + 16384*(1-np.exp((16384-n)/expand_factor))
elif n < -16384: n = -16384 - 16384*(1-np.exp((n+16384)/expand_factor))
return n
# wave_hash will change time_field, use carefully
def wave_hash(time_field):
np.divide(time_field, np.abs(time_field).max(), time_field)
if len(time_field) != 48000:
raise Exception("time not hashable")
freq_field = fft(time_field)
if len(freq_field) != 48000:
raise Exception("freq not hashable")
np.add(time_field, original_audio_time_minus(), out=time_field)
np.add(freq_field, original_audio_freq_minus(), out=freq_field)
hash = np.zeros(half_hash_len//2*2, dtype='>i2')
d = 375 * 512 // half_hash_len
for i in range(half_hash_len//4):
a = i*2
b = a+1
x = a + half_hash_len//2
y = x+1
s = np.average(freq_field[i*d:(i+1)*d])
hash[a] = np.int16(_cut_u16(round(32768*np.real(s))))
hash[b] = np.int16(_cut_u16(round(32768*np.imag(s))))
hash[x] = np.int16(_cut_u16(round(32768*np.sum(time_field[i*d:i*d+d//2]))))
hash[y] = np.int16(_cut_u16(round(32768*np.sum(time_field[i*d+d//2:(i+1)*d]))))
return encode_to_string(hash.tobytes())
def audio_hash(file):
return wave_hash(load_audio(file, 16000))
def model_hash(config, tgt_sr, net_g, if_f0, version):
pipeline = Pipeline(tgt_sr, config)
audio = load_audio(str(pathlib.Path(__file__).parent / "lgdsng.mp3"), 16000)
audio_max = np.abs(audio).max() / 0.95
if audio_max > 1:
np.divide(audio, audio_max, audio)
audio_opt = pipeline.pipeline(load_hubert(config.device, config.is_half), net_g, 0, audio,
[0, 0, 0], 6, "rmvpe", "", 0, if_f0, 3, tgt_sr, 16000, 0.25,
version, 0.33)
opt_len = len(audio_opt)
diff = 48000 - opt_len
n = diff//2
if n > 0:
audio_opt = np.pad(audio_opt, (n, n))
elif n < 0:
n = -n
audio_opt = audio_opt[n:-n]
h = wave_hash(audio_opt)
del pipeline, audio, audio_opt
return h
def model_hash_ckpt(cpt):
from infer.lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)
config = Config()
with TorchSeedContext(114514):
tgt_sr = cpt["config"][-1]
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
synthesizer_class = {
("v1", 1): SynthesizerTrnMs256NSFsid,
("v1", 0): SynthesizerTrnMs256NSFsid_nono,
("v2", 1): SynthesizerTrnMs768NSFsid,
("v2", 0): SynthesizerTrnMs768NSFsid_nono,
}
net_g = synthesizer_class.get(
(version, if_f0), SynthesizerTrnMs256NSFsid
)(*cpt["config"], is_half=config.is_half)
del net_g.enc_q
net_g.load_state_dict(cpt["weight"], strict=False)
net_g.eval().to(config.device)
if config.is_half:
net_g = net_g.half()
else:
net_g = net_g.float()
h = model_hash(config, tgt_sr, net_g, if_f0, version)
del net_g
return h
def model_hash_from(path):
cpt = torch.load(path, map_location="cpu")
h = model_hash_ckpt(cpt)
del cpt
return h
def _extend_difference(n, a, b):
if n < a: n = a
elif n > b: n = b
n -= a
n /= (b-a)
return n
def hash_similarity(h1: str, h2: str) -> int:
h1b, h2b = decode_from_string(h1), decode_from_string(h2)
if len(h1b) != half_hash_len*2 or len(h2b) != half_hash_len*2:
raise Exception("invalid hash length")
h1n, h2n = np.frombuffer(h1b, dtype='>i2'), np.frombuffer(h2b, dtype='>i2')
d = 0
for i in range(half_hash_len//4):
a = i*2
b = a+1
ax = complex(h1n[a], h1n[b])
bx = complex(h2n[a], h2n[b])
if abs(ax) == 0 or abs(bx) == 0: continue
d += np.abs(ax - bx)
frac = (np.linalg.norm(h1n) * np.linalg.norm(h2n))
cosine = np.dot(h1n.astype(np.float32), h2n.astype(np.float32)) / frac if frac != 0 else 1.0
distance = _extend_difference(np.exp(-d/expand_factor), 0.5, 1.0)
return round((abs(cosine) + distance) / 2, 6)
def hash_id(h: str) -> str:
return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1]

50
infer/modules/vc/info.py Normal file
View File

@@ -0,0 +1,50 @@
import traceback
from i18n.i18n import I18nAuto
from datetime import datetime
import torch
from .hash import model_hash_ckpt, hash_id
i18n = I18nAuto()
def show_model_info(cpt, show_long_id=False):
try:
h = model_hash_ckpt(cpt)
id = hash_id(h)
idread = cpt.get("id", "None")
hread = cpt.get("hash", "None")
if id != idread:
id += "("+i18n("实际计算")+"), "+idread+"("+i18n("从模型中读取")+")"
if not show_long_id: h = i18n("不显示")
elif h != hread:
h += "("+i18n("实际计算")+"), "+hread+"("+i18n("从模型中读取")+")"
txt = f"""{i18n("模型名")}: %s
{i18n("封装时间")}: %s
{i18n("信息")}: %s
{i18n("采样率")}: %s
{i18n("音高引导(f0)")}: %s
{i18n("版本")}: %s
{i18n("ID(短)")}: %s
{i18n("ID(长)")}: %s""" % (
cpt.get("name", "None"),
datetime.fromtimestamp(float(cpt.get("timestamp", 0))),
cpt.get("info", "None"),
cpt.get("sr", "None"),
i18n("") if cpt.get("f0", 0) == 1 else i18n(""),
cpt.get("version", "None"),
id, h
)
except:
txt = traceback.format_exc()
return txt
def show_info(path):
try:
a = torch.load(path, map_location="cpu")
txt = show_model_info(a, show_long_id=True)
del a
except:
txt = traceback.format_exc()
return txt

BIN
infer/modules/vc/lgdsng.mp3 Normal file

Binary file not shown.

View File

@@ -16,7 +16,7 @@ from infer.lib.infer_pack.models import (
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)
from .info import show_model_info
from .pipeline import Pipeline
from .utils import get_index_path_from_model, load_hubert
@@ -136,6 +136,7 @@ class VC:
to_return_protect1,
index,
index,
show_model_info(self.cpt)
)
if to_return_protect
else {"visible": True, "maximum": n_spk, "__type__": "update"}
@@ -158,6 +159,8 @@ class VC:
):
if input_audio_path is None:
return "You need to upload an audio", None
elif hasattr(input_audio_path, "name"):
input_audio_path = str(input_audio_path.name)
f0_up_key = int(f0_up_key)
try:
audio = load_audio(input_audio_path, 16000)
@@ -170,6 +173,7 @@ class VC:
self.hubert_model = load_hubert(self.config.device, self.config.is_half)
if file_index:
if hasattr(file_index, "name"): file_index = str(file_index.name)
file_index = (
file_index.strip(" ")
.strip('"')
@@ -207,12 +211,12 @@ class VC:
else:
tgt_sr = self.tgt_sr
index_info = (
"Index:\n%s." % file_index
"Index: %s." % file_index
if os.path.exists(file_index)
else "Index not used."
)
return (
"Success.\n%s\nTime:\nnpy: %.2fs, f0: %.2fs, infer: %.2fs."
"Success.\n%s\nTime: npy: %.2fs, f0: %.2fs, infer: %.2fs."
% (index_info, *times),
(tgt_sr, audio_opt),
)