mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-08 03:55:47 +08:00
feat: model embed author
This commit is contained in:
@@ -43,6 +43,7 @@ def save_small_model(ckpt, sr, if_f0, name, epoch, version, hps):
|
||||
opt["info"] = "%sepoch" % epoch
|
||||
opt["name"] = name
|
||||
opt["timestamp"] = int(time())
|
||||
if hps.author: opt["author"] = hps.author
|
||||
opt["sr"] = sr
|
||||
opt["f0"] = if_f0
|
||||
opt["version"] = version
|
||||
@@ -55,7 +56,7 @@ def save_small_model(ckpt, sr, if_f0, name, epoch, version, hps):
|
||||
return traceback.format_exc()
|
||||
|
||||
|
||||
def extract_small_model(path, name, sr, if_f0, info, version):
|
||||
def extract_small_model(path, name, author, sr, if_f0, info, version):
|
||||
try:
|
||||
ckpt = torch.load(path, map_location="cpu")
|
||||
if "model" in ckpt:
|
||||
@@ -178,6 +179,7 @@ def extract_small_model(path, name, sr, if_f0, info, version):
|
||||
opt["info"] = info
|
||||
opt["name"] = name
|
||||
opt["timestamp"] = int(time())
|
||||
if author: opt["author"] = author
|
||||
opt["version"] = version
|
||||
opt["sr"] = sr
|
||||
opt["f0"] = int(if_f0)
|
||||
@@ -214,6 +216,13 @@ def merge(path1, path2, alpha1, sr, f0, info, name, version):
|
||||
continue
|
||||
opt["weight"][key] = a[key]
|
||||
return opt
|
||||
|
||||
def authors(c1, c2):
|
||||
a1, a2 = c1.get("author", ""), c2.get("author", "")
|
||||
if a1 == a2: return a1
|
||||
if not a1: a1 = "Unknown"
|
||||
if not a2: a2 = "Unknown"
|
||||
return f"{a1} & {a2}"
|
||||
|
||||
ckpt1 = torch.load(path1, map_location="cpu")
|
||||
ckpt2 = torch.load(path2, map_location="cpu")
|
||||
@@ -242,8 +251,7 @@ def merge(path1, path2, alpha1, sr, f0, info, name, version):
|
||||
opt["weight"][key] = (
|
||||
alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
|
||||
).half()
|
||||
# except:
|
||||
# pdb.set_trace()
|
||||
author = authors(ckpt1, ckpt2)
|
||||
opt["config"] = cfg
|
||||
"""
|
||||
if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
|
||||
@@ -252,6 +260,7 @@ def merge(path1, path2, alpha1, sr, f0, info, name, version):
|
||||
"""
|
||||
opt["name"] = name
|
||||
opt["timestamp"] = int(time())
|
||||
if author: opt["author"] = author
|
||||
opt["sr"] = sr
|
||||
opt["f0"] = 1 if f0 == i18n("是") else 0
|
||||
opt["version"] = version
|
||||
|
||||
@@ -358,6 +358,9 @@ def get_hparams(init=True):
|
||||
required=True,
|
||||
help="if caching the dataset in GPU memory, 1 or 0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--author", type=str, default="", help="Model author"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
name = args.experiment_dir
|
||||
@@ -383,9 +386,10 @@ def get_hparams(init=True):
|
||||
hparams.save_every_weights = args.save_every_weights
|
||||
hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
|
||||
hparams.data.training_files = "%s/filelist.txt" % experiment_dir
|
||||
hparams.author = args.author
|
||||
return hparams
|
||||
|
||||
|
||||
"""
|
||||
def get_hparams_from_dir(model_dir):
|
||||
config_save_path = os.path.join(model_dir, "config.json")
|
||||
with open(config_save_path, "r") as f:
|
||||
@@ -429,7 +433,7 @@ def check_git_hash(model_dir):
|
||||
)
|
||||
else:
|
||||
open(path, "w").write(cur_hash)
|
||||
|
||||
"""
|
||||
|
||||
def get_logger(model_dir, filename="train.log"):
|
||||
global logger
|
||||
|
||||
Reference in New Issue
Block a user