From 6d8fa05f86fd4715d635d89701ab99bfa2515a4b Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Sat, 14 Oct 2023 13:45:19 -0300 Subject: [PATCH] Command option to set different devices for extensions --- comfy/cli_args.py | 1 + comfy/model_management.py | 12 ++++++++++++ comfy/utils.py | 11 +++++++++++ 3 files changed, 24 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d86557646..39aee7a92 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -43,6 +43,7 @@ parser.add_argument("--input-directory", type=str, default=None, help="Set the C parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") +parser.add_argument("--extension-device", type=str, default=None, help="Set the device for extensions in the format 'extension:device;extension:device;...'.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") diff --git a/comfy/model_management.py b/comfy/model_management.py index c24c7b27e..145cf035b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -41,6 +41,12 @@ if args.directml is not None: # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. +extensions_devices = {} +if args.extension_device is not None: + for ext_dev in args.extension_device.split(";"): + ext, dev = ext_dev.split(":") + extensions_devices[ext] = dev + try: import intel_extension_for_pytorch as ipex if torch.xpu.is_available(): @@ -69,6 +75,12 @@ def is_intel_xpu(): def get_torch_device(): global directml_enabled global cpu_state + global extensions_devices + + extension = comfy.utils.get_extension_calling() + if extension is not None and extension in extensions_devices: + return torch.device(extensions_devices[extension]) + if directml_enabled: global directml_device return directml_device diff --git a/comfy/utils.py b/comfy/utils.py index df016ef9e..ecc3ea6f4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -4,8 +4,19 @@ import struct import comfy.checkpoint_pickle import safetensors.torch import numpy as np +import inspect +import re from PIL import Image +def get_extension_calling(): + for frame in inspect.stack(): + if "/custom_nodes/" in frame.filename: + stack_module = inspect.getmodule(frame[0]) + if stack_module: + return re.sub(r".*\.?custom_nodes\.([^\.]+).*", r"\1", stack_module.__name__).split(".")[0] + + return None + def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu")