1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

fix: Add weight whitelist support for torch 2.6 (#110)

This commit is contained in:
Yongkun Li
2025-02-07 15:26:01 +08:00
committed by GitHub
parent e1aeb16630
commit ef9c8eb656
7 changed files with 45 additions and 21 deletions

View File

@@ -22,18 +22,16 @@ version_config_list = [
]
def singleton_variable(func):
def wrapper(*args, **kwargs):
if wrapper.instance is None:
wrapper.instance = func(*args, **kwargs)
return wrapper.instance
class Singleton(type):
_instances = {}
wrapper.instance = None
return wrapper
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
@singleton_variable
class Config:
class Config(metaclass=Singleton):
def __init__(self):
self.device = "cuda:0"
self.is_half = True
@@ -129,6 +127,16 @@ class Config:
else:
return False
@staticmethod
def use_insecure_load():
try:
from fairseq.data.dictionary import Dictionary
torch.serialization.add_safe_globals([Dictionary])
logging.warning("Using insecure weight loading for fairseq dictionary")
except AttributeError:
pass
def use_fp32_config(self):
for config_file in version_config_list:
self.json_config[config_file]["train"]["fp16_run"] = False
@@ -210,15 +218,20 @@ class Config:
else:
if self.instead:
logger.info(f"Use {self.instead} instead")
logger.info(
"Half-precision floating-point: %s, device: %s"
% (self.is_half, self.device)
)
# Check if the pytorch is 2.6 or higher
if tuple(map(int, torch.__version__.split("+")[0].split("."))) >= (2, 6, 0):
self.use_insecure_load()
return x_pad, x_query, x_center, x_max
@singleton_variable
class CPUConfig:
class CPUConfig(metaclass=Singleton):
def __init__(self):
self.device = "cpu"
self.is_half = False