diff --git a/configs/config.py b/configs/config.py index 2f2f438..14d3827 100644 --- a/configs/config.py +++ b/configs/config.py @@ -6,16 +6,7 @@ import shutil from multiprocessing import cpu_count import torch - -try: - import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import - - if torch.xpu.is_available(): - from infer.modules.ipex import ipex_init - - ipex_init() -except Exception: # pylint: disable=broad-exception-caught - pass +# TODO: move device selection into rvc import logging logger = logging.getLogger(__name__) diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 2d70f83..4e30005 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -24,8 +24,7 @@ try: import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import if torch.xpu.is_available(): - from infer.modules.ipex import ipex_init - from infer.modules.ipex.gradscaler import gradscaler_init + from rvc.ipex import ipex_init, gradscaler_init from torch.xpu.amp import autocast GradScaler = gradscaler_init() diff --git a/rvc/__init__.py b/rvc/__init__.py index e69de29..bfe7152 100644 --- a/rvc/__init__.py +++ b/rvc/__init__.py @@ -0,0 +1,4 @@ +from . import ipex +import sys +del sys.modules["rvc.ipex"] + diff --git a/rvc/ipex/__init__.py b/rvc/ipex/__init__.py new file mode 100644 index 0000000..12288d0 --- /dev/null +++ b/rvc/ipex/__init__.py @@ -0,0 +1,8 @@ +try: + import torch + if torch.xpu.is_available(): + from .init import ipex_init + ipex_init() + from .gradscaler import gradscaler_init +except Exception: # pylint: disable=broad-exception-caught + pass diff --git a/infer/modules/ipex/attention.py b/rvc/ipex/attention.py similarity index 100% rename from infer/modules/ipex/attention.py rename to rvc/ipex/attention.py diff --git a/infer/modules/ipex/gradscaler.py b/rvc/ipex/gradscaler.py similarity index 99% rename from infer/modules/ipex/gradscaler.py rename to rvc/ipex/gradscaler.py index 7875151..a88fe22 100644 --- a/infer/modules/ipex/gradscaler.py +++ b/rvc/ipex/gradscaler.py @@ -1,4 +1,5 @@ from collections import defaultdict + import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import diff --git a/infer/modules/ipex/hijacks.py b/rvc/ipex/hijacks.py similarity index 99% rename from infer/modules/ipex/hijacks.py rename to rvc/ipex/hijacks.py index fc75f0c..51e72e7 100644 --- a/infer/modules/ipex/hijacks.py +++ b/rvc/ipex/hijacks.py @@ -1,5 +1,6 @@ import contextlib import importlib + import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import diff --git a/infer/modules/ipex/__init__.py b/rvc/ipex/init.py similarity index 99% rename from infer/modules/ipex/__init__.py rename to rvc/ipex/init.py index cd27bc1..82b93f6 100644 --- a/infer/modules/ipex/__init__.py +++ b/rvc/ipex/init.py @@ -1,8 +1,10 @@ import os import sys import contextlib + import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + from .hijacks import ipex_hijacks from .attention import attention_init