From ef9c8eb656b5b768c8092a9795589bf331395e53 Mon Sep 17 00:00:00 2001 From: Yongkun Li <44155313+Nativu5@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:26:01 +0800 Subject: [PATCH] fix: Add weight whitelist support for torch 2.6 (#110) --- configs/__init__.py | 2 +- configs/config.py | 35 ++++++++++++++------ i18n/i18n.py | 5 ++- infer/modules/train/extract_f0_print.py | 4 +++ infer/modules/train/extract_feature_print.py | 4 +++ infer/modules/train/preprocess.py | 2 ++ infer/modules/vc/hash.py | 14 ++++---- 7 files changed, 45 insertions(+), 21 deletions(-) diff --git a/configs/__init__.py b/configs/__init__.py index e9ab5e6..47d9863 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -1 +1 @@ -from .config import singleton_variable, Config, CPUConfig +from .config import Singleton, Config, CPUConfig diff --git a/configs/config.py b/configs/config.py index 0e8057c..3fb3410 100644 --- a/configs/config.py +++ b/configs/config.py @@ -22,18 +22,16 @@ version_config_list = [ ] -def singleton_variable(func): - def wrapper(*args, **kwargs): - if wrapper.instance is None: - wrapper.instance = func(*args, **kwargs) - return wrapper.instance +class Singleton(type): + _instances = {} - wrapper.instance = None - return wrapper + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] -@singleton_variable -class Config: +class Config(metaclass=Singleton): def __init__(self): self.device = "cuda:0" self.is_half = True @@ -129,6 +127,16 @@ class Config: else: return False + @staticmethod + def use_insecure_load(): + try: + from fairseq.data.dictionary import Dictionary + + torch.serialization.add_safe_globals([Dictionary]) + logging.warning("Using insecure weight loading for fairseq dictionary") + except AttributeError: + pass + def use_fp32_config(self): for config_file in version_config_list: self.json_config[config_file]["train"]["fp16_run"] = False @@ -210,15 +218,20 @@ class Config: else: if self.instead: logger.info(f"Use {self.instead} instead") + logger.info( "Half-precision floating-point: %s, device: %s" % (self.is_half, self.device) ) + + # Check if the pytorch is 2.6 or higher + if tuple(map(int, torch.__version__.split("+")[0].split("."))) >= (2, 6, 0): + self.use_insecure_load() + return x_pad, x_query, x_center, x_max -@singleton_variable -class CPUConfig: +class CPUConfig(metaclass=Singleton): def __init__(self): self.device = "cpu" self.is_half = False diff --git a/i18n/i18n.py b/i18n/i18n.py index ea555da..6590c81 100644 --- a/i18n/i18n.py +++ b/i18n/i18n.py @@ -1,7 +1,7 @@ import json import locale import os -from configs import singleton_variable +from configs import Singleton def load_language_list(language): @@ -10,8 +10,7 @@ def load_language_list(language): return language_list -@singleton_variable -class I18nAuto: +class I18nAuto(metaclass=Singleton): def __init__(self, language=None): if language in ["Auto", None]: language = locale.getdefaultlocale( diff --git a/infer/modules/train/extract_f0_print.py b/infer/modules/train/extract_f0_print.py index 24fcb95..d6f8714 100644 --- a/infer/modules/train/extract_f0_print.py +++ b/infer/modules/train/extract_f0_print.py @@ -96,6 +96,10 @@ if __name__ == "__main__": # exp_dir=r"E:\codes\py39\dataset\mi-test" # n_p=16 # f = open("%s/log_extract_f0.log"%exp_dir, "w") + + from configs import Config + Config.use_insecure_load() + printt(" ".join(sys.argv)) featureInput = FeatureInput(is_half, device) paths = [] diff --git a/infer/modules/train/extract_feature_print.py b/infer/modules/train/extract_feature_print.py index 3b59974..1106f4c 100644 --- a/infer/modules/train/extract_feature_print.py +++ b/infer/modules/train/extract_feature_print.py @@ -23,11 +23,15 @@ else: os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu) version = sys.argv[6] is_half = sys.argv[7].lower() == "true" + import fairseq import numpy as np import torch import torch.nn.functional as F +from configs import Config +Config.use_insecure_load() + if "privateuseone" not in device: device = "cpu" if torch.cuda.is_available(): diff --git a/infer/modules/train/preprocess.py b/infer/modules/train/preprocess.py index 86625c9..e7b5368 100644 --- a/infer/modules/train/preprocess.py +++ b/infer/modules/train/preprocess.py @@ -142,4 +142,6 @@ def preprocess_trainset(inp_root, sr, n_p, exp_dir, per): if __name__ == "__main__": + from configs import Config + Config.use_insecure_load() preprocess_trainset(inp_root, sr, n_p, exp_dir, per) diff --git a/infer/modules/vc/hash.py b/infer/modules/vc/hash.py index d8d7e7c..77b6ae1 100644 --- a/infer/modules/vc/hash.py +++ b/infer/modules/vc/hash.py @@ -2,10 +2,12 @@ import numpy as np import torch import hashlib import pathlib + +from functools import lru_cache from scipy.fft import fft from pybase16384 import encode_to_string, decode_from_string -from configs import CPUConfig, singleton_variable +from configs import CPUConfig from rvc.synthesizer import get_synthesizer from .pipeline import Pipeline @@ -29,27 +31,27 @@ half_hash_len = 512 expand_factor = 65536 * 8 -@singleton_variable +@lru_cache(None) # None 表示无限缓存 def original_audio_storage(): return np.load(pathlib.Path(__file__).parent / "lgdsng.npz") -@singleton_variable +@lru_cache(None) def original_audio(): return original_audio_storage()["a"] -@singleton_variable +@lru_cache(None) def original_audio_time_minus(): return original_audio_storage()["t"] -@singleton_variable +@lru_cache(None) def original_audio_freq_minus(): return original_audio_storage()["f"] -@singleton_variable +@lru_cache(None) def original_rmvpe_f0(): x = original_audio_storage() return x["pitch"], x["pitchf"]