diff --git a/configs/config.py b/configs/config.py index 9db1126..5232048 100644 --- a/configs/config.py +++ b/configs/config.py @@ -137,11 +137,11 @@ class Config(metaclass=Singleton): logging.warning("Using insecure weight loading for fairseq dictionary") except AttributeError: pass - + @staticmethod def get_default_batch_size() -> int: if not torch.cuda.is_available(): - #TODO: add non-cuda multicards + # TODO: add non-cuda multicards return 1 # 判断是否有能用来训练和加速推理的N卡 ngpu = torch.cuda.device_count() @@ -155,7 +155,10 @@ class Config(metaclass=Singleton): mem.append( int( torch.cuda.get_device_properties(i).total_memory - / 1024 / 1024 / 1024 + 0.4 + / 1024 + / 1024 + / 1024 + + 0.4 ) ) if if_gpu_ok: