diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..9ebd0efe2 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.") +parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.") parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ba670b16d..837aa907a 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -15,13 +15,14 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - +from __future__ import annotations import torch from enum import Enum import math import os import logging +import copy import comfy.utils import comfy.model_management import comfy.model_detection @@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.ldm.qwen_image.controlnet import comfy.cldm.dit_embedder -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from comfy.hooks import HookGroup @@ -64,6 +65,18 @@ class StrengthType(Enum): CONSTANT = 1 LINEAR_UP = 2 +class ControlIsolation: + '''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.''' + def __init__(self, control: ControlBase): + self.control = control + self.orig_previous_controlnet = control.previous_controlnet + + def __enter__(self): + self.control.previous_controlnet = None + + def __exit__(self, *args): + self.control.previous_controlnet = self.orig_previous_controlnet + class ControlBase: def __init__(self): self.cond_hint_original = None @@ -77,7 +90,7 @@ class ControlBase: self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' self.extra_args = {} - self.previous_controlnet = None + self.previous_controlnet: Union[ControlBase, None] = None self.extra_conds = [] self.strength_type = StrengthType.CONSTANT self.concat_mask = False @@ -85,6 +98,7 @@ class ControlBase: self.extra_concat = None self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a + self.multigpu_clones: dict[torch.device, ControlBase] = {} def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): self.cond_hint_original = cond_hint @@ -111,17 +125,38 @@ class ControlBase: def cleanup(self): if self.previous_controlnet is not None: self.previous_controlnet.cleanup() - + for device_cnet in self.multigpu_clones.values(): + with ControlIsolation(device_cnet): + device_cnet.cleanup() self.cond_hint = None self.extra_concat = None self.timestep_range = None def get_models(self): out = [] + for device_cnet in self.multigpu_clones.values(): + out += device_cnet.get_models_only_self() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out + def get_models_only_self(self): + 'Calls get_models, but temporarily sets previous_controlnet to None.' + with ControlIsolation(self): + return self.get_models() + + def get_instance_for_device(self, device): + 'Returns instance of this Control object intended for selected device.' + return self.multigpu_clones.get(device, self) + + def deepclone_multigpu(self, load_device, autoregister=False): + ''' + Create deep clone of Control object where model(s) is set to other devices. + + When autoregister is set to True, the deep clone is also added to multigpu_clones dict. + ''' + raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.") + def get_extra_hooks(self): out = [] if self.extra_hooks is not None: @@ -130,7 +165,7 @@ class ControlBase: out += self.previous_controlnet.get_extra_hooks() return out - def copy_to(self, c): + def copy_to(self, c: ControlBase): c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range @@ -284,6 +319,14 @@ class ControlNet(ControlBase): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.control_model = copy.deepcopy(c.control_model) + c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + if autoregister: + self.multigpu_clones[load_device] = c + return c + def get_models(self): out = super().get_models() out.append(self.control_model_wrapped) @@ -906,6 +949,14 @@ class T2IAdapter(ControlBase): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.t2i_model = copy.deepcopy(c.t2i_model) + c.device = load_device + if autoregister: + self.multigpu_clones[load_device] = c + return c + def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options compression_ratio = 8 upscale_algorithm = 'nearest-exact' diff --git a/comfy/model_management.py b/comfy/model_management.py index bcf1399c4..46261a0ed 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,6 +15,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from __future__ import annotations import psutil import logging @@ -32,6 +33,11 @@ import comfy.memory_management import comfy.utils import comfy.quant_ops +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -206,6 +212,25 @@ def get_torch_device(): else: return torch.device(torch.cuda.current_device()) +def get_all_torch_devices(exclude_current=False): + global cpu_state + devices = [] + if cpu_state == CPUState.GPU: + if is_nvidia(): + for i in range(torch.cuda.device_count()): + devices.append(torch.device(i)) + elif is_intel_xpu(): + for i in range(torch.xpu.device_count()): + devices.append(torch.device(i)) + elif is_ascend_npu(): + for i in range(torch.npu.device_count()): + devices.append(torch.device(i)) + else: + devices.append(get_torch_device()) + if exclude_current: + devices.remove(get_torch_device()) + return devices + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: @@ -494,9 +519,13 @@ try: logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") +try: + for device in get_all_torch_devices(exclude_current=True): + logging.info("Device: {}".format(get_torch_device_name(device))) +except: + pass - -current_loaded_models = [] +current_loaded_models: list[LoadedModel] = [] def module_size(module): module_mem = 0 @@ -529,7 +558,7 @@ def module_mmap_residency(module, free=False): return mmap_touched_mem, module_mem class LoadedModel: - def __init__(self, model): + def __init__(self, model: ModelPatcher): self._set_model(model) self.device = model.load_device self.real_model = None @@ -537,7 +566,7 @@ class LoadedModel: self.model_finalizer = None self._patcher_finalizer = None - def _set_model(self, model): + def _set_model(self, model: ModelPatcher): self._model = weakref.ref(model) if model.parent is not None: self._parent_model = weakref.ref(model.parent) @@ -548,6 +577,7 @@ class LoadedModel: model = self._parent_model() if model is not None: self._set_model(model) + self.device = model.load_device @property def model(self): @@ -1794,7 +1824,34 @@ def soft_empty_cache(force=False): torch.cuda.ipc_collect() def unload_all_models(): - free_memory(1e30, get_torch_device()) + for device in get_all_torch_devices(): + free_memory(1e30, device) + +def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False): + 'Unload only model and its clones - primarily for multigpu cloning purposes.' + initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() + additional_models = [] + if unload_additional_models: + additional_models = model.get_nested_additional_models() + keep_loaded = [] + for loaded_model in initial_keep_loaded: + if loaded_model.model is not None: + if model.clone_base_uuid == loaded_model.model.clone_base_uuid: + continue + # check additional models if they are a match + skip = False + for add_model in additional_models: + if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid: + skip = True + break + if skip: + continue + keep_loaded.append(loaded_model) + if not all_devices: + free_memory(1e30, get_torch_device(), keep_loaded) + else: + for device in get_all_torch_devices(): + free_memory(1e30, device, keep_loaded) def debug_memory_summary(): if is_amd() or is_nvidia(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d19d6fe..092bc6a79 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -23,6 +23,7 @@ import inspect import logging import math import uuid +import copy from typing import Callable, Optional import torch @@ -75,12 +76,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ def create_model_options_clone(orig_model_options: dict): return comfy.patcher_extension.copy_nested_dicts(orig_model_options) -def create_hook_patches_clone(orig_hook_patches): +def create_hook_patches_clone(orig_hook_patches, copy_tuples=False): new_hook_patches = {} for hook_ref in orig_hook_patches: new_hook_patches[hook_ref] = {} for k in orig_hook_patches[hook_ref]: new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] + if copy_tuples: + for i in range(len(new_hook_patches[hook_ref][k])): + new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i]) return new_hook_patches def wipe_lowvram_weight(m): @@ -272,7 +276,10 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed - self.cached_patcher_init: tuple[Callable, tuple] | None = None + self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None + self.is_multigpu_base_clone = False + self.clone_base_uuid = uuid.uuid4() + if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -312,7 +319,8 @@ class ModelPatcher: #than pays for CFG. So return everything both torch and Aimdo could give us aimdo_mem = 0 if comfy.memory_management.aimdo_enabled: - aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device) return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): @@ -326,6 +334,8 @@ class ModelPatcher: if self.cached_patcher_init is None: raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] model_override = temp_model_patcher.get_clone_model_override() if model_override is None: model_override = self.get_clone_model_override() @@ -384,19 +394,98 @@ class ModelPatcher: n.hook_mode = self.hook_mode n.cached_patcher_init = self.cached_patcher_init + n.is_multigpu_base_clone = self.is_multigpu_base_clone + n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n + def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None): + logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.") + comfy.model_management.unload_model_and_clones(self) + n = self.clone() + # set load device, if present + if new_load_device is not None: + n.load_device = new_load_device + if self.cached_patcher_init is not None: + temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] + n.model = temp_model_patcher.model + else: + n.model = copy.deepcopy(n.model) + # unlike for normal clone, backup dicts that shared same ref should not; + # otherwise, patchers that have deep copies of base models will erroneously influence each other. + n.backup = copy.deepcopy(n.backup) + n.object_patches_backup = copy.deepcopy(n.object_patches_backup) + n.hook_backup = copy.deepcopy(n.hook_backup) + # multigpu clone should not have multigpu additional_models entry + n.remove_additional_models("multigpu") + # multigpu_clone all stored additional_models; make sure circular references are properly handled + if models_cache is None: + models_cache = {} + for key, model_list in n.additional_models.items(): + for i in range(len(model_list)): + add_model = n.additional_models[key][i] + if add_model.clone_base_uuid not in models_cache: + models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model.clone_base_uuid] + for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU): + callback(self, n) + return n + + def match_multigpu_clones(self): + multigpu_models = self.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + new_multigpu_models = [] + for mm in multigpu_models: + # clone main model, but bring over relevant props from existing multigpu clone + n = self.clone() + n.load_device = mm.load_device + n.backup = mm.backup + n.object_patches_backup = mm.object_patches_backup + n.hook_backup = mm.hook_backup + n.model = mm.model + n.is_multigpu_base_clone = mm.is_multigpu_base_clone + n.remove_additional_models("multigpu") + orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models) + n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models) + # figure out which additional models are not present in multigpu clone + models_cache = {} + for mm_add_model in mm.get_additional_models(): + models_cache[mm_add_model.clone_base_uuid] = mm_add_model + remove_models_uuids = set(list(models_cache.keys())) + for key, model_list in orig_additional_models.items(): + for orig_add_model in model_list: + if orig_add_model.clone_base_uuid not in models_cache: + models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache) + existing_list = n.get_additional_models_with_key(key) + existing_list.append(models_cache[orig_add_model.clone_base_uuid]) + n.set_additional_models(key, existing_list) + if orig_add_model.clone_base_uuid in remove_models_uuids: + remove_models_uuids.remove(orig_add_model.clone_base_uuid) + # remove duplicate additional models + for key, model_list in n.additional_models.items(): + new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids] + n.set_additional_models(key, new_model_list) + for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES): + callback(self, n) + new_multigpu_models.append(n) + self.set_additional_models("multigpu", new_multigpu_models) + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True return False - def clone_has_same_weights(self, clone: 'ModelPatcher'): - if not self.is_clone(clone): - return False + def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False): + if allow_multigpu: + if self.clone_base_uuid != clone.clone_base_uuid: + return False + else: + if not self.is_clone(clone): + return False if self.current_hooks != clone.current_hooks: return False @@ -1171,7 +1260,7 @@ class ModelPatcher: return self.additional_models.get(key, []) def get_additional_models(self): - all_models = [] + all_models: list[ModelPatcher] = [] for models in self.additional_models.values(): all_models.extend(models) return all_models @@ -1225,9 +1314,13 @@ class ModelPatcher: for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep): + def prepare_state(self, timestep, model_options, ignore_multigpu=False): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep) + callback(self, timestep, model_options, ignore_multigpu) + if not ignore_multigpu and "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options, ignore_multigpu=True) def restore_hook_patches(self): if self.hook_patches_backup is not None: @@ -1240,12 +1333,18 @@ class ModelPatcher: def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + multigpu_kf_changed_cache = None transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: + # cache changed for multigpu usage + if "multigpu_clones" in model_options: + if multigpu_kf_changed_cache is None: + multigpu_kf_changed_cache = [] + multigpu_kf_changed_cache.append(hook) # reset current_hooks if contains hook that changed if self.current_hooks is not None: for current_hook in self.current_hooks.hooks: @@ -1257,6 +1356,28 @@ class ModelPatcher: self.cached_hook_patches.pop(cached_group) if reset_current_hooks: self.patch_hooks(None) + if "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p._handle_changed_hook_keyframes(multigpu_kf_changed_cache) + + def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]): + 'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.' + if kf_changed_cache is None: + return + reset_current_hooks = False + # reset current_hooks if contains hook that changed + for hook in kf_changed_cache: + if self.current_hooks is not None: + for current_hook in self.current_hooks.hooks: + if current_hook == hook: + reset_current_hooks = True + break + for cached_group in list(self.cached_hook_patches.keys()): + if cached_group.contains(hook): + self.cached_hook_patches.pop(cached_group) + if reset_current_hooks: + self.patch_hooks(None) def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, registered: comfy.hooks.HookGroup = None): diff --git a/comfy/multigpu.py b/comfy/multigpu.py new file mode 100644 index 000000000..096270c12 --- /dev/null +++ b/comfy/multigpu.py @@ -0,0 +1,230 @@ +from __future__ import annotations +import queue +import threading +import torch +import logging + +from collections import namedtuple +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management + + +class MultiGPUThreadPool: + """Persistent thread pool for multi-GPU work distribution. + + Maintains one worker thread per extra GPU device. Each thread calls + torch.cuda.set_device() once at startup so that compiled kernel caches + (inductor/triton) stay warm across diffusion steps. + """ + + def __init__(self, devices: list[torch.device]): + self._workers: list[threading.Thread] = [] + self._work_queues: dict[torch.device, queue.Queue] = {} + self._result_queues: dict[torch.device, queue.Queue] = {} + + for device in devices: + wq = queue.Queue() + rq = queue.Queue() + self._work_queues[device] = wq + self._result_queues[device] = rq + t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True) + t.start() + self._workers.append(t) + + def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): + try: + torch.cuda.set_device(device) + except Exception as e: + logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") + while True: + item = work_q.get() + if item is None: + return + result_q.put((None, e)) + return + while True: + item = work_q.get() + if item is None: + break + fn, args, kwargs = item + try: + result = fn(*args, **kwargs) + result_q.put((result, None)) + except Exception as e: + result_q.put((None, e)) + + def submit(self, device: torch.device, fn, *args, **kwargs): + self._work_queues[device].put((fn, args, kwargs)) + + def get_result(self, device: torch.device): + return self._result_queues[device].get() + + @property + def devices(self) -> list[torch.device]: + return list(self._work_queues.keys()) + + def shutdown(self): + for wq in self._work_queues.values(): + wq.put(None) # sentinel + for t in self._workers: + t.join(timeout=5.0) + + +class GPUOptions: + def __init__(self, device_index: int, relative_speed: float): + self.device_index = device_index + self.relative_speed = relative_speed + + def clone(self): + return GPUOptions(self.device_index, self.relative_speed) + + def create_dict(self): + return { + "relative_speed": self.relative_speed + } + +class GPUOptionsGroup: + def __init__(self): + self.options: dict[int, GPUOptions] = {} + + def add(self, info: GPUOptions): + self.options[info.device_index] = info + + def clone(self): + c = GPUOptionsGroup() + for opt in self.options.values(): + c.add(opt) + return c + + def register(self, model: ModelPatcher): + opts_dict = {} + # get devices that are valid for this model + devices: list[torch.device] = [model.load_device] + for extra_model in model.get_additional_models_with_key("multigpu"): + extra_model: ModelPatcher + devices.append(extra_model.load_device) + # create dictionary with actual device mapped to its GPUOptions + device_opts_list: list[GPUOptions] = [] + for device in devices: + device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) + opts_dict[device] = device_opts.create_dict() + device_opts_list.append(device_opts) + # make relative_speed relative to 1.0 + min_speed = min([x.relative_speed for x in device_opts_list]) + for value in opts_dict.values(): + value['relative_speed'] /= min_speed + model.model_options['multigpu_options'] = opts_dict + + +def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False): + 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.' + model = model.clone() + # check if multigpu is already prepared - get the load devices from them if possible to exclude + skip_devices = set() + multigpu_models = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + for mm in multigpu_models: + skip_devices.add(mm.load_device) + skip_devices = list(skip_devices) + + full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + limit_extra_devices = full_extra_devices[:max_gpus-1] + extra_devices = limit_extra_devices.copy() + # exclude skipped devices + for skip in skip_devices: + if skip in extra_devices: + extra_devices.remove(skip) + # create new deepclones + if len(extra_devices) > 0: + for device in extra_devices: + device_patcher = None + if reuse_loaded: + # check if there are any ModelPatchers currently loaded that could be referenced here after a clone + loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models() + for lm in loaded_models: + if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device: + device_patcher = lm.clone() + logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}") + break + if device_patcher is None: + device_patcher = model.deepclone_multigpu(new_load_device=device) + device_patcher.is_multigpu_base_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + model.match_multigpu_clones() + if gpu_options is None: + gpu_options = GPUOptionsGroup() + gpu_options.register(model) + else: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") + # TODO: only keep model clones that don't go 'past' the intended max_gpu count + # multigpu_models = model.get_additional_models_with_key("multigpu") + # new_multigpu_models = [] + # for m in multigpu_models: + # if m.load_device in limit_extra_devices: + # new_multigpu_models.append(m) + # model.set_additional_models("multigpu", new_multigpu_models) + # persist skip_devices for use in sampling code + # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: + # model.model_options["multigpu_skip_devices"] = skip_devices + return model + + +LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) +def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): + 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' + opts_dict = model_options['multigpu_options'] + devices = list(model_options['multigpu_clones'].keys()) + speed_per_device = [] + work_per_device = [] + # get sum of each device's relative_speed + total_speed = 0.0 + for opts in opts_dict.values(): + total_speed += opts['relative_speed'] + # get relative work for each device; + # obtained by w = (W*r)/R + for device in devices: + relative_speed = opts_dict[device]['relative_speed'] + relative_work = (total_work*relative_speed) / total_speed + speed_per_device.append(relative_speed) + work_per_device.append(relative_work) + # relative work must be expressed in whole numbers, but likely is a decimal; + # perform rounding while maintaining total sum equal to total work (sum of relative works) + work_per_device = round_preserved(work_per_device) + dict_work_per_device = {} + for device, relative_work in zip(devices, work_per_device): + dict_work_per_device[device] = relative_work + if not return_idle_time: + return LoadBalance(dict_work_per_device, None) + # divide relative work by relative speed to get estimated completion time of said work by each device; + # time here is relative and does not correspond to real-world units + completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)] + # calculate relative time spent by the devices waiting on each other after their work is completed + idle_time = abs(min(completion_time) - max(completion_time)) + # if need to compare work idle time, need to normalize to a common total work + if work_normalized: + idle_time *= (work_normalized/total_work) + + return LoadBalance(dict_work_per_device, idle_time) + +def round_preserved(values: list[float]): + 'Round all values in a list, preserving the combined sum of values.' + # get floor of values; casting to int does it too + floored = [int(x) for x in values] + total_floored = sum(floored) + # get remainder to distribute + remainder = round(sum(values)) - total_floored + # pair values with fractional portions + fractional = [(i, x-floored[i]) for i, x in enumerate(values)] + # sort by fractional part in descending order + fractional.sort(key=lambda x: x[1], reverse=True) + # distribute the remainder + for i in range(remainder): + index = fractional[i][0] + floored[index] += 1 + return floored diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 5ee4d5ee5..4b276b175 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -3,6 +3,8 @@ from typing import Callable class CallbacksMP: ON_CLONE = "on_clone" + ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu" + ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones" ON_LOAD = "on_load_after" ON_DETACH = "on_detach_after" ON_CLEANUP = "on_cleanup" diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 42ee08fb2..37e546722 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -20,7 +20,6 @@ try: if cuda_version < (13,): ck.registry.disable("cuda") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") - ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index bbba09e26..6f5447d95 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,16 +1,18 @@ from __future__ import annotations +import torch import uuid import math import collections import comfy.model_management import comfy.conds +import comfy.model_patcher import comfy.utils import comfy.hooks import comfy.patcher_extension from typing import TYPE_CHECKING if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel + from comfy.model_patcher import ModelPatcher from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): @@ -118,6 +120,47 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # potentially handle gligen - since not widely used, ignored for now + def estimate_memory(model, noise_shape, conds): cond_shapes = collections.defaultdict(list) cond_shapes_min = {} @@ -142,7 +185,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload) def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): - real_model: BaseModel = None + model.match_multigpu_clones() + preprocess_multigpu_conds(conds, model, model_options) models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? @@ -154,7 +198,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non memory_required += inference_memory minimum_memory_required += inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) - real_model = model.model + real_model: BaseModel = model.model return real_model, conds, models @@ -200,3 +244,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options + +def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict): + ''' + In case multigpu acceleration is enabled, prep ModelPatchers for each device. + ''' + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[model_patcher.load_device] = model_patcher + for x in multigpu_patchers: + x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + x.hook_mode = model_patcher.hook_mode # match main model's hook_mode + multigpu_dict[x.load_device] = x + model_options["multigpu_clones"] = multigpu_dict + return multigpu_patchers diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..8ebf1c496 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,7 +1,9 @@ from __future__ import annotations + +import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc -from typing import TYPE_CHECKING, Callable, NamedTuple +from typing import TYPE_CHECKING, Callable, NamedTuple, Any if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel @@ -16,6 +18,7 @@ import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows +import comfy.multigpu import comfy.utils import scipy.stats import numpy @@ -141,7 +144,7 @@ def can_concat_cond(c1, c2): return cond_equal_size(c1.conditioning, c2.conditioning) -def cond_cat(c_list): +def cond_cat(c_list, device=None): temp = {} for x in c_list: for k in x: @@ -153,6 +156,8 @@ def cond_cat(c_list): for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) + if device is not None and hasattr(out[k], 'to'): + out[k] = out[k].to(device) return out @@ -212,7 +217,9 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + if 'multigpu_clones' in model_options: + return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options) out_conds = [] out_counts = [] # separate conds by matching hooks @@ -244,7 +251,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): @@ -345,6 +352,212 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens return out_conds +def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + out_conds = [] + out_counts = [] + # separate conds by matching hooks + hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} + default_conds = [] + has_default_conds = False + + output_device = x_in.device + + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) + + cond = conds[i] + default_c = [] + if cond is not None: + for x in cond: + if 'default' in x: + default_c.append(x) + has_default_conds = True + continue + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + default_conds.append(default_c) + + if has_default_conds: + finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) + + model.current_patcher.prepare_state(timestep, model_options) + + devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()] + device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} + + total_conds = 0 + for to_run in hooked_to_run.values(): + total_conds += len(to_run) + conds_per_device = max(1, math.ceil(total_conds//len(devices))) + index_device = 0 + current_device = devices[index_device] + # run every hooked_to_run separately + for hooks, to_run in hooked_to_run.items(): + while len(to_run) > 0: + current_device = devices[index_device % len(devices)] + batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + # keep track of conds currently scheduled onto this device + batched_to_run_length = 0 + for btr in batched_to_run: + batched_to_run_length += len(btr[1]) + + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + # make sure not over conds_per_device limit when creating temp batch + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = comfy.model_management.get_free_memory(current_device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) * 1.5 < free_memory: + to_batch = batch_amount + break + conds_to_batch = [] + for x in to_batch: + conds_to_batch.append(to_run.pop(x)) + batched_to_run_length += len(conds_to_batch) + + batched_to_run.append((hooks, conds_to_batch)) + if batched_to_run_length >= conds_per_device: + index_device += 1 + + class thread_result(NamedTuple): + output: Any + mult: Any + area: Any + batch_chunks: int + cond_or_uncond: Any + error: Exception = None + + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): + try: + torch.cuda.set_device(device) + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control: ControlBase = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts( + transformer_options.get("patches", {}), + patches + ) + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep.to(device) + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device + + cast_transformer_options(transformer_options, device=device) + c['transformer_options'] = transformer_options + + if control is not None: + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) + else: + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + except Exception as e: + results.append(thread_result(None, None, None, None, None, error=e)) + raise + + + def _handle_batch_pooled(device, batch_tuple): + worker_results = [] + _handle_batch(device, batch_tuple, worker_results) + return worker_results + + results: list[thread_result] = [] + thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool") + + # Submit all GPU work to pool threads + pool_devices = [] + for device, batch_tuple in device_batched_hooked_to_run.items(): + if thread_pool is not None: + thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple) + pool_devices.append(device) + else: + # Fallback: no pool, run everything on main thread + _handle_batch(device, batch_tuple, results) + + # Collect results from pool workers + for device in pool_devices: + worker_results, error = thread_pool.get_result(device) + if error is not None: + raise error + results.extend(worker_results) + + for output, mult, area, batch_chunks, cond_or_uncond, error in results: + if error is not None: + raise error + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) @@ -649,6 +862,8 @@ def pre_run_control(model, conds): percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: x['control'].pre_run(model, percent_to_timestep_function) + for device_cnet in x['control'].multigpu_clones.values(): + device_cnet.pre_run(model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -891,7 +1106,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return + cast_transformer_options(to_load_options, device, dtype) +def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None): casts = [] if device is not None: casts.append(device) @@ -900,18 +1117,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - # try to call .to on patches - if "patches" in to_load_options: - patches = to_load_options["patches"] + if "patches" in transformer_options: + patches = transformer_options["patches"] for name in patches: patch_list = patches[name] for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): for cast in casts: patch_list[i] = patch_list[i].to(cast) - if "patches_replace" in to_load_options: - patches = to_load_options["patches_replace"] + if "patches_replace" in transformer_options: + patches = transformer_options["patches_replace"] for name in patches: patch_list = patches[name] for k in patch_list: @@ -921,8 +1137,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: - if wc_name in to_load_options: - wc: dict[str, list] = to_load_options[wc_name] + if wc_name in transformer_options: + wc: dict[str, list] = transformer_options[wc_name] for wc_dict in wc.values(): for wc_list in wc_dict.values(): for i in range(len(wc_list)): @@ -930,7 +1146,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): for cast in casts: wc_list[i] = wc_list[i].to(cast) - class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -985,16 +1200,31 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) + + # Create persistent thread pool for all GPU devices (main + extras) + if multigpu_patchers: + extra_devices = [p.load_device for p in multigpu_patchers] + all_devices = [device] + extra_devices + self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..0ce450ace 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1604,10 +1604,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - if output_model and out[0] is not None: - out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) - if output_clip and out[1] is not None: - out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) + out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) return out def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py new file mode 100644 index 000000000..5d24952bf --- /dev/null +++ b/comfy_extras/nodes_multigpu.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from inspect import cleandoc +from typing import TYPE_CHECKING +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher +import comfy.multigpu + + +class MultiGPUCFGSplitNode(io.ComfyNode): + """ + Prepares model to have sampling accelerated via splitting work units. + + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_WorkUnits", + display_name="MultiGPU CFG Split", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Model.Input("model"), + io.Int.Input("max_gpus", default=2, min=1, step=1), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) + return io.NodeOutput(model) + + +class MultiGPUOptionsNode(io.ComfyNode): + """ + Select the relative speed of GPUs in the special case they have significantly different performance from one another. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_Options", + display_name="MultiGPU Options", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Int.Input("device_index", default=0, min=0, max=64), + io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01), + io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + ], + outputs=[ + io.Custom("GPU_OPTIONS").Output(), + ], + ) + + @classmethod + def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: + if not gpu_options: + gpu_options = comfy.multigpu.GPUOptionsGroup() + gpu_options.clone() + + opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) + gpu_options.add(opt) + + return io.NodeOutput(gpu_options) + + +class MultiGPUExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + MultiGPUCFGSplitNode, + # MultiGPUOptionsNode, + ] + + +async def comfy_entrypoint() -> MultiGPUExtension: + return MultiGPUExtension() diff --git a/main.py b/main.py index 12b04719d..de145a1e9 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,7 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") - +import torch import comfy.utils import execution @@ -210,7 +210,7 @@ import comfy.model_patcher if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + elif comfy_aimdo.control.init_devices(range(torch.cuda.device_count())): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': diff --git a/nodes.py b/nodes.py index 299b3d758..9eced6838 100644 --- a/nodes.py +++ b/nodes.py @@ -2412,6 +2412,7 @@ async def init_builtin_extra_nodes(): "nodes_lt_audio.py", "nodes_lt.py", "nodes_hooks.py", + "nodes_multigpu.py", "nodes_load_3d.py", "nodes_cosmos.py", "nodes_video.py", diff --git a/requirements.txt b/requirements.txt index 3de845f48..a8e4f9bf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo>=0.2.12 +comfy-aimdo==0.0.213 requests simpleeval>=1.0.0 blake3