From e936e24a91fa57ec509fb5ba49657c3aec48ea2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 14 Jun 2024 22:01:39 +0900 Subject: [PATCH] optimize(infer): move ipex into rvc --- configs/config.py | 11 +---------- infer/modules/train/train.py | 3 +-- rvc/__init__.py | 4 ++++ rvc/ipex/__init__.py | 8 ++++++++ {infer/modules => rvc}/ipex/attention.py | 0 {infer/modules => rvc}/ipex/gradscaler.py | 1 + {infer/modules => rvc}/ipex/hijacks.py | 1 + infer/modules/ipex/__init__.py => rvc/ipex/init.py | 2 ++ 8 files changed, 18 insertions(+), 12 deletions(-) create mode 100644 rvc/ipex/__init__.py rename {infer/modules => rvc}/ipex/attention.py (100%) rename {infer/modules => rvc}/ipex/gradscaler.py (99%) rename {infer/modules => rvc}/ipex/hijacks.py (99%) rename infer/modules/ipex/__init__.py => rvc/ipex/init.py (99%) diff --git a/configs/config.py b/configs/config.py index 2f2f438..14d3827 100644 --- a/configs/config.py +++ b/configs/config.py @@ -6,16 +6,7 @@ import shutil from multiprocessing import cpu_count import torch - -try: - import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import - - if torch.xpu.is_available(): - from infer.modules.ipex import ipex_init - - ipex_init() -except Exception: # pylint: disable=broad-exception-caught - pass +# TODO: move device selection into rvc import logging logger = logging.getLogger(__name__) diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 2d70f83..4e30005 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -24,8 +24,7 @@ try: import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import if torch.xpu.is_available(): - from infer.modules.ipex import ipex_init - from infer.modules.ipex.gradscaler import gradscaler_init + from rvc.ipex import ipex_init, gradscaler_init from torch.xpu.amp import autocast GradScaler = gradscaler_init() diff --git a/rvc/__init__.py b/rvc/__init__.py index e69de29..bfe7152 100644 --- a/rvc/__init__.py +++ b/rvc/__init__.py @@ -0,0 +1,4 @@ +from . import ipex +import sys +del sys.modules["rvc.ipex"] + diff --git a/rvc/ipex/__init__.py b/rvc/ipex/__init__.py new file mode 100644 index 0000000..12288d0 --- /dev/null +++ b/rvc/ipex/__init__.py @@ -0,0 +1,8 @@ +try: + import torch + if torch.xpu.is_available(): + from .init import ipex_init + ipex_init() + from .gradscaler import gradscaler_init +except Exception: # pylint: disable=broad-exception-caught + pass diff --git a/infer/modules/ipex/attention.py b/rvc/ipex/attention.py similarity index 100% rename from infer/modules/ipex/attention.py rename to rvc/ipex/attention.py diff --git a/infer/modules/ipex/gradscaler.py b/rvc/ipex/gradscaler.py similarity index 99% rename from infer/modules/ipex/gradscaler.py rename to rvc/ipex/gradscaler.py index 7875151..a88fe22 100644 --- a/infer/modules/ipex/gradscaler.py +++ b/rvc/ipex/gradscaler.py @@ -1,4 +1,5 @@ from collections import defaultdict + import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import diff --git a/infer/modules/ipex/hijacks.py b/rvc/ipex/hijacks.py similarity index 99% rename from infer/modules/ipex/hijacks.py rename to rvc/ipex/hijacks.py index fc75f0c..51e72e7 100644 --- a/infer/modules/ipex/hijacks.py +++ b/rvc/ipex/hijacks.py @@ -1,5 +1,6 @@ import contextlib import importlib + import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import diff --git a/infer/modules/ipex/__init__.py b/rvc/ipex/init.py similarity index 99% rename from infer/modules/ipex/__init__.py rename to rvc/ipex/init.py index cd27bc1..82b93f6 100644 --- a/infer/modules/ipex/__init__.py +++ b/rvc/ipex/init.py @@ -1,8 +1,10 @@ import os import sys import contextlib + import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + from .hijacks import ipex_hijacks from .attention import attention_init