mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-07 19:40:44 +08:00
fix(train): cannot extract feature on non-cuda devices (fix #123)
This commit is contained in:
@@ -53,6 +53,7 @@ class Config(metaclass=Singleton):
|
||||
self.instead = ""
|
||||
self.preprocess_per = 3.7
|
||||
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
||||
self.default_batch_size = self.get_default_batch_size()
|
||||
|
||||
@staticmethod
|
||||
def load_config_json() -> dict:
|
||||
@@ -136,6 +137,32 @@ class Config(metaclass=Singleton):
|
||||
logging.warning("Using insecure weight loading for fairseq dictionary")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_default_batch_size() -> int:
|
||||
if not torch.cuda.is_available():
|
||||
#TODO: add non-cuda multicards
|
||||
return 1
|
||||
# 判断是否有能用来训练和加速推理的N卡
|
||||
ngpu = torch.cuda.device_count()
|
||||
if not ngpu:
|
||||
return 1
|
||||
mem = []
|
||||
if_gpu_ok = False
|
||||
|
||||
for i in range(ngpu):
|
||||
if_gpu_ok = True # 至少有一张能用的N卡
|
||||
mem.append(
|
||||
int(
|
||||
torch.cuda.get_device_properties(i).total_memory
|
||||
/ 1024 / 1024 / 1024 + 0.4
|
||||
)
|
||||
)
|
||||
if if_gpu_ok:
|
||||
default_batch_size = min(mem) // 2
|
||||
else:
|
||||
default_batch_size = 1
|
||||
return default_batch_size
|
||||
|
||||
def use_fp32_config(self):
|
||||
for config_file in version_config_list:
|
||||
|
||||
Reference in New Issue
Block a user