diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 47b8174f4..9bda414d1 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.") +parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.") parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ba670b16d..6dbbaa959 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -15,13 +15,14 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - +from __future__ import annotations import torch from enum import Enum import math import os import logging +import copy import comfy.utils import comfy.model_management import comfy.model_detection @@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.ldm.qwen_image.controlnet import comfy.cldm.dit_embedder -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from comfy.hooks import HookGroup @@ -64,6 +65,18 @@ class StrengthType(Enum): CONSTANT = 1 LINEAR_UP = 2 +class ControlIsolation: + '''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.''' + def __init__(self, control: ControlBase): + self.control = control + self.orig_previous_controlnet = control.previous_controlnet + + def __enter__(self): + self.control.previous_controlnet = None + + def __exit__(self, *args): + self.control.previous_controlnet = self.orig_previous_controlnet + class ControlBase: def __init__(self): self.cond_hint_original = None @@ -77,7 +90,7 @@ class ControlBase: self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' self.extra_args = {} - self.previous_controlnet = None + self.previous_controlnet: Union[ControlBase, None] = None self.extra_conds = [] self.strength_type = StrengthType.CONSTANT self.concat_mask = False @@ -85,6 +98,7 @@ class ControlBase: self.extra_concat = None self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a + self.multigpu_clones: dict[torch.device, ControlBase] = {} def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): self.cond_hint_original = cond_hint @@ -111,17 +125,38 @@ class ControlBase: def cleanup(self): if self.previous_controlnet is not None: self.previous_controlnet.cleanup() - + for device_cnet in self.multigpu_clones.values(): + with ControlIsolation(device_cnet): + device_cnet.cleanup() self.cond_hint = None self.extra_concat = None self.timestep_range = None def get_models(self): out = [] + for device_cnet in self.multigpu_clones.values(): + out += device_cnet.get_models_only_self() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out + def get_models_only_self(self): + 'Calls get_models, but temporarily sets previous_controlnet to None.' + with ControlIsolation(self): + return self.get_models() + + def get_instance_for_device(self, device): + 'Returns instance of this Control object intended for selected device.' + return self.multigpu_clones.get(device, self) + + def deepclone_multigpu(self, load_device, autoregister=False): + ''' + Create deep clone of Control object where model(s) is set to other devices. + + When autoregister is set to True, the deep clone is also added to multigpu_clones dict. + ''' + raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.") + def get_extra_hooks(self): out = [] if self.extra_hooks is not None: @@ -130,7 +165,7 @@ class ControlBase: out += self.previous_controlnet.get_extra_hooks() return out - def copy_to(self, c): + def copy_to(self, c: ControlBase): c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range @@ -284,6 +319,14 @@ class ControlNet(ControlBase): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.control_model = copy.deepcopy(c.control_model) + c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + if autoregister: + self.multigpu_clones[load_device] = c + return c + def get_models(self): out = super().get_models() out.append(self.control_model_wrapped) @@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet): super().pre_run(model, percent_to_timestep_function) self.set_extra_arg("base_model", model.diffusion_model) + def cleanup(self): + self.extra_args.pop("base_model", None) + super().cleanup() + def copy(self): c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c.control_model = self.control_model @@ -906,6 +953,14 @@ class T2IAdapter(ControlBase): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.t2i_model = copy.deepcopy(c.t2i_model) + c.device = load_device + if autoregister: + self.multigpu_clones[load_device] = c + return c + def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options compression_ratio = 8 upscale_algorithm = 'nearest-exact' diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index bc36b8998..4e4819fe3 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module): def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - if context.shape[0] >= 2: - uncond_emb, cond_emb = context.chunk(2, dim = 0) - context = torch.cat([cond_emb, uncond_emb], dim = 0) + + swap_cfg_halves = context.shape[0] >= 2 + + if swap_cfg_halves: + first_half, second_half = context.chunk(2, dim = 0) + context = torch.cat([second_half, first_half], dim = 0) + main_condition = context t = 1.0 - t @@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - if output.shape[0] >= 2: - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) - else: - return output + if swap_cfg_halves: + first_half, second_half = output.chunk(2, dim = 0) + output = torch.cat([second_half, first_half], dim = 0) + + return output diff --git a/comfy/memory_management.py b/comfy/memory_management.py index c43f0c4a2..962addb27 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,6 +1,5 @@ import math import ctypes -import threading import dataclasses import torch from typing import NamedTuple @@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor class TensorFileSlice(NamedTuple): file_ref: object - thread_id: int + lock: object offset: int size: int @@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N file_obj = info.file_ref if (destination.device.type != "cpu" or file_obj is None - or threading.get_ident() != info.thread_id or destination.numel() * destination.element_size() < info.size or tensor.numel() * tensor.element_size() != info.size or tensor.storage_offset() != 0 @@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N if hostbuf is not None: stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 device_ptr = destination2.data_ptr() if destination2 is not None else 0 - hostbuf.read_file_slice(file_obj, info.offset, info.size, - offset=destination.data_ptr() - hostbuf.get_raw_address(), - stream=stream_ptr, - device_ptr=device_ptr, - device=None if destination2 is None else destination2.device.index) + with info.lock: + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr, + device=None if destination2 is None else destination2.device.index) return True buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) try: - file_obj.seek(info.offset) - done = 0 - while done < info.size: - try: - n = file_obj.readinto(view[done:]) - except OSError: - return False - if n <= 0: - return False - done += n + with info.lock: + file_obj.seek(info.offset) + done = 0 + while done < info.size: + try: + n = file_obj.readinto(view[done:]) + except OSError: + return False + if n <= 0: + return False + done += n return True finally: view.release() diff --git a/comfy/model_management.py b/comfy/model_management.py index cd8772d3a..b01c4d7fa 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,6 +15,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from __future__ import annotations import psutil import logging @@ -27,13 +28,18 @@ import platform import weakref import gc import os -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops import comfy_aimdo.host_buffer import comfy_aimdo.vram_buffer +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -204,6 +210,107 @@ def get_torch_device(): else: return torch.device(torch.cuda.current_device()) +def get_all_torch_devices(exclude_current=False): + global cpu_state + devices = [] + if cpu_state == CPUState.GPU: + # NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*; + # without the AMD arm, single-GPU ROCm users get an empty list + # which silently turns unload_all_models() into a no-op. + if is_nvidia() or is_amd(): + for i in range(torch.cuda.device_count()): + devices.append(torch.device("cuda", i)) + elif is_intel_xpu(): + for i in range(torch.xpu.device_count()): + devices.append(torch.device("xpu", i)) + elif is_ascend_npu(): + for i in range(torch.npu.device_count()): + devices.append(torch.device("npu", i)) + elif is_mlu(): + for i in range(torch.mlu.device_count()): + devices.append(torch.device("mlu", i)) + else: + # Fallback for unhandled GPU backends (e.g. DirectML): at least + # report the current device so callers like unload_all_models() + # do not silently no-op. + devices.append(get_torch_device()) + else: + devices.append(get_torch_device()) + if exclude_current: + current = get_torch_device() + if current in devices: + devices.remove(current) + return devices + +def get_gpu_device_options(): + """Return list of device option strings for node widgets. + + Always includes "default" and "cpu". When multiple GPUs are present, + adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels). + """ + options = ["default", "cpu"] + devices = get_all_torch_devices() + if len(devices) > 1: + for i in range(len(devices)): + options.append(f"gpu:{i}") + return options + +def get_gpu_device_options_no_cpu(): + """Variant of get_gpu_device_options that omits "cpu". + + Intended for components like the VAE selector where running on CPU + is impractical and should not be offered as a choice. + """ + return [o for o in get_gpu_device_options() if o != "cpu"] + +def resolve_gpu_device_option(option: str): + """Resolve a device option string to a torch.device. + + Returns None for "default" (let the caller use its normal default). + Returns torch.device("cpu") for "cpu". + For "gpu:N", returns the Nth torch device. Returns None if the + index is out of range, the option string is malformed, or + unrecognized (callers are expected to log their own context-rich + message before falling back to the default device). + """ + if option is None or option == "default": + return None + if option == "cpu": + return torch.device("cpu") + if option.startswith("gpu:"): + try: + idx = int(option[4:]) + except ValueError: + return None + devices = get_all_torch_devices() + if 0 <= idx < len(devices): + return devices[idx] + return None + +@contextmanager +def cuda_device_context(device): + """Context manager that sets torch.cuda.current_device to match *device*. + + Used when running operations on a non-default CUDA device so that custom + CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct + device index. The previous device is restored on exit. + + No-op when *device* is not CUDA, has no explicit index, or already matches + the current device. + """ + prev = None + if device.type == "cuda" and device.index is not None: + prev = torch.cuda.current_device() + if prev != device.index: + torch.cuda.set_device(device) + else: + prev = None + try: + yield + finally: + if prev is not None: + torch.cuda.set_device(prev) + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: @@ -492,9 +599,13 @@ try: logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") +try: + for device in get_all_torch_devices(exclude_current=True): + logging.info("Device: {}".format(get_torch_device_name(device))) +except: + pass - -current_loaded_models = [] +current_loaded_models: list[LoadedModel] = [] DIRTY_MMAPS = set() @@ -554,7 +665,7 @@ def ensure_pin_registerable(size, evict_active=False): return shortfall <= REGISTERABLE_PIN_HYSTERESIS class LoadedModel: - def __init__(self, model): + def __init__(self, model: ModelPatcher): self._set_model(model) self.device = model.load_device self.real_model = None @@ -562,7 +673,7 @@ class LoadedModel: self.model_finalizer = None self._patcher_finalizer = None - def _set_model(self, model): + def _set_model(self, model: ModelPatcher): self._model = weakref.ref(model) if model.parent is not None: self._parent_model = weakref.ref(model.parent) @@ -573,6 +684,7 @@ class LoadedModel: model = self._parent_model() if model is not None: self._set_model(model) + self.device = model.load_device @property def model(self): @@ -1848,7 +1960,34 @@ def soft_empty_cache(force=False): torch.cuda.ipc_collect() def unload_all_models(): - free_memory(1e30, get_torch_device()) + for device in get_all_torch_devices(): + free_memory(1e30, device) + +def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False): + 'Unload only model and its clones - primarily for multigpu cloning purposes.' + initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() + additional_models = [] + if unload_additional_models: + additional_models = model.get_nested_additional_models() + keep_loaded = [] + for loaded_model in initial_keep_loaded: + if loaded_model.model is not None: + if model.clone_base_uuid == loaded_model.model.clone_base_uuid: + continue + # check additional models if they are a match + skip = False + for add_model in additional_models: + if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid: + skip = True + break + if skip: + continue + keep_loaded.append(loaded_model) + if not all_devices: + free_memory(1e30, get_torch_device(), keep_loaded) + else: + for device in get_all_torch_devices(): + free_memory(1e30, device, keep_loaded) def debug_memory_summary(): if is_amd() or is_nvidia(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b44b99e4a..00a15fa63 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -78,12 +78,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ def create_model_options_clone(orig_model_options: dict): return comfy.patcher_extension.copy_nested_dicts(orig_model_options) -def create_hook_patches_clone(orig_hook_patches): +def create_hook_patches_clone(orig_hook_patches, copy_tuples=False): new_hook_patches = {} for hook_ref in orig_hook_patches: new_hook_patches[hook_ref] = {} for k in orig_hook_patches[hook_ref]: new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] + if copy_tuples: + for i in range(len(new_hook_patches[hook_ref][k])): + new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i]) return new_hook_patches def wipe_lowvram_weight(m): @@ -329,7 +332,10 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed - self.cached_patcher_init: tuple[Callable, tuple] | None = None + self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None + self.is_multigpu_base_clone = False + self.clone_base_uuid = uuid.uuid4() + if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -366,7 +372,8 @@ class ModelPatcher: #than pays for CFG. So return everything both torch and Aimdo could give us aimdo_mem = 0 if comfy.memory_management.aimdo_enabled: - aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device) return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): @@ -380,6 +387,8 @@ class ModelPatcher: if self.cached_patcher_init is None: raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] model_override = temp_model_patcher.get_clone_model_override() if model_override is None: model_override = self.get_clone_model_override() @@ -438,19 +447,113 @@ class ModelPatcher: n.hook_mode = self.hook_mode n.cached_patcher_init = self.cached_patcher_init + n.is_multigpu_base_clone = self.is_multigpu_base_clone + n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n + def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None): + logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.") + if self.cached_patcher_init is None: + raise RuntimeError( + f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: " + "the loader that produced this model does not support multigpu " + "(cached_patcher_init is not initialized). Use a core loader " + "(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), " + "or have the custom loader register a cached_patcher_init factory." + ) + comfy.model_management.unload_model_and_clones(self) + # Produce a freshly-loaded patcher from the loader factory so the multigpu + # clone owns its own untainted model weights (rather than relying on + # copy.deepcopy of an already-patched/already-loaded module). + temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] + # Override clone()'s normal "share self.model + share backup containers" with + # the pristine model from temp_model_patcher plus empty backup containers -- + # the fresh model has no patches applied, so any deepcopy of self's stale + # backup/object_patches_backup/pinned would just propagate dead state that + # no longer corresponds to anything in n.model. + model_override = (temp_model_patcher.model, ({}, {}, {}, set())) + n = self.clone(model_override=model_override) + # clone() copies hook_backup by reference from self; reset since model is pristine. + n.hook_backup = {} + # set load device, if present + if new_load_device is not None: + n.load_device = new_load_device + # Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins) + # has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's + # __init__ only registered its own (default) load_device. + if hasattr(n, "register_load_device"): + n.register_load_device(n.load_device) + # multigpu clone should not have multigpu additional_models entry + n.remove_additional_models("multigpu") + # multigpu_clone all stored additional_models; make sure circular references are properly handled + if models_cache is None: + models_cache = {} + for key, model_list in n.additional_models.items(): + for i in range(len(model_list)): + add_model = n.additional_models[key][i] + if add_model.clone_base_uuid not in models_cache: + models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model.clone_base_uuid] + for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU): + callback(self, n) + return n + + def match_multigpu_clones(self): + multigpu_models = self.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + new_multigpu_models = [] + for mm in multigpu_models: + # clone main model, but bring over relevant props from existing multigpu clone + n = self.clone() + n.load_device = mm.load_device + n.backup = mm.backup + n.object_patches_backup = mm.object_patches_backup + n.hook_backup = mm.hook_backup + n.model = mm.model + n.is_multigpu_base_clone = mm.is_multigpu_base_clone + n.remove_additional_models("multigpu") + orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models) + n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models) + # figure out which additional models are not present in multigpu clone + models_cache = {} + for mm_add_model in mm.get_additional_models(): + models_cache[mm_add_model.clone_base_uuid] = mm_add_model + remove_models_uuids = set(list(models_cache.keys())) + for key, model_list in orig_additional_models.items(): + for orig_add_model in model_list: + if orig_add_model.clone_base_uuid not in models_cache: + models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache) + existing_list = n.get_additional_models_with_key(key) + existing_list.append(models_cache[orig_add_model.clone_base_uuid]) + n.set_additional_models(key, existing_list) + if orig_add_model.clone_base_uuid in remove_models_uuids: + remove_models_uuids.remove(orig_add_model.clone_base_uuid) + # remove duplicate additional models + for key, model_list in n.additional_models.items(): + new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids] + n.set_additional_models(key, new_model_list) + for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES): + callback(self, n) + new_multigpu_models.append(n) + self.set_additional_models("multigpu", new_multigpu_models) + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True return False - def clone_has_same_weights(self, clone: 'ModelPatcher'): - if not self.is_clone(clone): - return False + def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False): + if allow_multigpu: + if self.clone_base_uuid != clone.clone_base_uuid: + return False + else: + if not self.is_clone(clone): + return False if self.current_hooks != clone.current_hooks: return False @@ -1232,7 +1335,7 @@ class ModelPatcher: return self.additional_models.get(key, []) def get_additional_models(self): - all_models = [] + all_models: list[ModelPatcher] = [] for models in self.additional_models.values(): all_models.extend(models) return all_models @@ -1286,9 +1389,18 @@ class ModelPatcher: for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep): + def prepare_state(self, timestep, model_options): + ignore_multigpu = model_options.get("ignore_multigpu", False) for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep) + callback(self, timestep, model_options) + if not ignore_multigpu and "multigpu_clones" in model_options: + model_options["ignore_multigpu"] = True + try: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options) + finally: + model_options.pop("ignore_multigpu", None) def restore_hook_patches(self): if self.hook_patches_backup is not None: @@ -1301,12 +1413,18 @@ class ModelPatcher: def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + multigpu_kf_changed_cache = None transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: + # cache changed for multigpu usage + if "multigpu_clones" in model_options: + if multigpu_kf_changed_cache is None: + multigpu_kf_changed_cache = [] + multigpu_kf_changed_cache.append(hook) # reset current_hooks if contains hook that changed if self.current_hooks is not None: for current_hook in self.current_hooks.hooks: @@ -1318,6 +1436,28 @@ class ModelPatcher: self.cached_hook_patches.pop(cached_group) if reset_current_hooks: self.patch_hooks(None) + if "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p._handle_changed_hook_keyframes(multigpu_kf_changed_cache) + + def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]): + 'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.' + if kf_changed_cache is None: + return + reset_current_hooks = False + # reset current_hooks if contains hook that changed + for hook in kf_changed_cache: + if self.current_hooks is not None: + for current_hook in self.current_hooks.hooks: + if current_hook == hook: + reset_current_hooks = True + break + for cached_group in list(self.cached_hook_patches.keys()): + if cached_group.contains(hook): + self.cached_hook_patches.pop(cached_group) + if reset_current_hooks: + self.patch_hooks(None) def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, registered: comfy.hooks.HookGroup = None): @@ -1566,16 +1706,27 @@ class ModelPatcherDynamic(ModelPatcher): self.model.dynamic_vbars = {} if not hasattr(self.model, "dynamic_pins"): self.model.dynamic_pins = {} - if self.load_device not in self.model.dynamic_pins: - self.model.dynamic_pins[self.load_device] = { + self.register_load_device(self.load_device) + self.non_dynamic_delegate_model = None + assert load_device is not None + + def register_load_device(self, device): + """Ensure dynamic_pins has an entry for *device*. + + Called from __init__ and also from any code that retargets an + already-constructed patcher to a new load_device (e.g. the + Select{Model,CLIP,VAE}Device selector nodes); without this entry + partially_unload_ram() raises KeyError when it tries to read the + per-device pin state. + """ + if device not in self.model.dynamic_pins: + self.model.dynamic_pins[device] = { "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), "hostbufs_initialized": False, "failed": False, "active": False, } - self.non_dynamic_delegate_model = None - assert load_device is not None def is_dynamic(self): return True diff --git a/comfy/multigpu.py b/comfy/multigpu.py new file mode 100644 index 000000000..e7f5b3d6f --- /dev/null +++ b/comfy/multigpu.py @@ -0,0 +1,248 @@ +from __future__ import annotations +import queue +import threading +import torch +import logging + +from collections import namedtuple +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management + + +class MultiGPUThreadPool: + """Persistent thread pool for multi-GPU work distribution. + + Maintains one worker thread per extra GPU device. Each thread calls + torch.cuda.set_device() once at startup so that compiled kernel caches + (inductor/triton) stay warm across diffusion steps. + """ + + def __init__(self, devices: list[torch.device]): + self._workers: list[threading.Thread] = [] + self._work_queues: dict[torch.device, queue.Queue] = {} + self._result_queues: dict[torch.device, queue.Queue] = {} + + for device in devices: + wq = queue.Queue() + rq = queue.Queue() + self._work_queues[device] = wq + self._result_queues[device] = rq + t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True) + t.start() + self._workers.append(t) + + def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): + try: + torch.cuda.set_device(device) + except Exception as e: + logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") + while True: + item = work_q.get() + if item is None: + return + result_q.put((None, e)) + return + while True: + item = work_q.get() + if item is None: + break + fn, args, kwargs = item + try: + result = fn(*args, **kwargs) + result_q.put((result, None)) + except Exception as e: + result_q.put((None, e)) + + def submit(self, device: torch.device, fn, *args, **kwargs): + self._work_queues[device].put((fn, args, kwargs)) + + def get_result(self, device: torch.device): + return self._result_queues[device].get() + + @property + def devices(self) -> list[torch.device]: + return list(self._work_queues.keys()) + + def shutdown(self): + for wq in self._work_queues.values(): + wq.put(None) # sentinel + for t in self._workers: + t.join(timeout=5.0) + + +class GPUOptions: + def __init__(self, device_index: int, relative_speed: float): + self.device_index = device_index + self.relative_speed = relative_speed + + def clone(self): + return GPUOptions(self.device_index, self.relative_speed) + + def create_dict(self): + return { + "relative_speed": self.relative_speed + } + +class GPUOptionsGroup: + def __init__(self): + self.options: dict[int, GPUOptions] = {} + + def add(self, info: GPUOptions): + self.options[info.device_index] = info + + def clone(self): + c = GPUOptionsGroup() + for opt in self.options.values(): + c.add(opt) + return c + + def register(self, model: ModelPatcher): + opts_dict = {} + # get devices that are valid for this model + devices: list[torch.device] = [model.load_device] + for extra_model in model.get_additional_models_with_key("multigpu"): + extra_model: ModelPatcher + devices.append(extra_model.load_device) + # create dictionary with actual device mapped to its GPUOptions + device_opts_list: list[GPUOptions] = [] + for device in devices: + device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) + opts_dict[device] = device_opts.create_dict() + device_opts_list.append(device_opts) + # make relative_speed relative to 1.0 + min_speed = min([x.relative_speed for x in device_opts_list]) + for value in opts_dict.values(): + value['relative_speed'] /= min_speed + model.model_options['multigpu_options'] = opts_dict + + +def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False): + 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.' + model = model.clone() + # check if multigpu is already prepared - get the load devices from them if possible to exclude + skip_devices = set() + multigpu_models = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + for mm in multigpu_models: + skip_devices.add(mm.load_device) + skip_devices = list(skip_devices) + + # Exclude the primary model's actual device, not the global current device: + # after SelectModelDevice(gpu:N) the primary may not live on the process's + # current CUDA device, and excluding the wrong device picks bad extras. + all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False) + full_extra_devices = [d for d in all_devices if d != model.load_device] + limit_extra_devices = full_extra_devices[:max_gpus-1] + extra_devices = limit_extra_devices.copy() + # exclude skipped devices + for skip in skip_devices: + if skip in extra_devices: + extra_devices.remove(skip) + # create new deepclones + if len(extra_devices) > 0: + for device in extra_devices: + device_patcher = None + if reuse_loaded: + # Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice + # patcher on the same device shares clone_base_uuid but has + # is_multigpu_base_clone=False, which would later be filtered out by + # prepare_model_patcher_multigpu_clones() and silently shrink the + # work split back to one GPU. + loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models() + for lm in loaded_models: + if lm.model is None: + continue + if lm.load_device != device: + continue + if lm.clone_base_uuid != model.clone_base_uuid: + continue + if not getattr(lm, "is_multigpu_base_clone", False): + continue + device_patcher = lm.clone() + logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}") + break + if device_patcher is None: + device_patcher = model.deepclone_multigpu(new_load_device=device) + # Always flag the clone; whether reused or freshly deepcloned, it must + # advertise itself as a MultiGPU base clone so the cond scheduler picks + # it up in prepare_model_patcher_multigpu_clones(). + device_patcher.is_multigpu_base_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + model.match_multigpu_clones() + if gpu_options is None: + gpu_options = GPUOptionsGroup() + gpu_options.register(model) + else: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") + # only keep model clones that don't go 'past' the intended max_gpu count; + # this prunes any inherited multigpu clones whose load_device is no longer allowed + # when max_gpus is lowered between runs. + allowed_devices = set(limit_extra_devices) + allowed_devices.add(model.load_device) + multigpu_models = model.get_additional_models_with_key("multigpu") + new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices] + if len(new_multigpu_models) != len(multigpu_models): + model.set_additional_models("multigpu", new_multigpu_models) + model.match_multigpu_clones() + return model + + +LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) +def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): + 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' + opts_dict = model_options['multigpu_options'] + devices = list(model_options['multigpu_clones'].keys()) + speed_per_device = [] + work_per_device = [] + # get sum of each device's relative_speed + total_speed = 0.0 + for opts in opts_dict.values(): + total_speed += opts['relative_speed'] + # get relative work for each device; + # obtained by w = (W*r)/R + for device in devices: + relative_speed = opts_dict[device]['relative_speed'] + relative_work = (total_work*relative_speed) / total_speed + speed_per_device.append(relative_speed) + work_per_device.append(relative_work) + # relative work must be expressed in whole numbers, but likely is a decimal; + # perform rounding while maintaining total sum equal to total work (sum of relative works) + work_per_device = round_preserved(work_per_device) + dict_work_per_device = {} + for device, relative_work in zip(devices, work_per_device): + dict_work_per_device[device] = relative_work + if not return_idle_time: + return LoadBalance(dict_work_per_device, None) + # divide relative work by relative speed to get estimated completion time of said work by each device; + # time here is relative and does not correspond to real-world units + completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)] + # calculate relative time spent by the devices waiting on each other after their work is completed + idle_time = abs(min(completion_time) - max(completion_time)) + # if need to compare work idle time, need to normalize to a common total work + if work_normalized: + idle_time *= (work_normalized/total_work) + + return LoadBalance(dict_work_per_device, idle_time) + +def round_preserved(values: list[float]): + 'Round all values in a list, preserving the combined sum of values.' + # get floor of values; casting to int does it too + floored = [int(x) for x in values] + total_floored = sum(floored) + # get remainder to distribute + remainder = round(sum(values)) - total_floored + # pair values with fractional portions + fractional = [(i, x-floored[i]) for i, x in enumerate(values)] + # sort by fractional part in descending order + fractional.sort(key=lambda x: x[1], reverse=True) + # distribute the remainder + for i in range(remainder): + index = fractional[i][0] + floored[index] += 1 + return floored diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 5ee4d5ee5..4b276b175 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -3,6 +3,8 @@ from typing import Callable class CallbacksMP: ON_CLONE = "on_clone" + ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu" + ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones" ON_LOAD = "on_load_after" ON_DETACH = "on_detach_after" ON_CLEANUP = "on_cleanup" diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 3782fd2d5..bdce2f2d8 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,16 +1,18 @@ from __future__ import annotations +import torch import uuid import math import collections import comfy.model_management import comfy.conds +import comfy.model_patcher import comfy.utils import comfy.hooks import comfy.patcher_extension from typing import TYPE_CHECKING if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel + from comfy.model_patcher import ModelPatcher from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): @@ -119,6 +121,47 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # potentially handle gligen - since not widely used, ignored for now + def estimate_memory(model, noise_shape, conds): cond_shapes = collections.defaultdict(list) cond_shapes_min = {} @@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload) def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): - real_model: BaseModel = None + model.match_multigpu_clones() + preprocess_multigpu_conds(conds, model, model_options) models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? @@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non memory_required += inference_memory minimum_memory_required += inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) - real_model = model.model + real_model: BaseModel = model.model return real_model, conds, models @@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options + +def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict): + ''' + In case multigpu acceleration is enabled, prep ModelPatchers for each device. + ''' + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[model_patcher.load_device] = model_patcher + for x in multigpu_patchers: + x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + x.hook_mode = model_patcher.hook_mode # match main model's hook_mode + multigpu_dict[x.load_device] = x + model_options["multigpu_clones"] = multigpu_dict + return multigpu_patchers diff --git a/comfy/samplers.py b/comfy/samplers.py index c5e36ff05..e31277f7b 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,7 +1,9 @@ from __future__ import annotations + +import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc -from typing import TYPE_CHECKING, Callable, NamedTuple +from typing import TYPE_CHECKING, Callable, NamedTuple, Any if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel @@ -16,6 +18,7 @@ import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows +import comfy.multigpu import comfy.utils import scipy.stats import numpy @@ -141,7 +144,7 @@ def can_concat_cond(c1, c2): return cond_equal_size(c1.conditioning, c2.conditioning) -def cond_cat(c_list): +def cond_cat(c_list, device=None): temp = {} for x in c_list: for k in x: @@ -153,6 +156,8 @@ def cond_cat(c_list): for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) + if device is not None and hasattr(out[k], 'to'): + out[k] = out[k].to(device) return out @@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + # NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic + # (hooked_to_run accumulation, memory-fit batching, per-chunk output + # aggregation) is duplicated there with per-device scheduling layered on top. + if 'multigpu_clones' in model_options: + return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options) out_conds = [] out_counts = [] # separate conds by matching hooks @@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): @@ -344,6 +354,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens return out_conds +def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + # NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks + # accumulation, memory-fit batching, and output aggregation, but adds a + # per-device scheduler, per-device patcher/control lookup, tensor .to(device) + # placement, and MultiGPUThreadPool dispatch around the inner loop. + out_conds = [] + out_counts = [] + # separate conds by matching hooks + hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} + default_conds = [] + has_default_conds = False + + output_device = x_in.device + + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) + + cond = conds[i] + default_c = [] + if cond is not None: + for x in cond: + if 'default' in x: + default_c.append(x) + has_default_conds = True + continue + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + default_conds.append(default_c) + + if has_default_conds: + finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) + + model.current_patcher.prepare_state(timestep, model_options) + + devices = list(model_options['multigpu_clones'].keys()) + device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} + # Track conds currently scheduled per device; single source of truth for capacity checks. + device_load: dict[torch.device, int] = {d: 0 for d in devices} + + total_conds = sum(len(to_run) for to_run in hooked_to_run.values()) + conds_per_device = max(1, math.ceil(total_conds / len(devices))) + + def next_available_device(start: int) -> tuple[int, torch.device]: + """Return (index, device) for the next device with remaining capacity, starting at `start`. + + Scans at most len(devices) positions, so this always terminates. Raises if no device + has remaining capacity, which would indicate a bug in conds_per_device accounting. + """ + for offset in range(len(devices)): + i = (start + offset) % len(devices) + if device_load[devices[i]] < conds_per_device: + return i, devices[i] + raise RuntimeError( + f"MultiGPU scheduler: all {len(devices)} devices at capacity " + f"({conds_per_device}) but conds remain to schedule" + ) + + # run every hooked_to_run separately + index_device = 0 + for hooks, to_run in hooked_to_run.items(): + while len(to_run) > 0: + index_device, current_device = next_available_device(index_device) + remaining_capacity = conds_per_device - device_load[current_device] + + first = to_run[0] + first_shape = first[0][0].shape + # collect candidate indices that can be concatenated with `first`, up to remaining capacity + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity: + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = comfy.model_management.get_free_memory(current_device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + cond_shapes = collections.defaultdict(list) + for tt in batch_amount: + for k, v in to_run[tt][0].conditioning.items(): + cond_shapes[k].append(v.size()) + if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: + to_batch = batch_amount + break + + conds_to_batch = [to_run.pop(x) for x in to_batch] + device_load[current_device] += len(conds_to_batch) + device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch)) + + if device_load[current_device] >= conds_per_device: + index_device += 1 + + class thread_result(NamedTuple): + output: Any + mult: Any + area: Any + batch_chunks: int + cond_or_uncond: Any + error: Exception = None + + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): + try: + # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once + # we extend multigpu QA beyond CUDA. Unconditional call crashes on + # XPU/NPU/MPS/CPU/DirectML backends. + torch.cuda.set_device(device) + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control: ControlBase = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts( + transformer_options.get("patches", {}), + patches + ) + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep.to(device) + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device + + cast_transformer_options(transformer_options, device=device) + c['transformer_options'] = transformer_options + + if control is not None: + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) + else: + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + # TODO: non-NVIDIA support -- the `.to(output_device)` copies + # above are async on CUDA, so the main thread's aggregation + # could race with in-flight transfers. CUDA-only QA has not + # surfaced this in practice, but before extending multigpu + # beyond NVIDIA add a `torch.cuda.synchronize(output_device)` + # here (guarded by `output_device.type == "cuda"`). + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + except Exception as e: + results.append(thread_result(None, None, None, None, None, error=e)) + raise + + + def _handle_batch_pooled(device, batch_tuple): + worker_results = [] + _handle_batch(device, batch_tuple, worker_results) + return worker_results + + results: list[thread_result] = [] + thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool") + + # Submit all GPU work to pool threads + pool_devices = [] + for device, batch_tuple in device_batched_hooked_to_run.items(): + if thread_pool is not None: + thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple) + pool_devices.append(device) + else: + # Fallback: no pool, run everything on main thread + _handle_batch(device, batch_tuple, results) + + # Collect results from pool workers + for device in pool_devices: + worker_results, error = thread_pool.get_result(device) + if error is not None: + raise error + results.extend(worker_results) + + for output, mult, area, batch_chunks, cond_or_uncond, error in results: + if error is not None: + raise error + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) @@ -642,12 +885,21 @@ def calculate_start_end_timesteps(model, conds): def pre_run_control(model, conds): s = model.model_sampling + # Per-device model lookup so multigpu control clones get the matching + # diffusion_model (e.g. QwenFunControlNet stashes it into extra_args). + device_models: dict = {} + patcher = getattr(model, "current_patcher", None) + if patcher is not None: + for p in patcher.get_additional_models_with_key("multigpu"): + device_models[p.load_device] = p.model for t in range(len(conds)): x = conds[t] percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: x['control'].pre_run(model, percent_to_timestep_function) + for device, device_cnet in x['control'].multigpu_clones.items(): + device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -890,7 +1142,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return + cast_transformer_options(to_load_options, device, dtype) +def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None): casts = [] if device is not None: casts.append(device) @@ -899,18 +1153,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - # try to call .to on patches - if "patches" in to_load_options: - patches = to_load_options["patches"] + if "patches" in transformer_options: + patches = transformer_options["patches"] for name in patches: patch_list = patches[name] for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): for cast in casts: patch_list[i] = patch_list[i].to(cast) - if "patches_replace" in to_load_options: - patches = to_load_options["patches_replace"] + if "patches_replace" in transformer_options: + patches = transformer_options["patches_replace"] for name in patches: patch_list = patches[name] for k in patch_list: @@ -920,8 +1173,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: - if wc_name in to_load_options: - wc: dict[str, list] = to_load_options[wc_name] + if wc_name in transformer_options: + wc: dict[str, list] = transformer_options[wc_name] for wc_dict in wc.values(): for wc_list in wc_dict.values(): for i in range(len(wc_list)): @@ -929,7 +1182,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): for cast in casts: wc_list[i] = wc_list[i].to(cast) - class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -984,16 +1236,32 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) - try: - self.model_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) - finally: - self.model_patcher.cleanup() + # Create persistent thread pool for all GPU devices (main + extras) + if multigpu_patchers: + extra_devices = [p.load_device for p in multigpu_patchers] + all_devices = [device] + extra_devices + self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) + + with comfy.model_management.cuda_device_context(device): + try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + + self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) + finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() + self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model diff --git a/comfy/sd.py b/comfy/sd.py index 7bd07ed3a..084170c62 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -335,41 +335,43 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: pbar = ProgressBar(len(scheduled_keyframes)) - for scheduled_opts in scheduled_keyframes: - t_range = scheduled_opts[0] - # don't bother encoding any conds outside of start_percent and end_percent bounds - if "start_percent" in add_dict: - if t_range[1] < add_dict["start_percent"]: - continue - if "end_percent" in add_dict: - if t_range[0] > add_dict["end_percent"]: - continue - hooks_keyframes = scheduled_opts[1] - for hook, keyframe in hooks_keyframes: - hook.hook_keyframe._current_keyframe = keyframe - # apply appropriate hooks with values that match new hook_keyframe - self.patcher.patch_hooks(all_hooks) - # perform encoding as normal - o = self.cond_stage_model.encode_token_weights(tokens) - cond, pooled = o[:2] - pooled_dict = {"pooled_output": pooled} - # add clip_start_percent and clip_end_percent in pooled - pooled_dict["clip_start_percent"] = t_range[0] - pooled_dict["clip_end_percent"] = t_range[1] - # add/update any keys with the provided add_dict - pooled_dict.update(add_dict) - # add hooks stored on clip - self.add_hooks_to_dict(pooled_dict) - all_cond_pooled.append([cond, pooled_dict]) - if show_pbar: - pbar.update(1) - model_management.throw_exception_if_processing_interrupted() + with model_management.cuda_device_context(device): + for scheduled_opts in scheduled_keyframes: + t_range = scheduled_opts[0] + # don't bother encoding any conds outside of start_percent and end_percent bounds + if "start_percent" in add_dict: + if t_range[1] < add_dict["start_percent"]: + continue + if "end_percent" in add_dict: + if t_range[0] > add_dict["end_percent"]: + continue + hooks_keyframes = scheduled_opts[1] + for hook, keyframe in hooks_keyframes: + hook.hook_keyframe._current_keyframe = keyframe + # apply appropriate hooks with values that match new hook_keyframe + self.patcher.patch_hooks(all_hooks) + # perform encoding as normal + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + pooled_dict = {"pooled_output": pooled} + # add clip_start_percent and clip_end_percent in pooled + pooled_dict["clip_start_percent"] = t_range[0] + pooled_dict["clip_end_percent"] = t_range[1] + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + # add hooks stored on clip + self.add_hooks_to_dict(pooled_dict) + all_cond_pooled.append([cond, pooled_dict]) + if show_pbar: + pbar.update(1) + model_management.throw_exception_if_processing_interrupted() all_hooks.reset() return all_cond_pooled @@ -383,8 +385,12 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - o = self.cond_stage_model.encode_token_weights(tokens) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) + + with model_management.cuda_device_context(device): + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] if return_dict: out = {"cond": cond, "pooled_output": pooled} @@ -446,9 +452,12 @@ class CLIP: self.cond_stage_model.reset_clip_options() self.load_model(tokens) + device = self.patcher.load_device self.cond_stage_model.set_clip_options({"layer": None}) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) + self.cond_stage_model.set_clip_options({"execution_device": device}) + + with model_management.cuda_device_context(device): + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) def decode(self, token_ids, skip_special_tokens=True): return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) @@ -1026,50 +1035,52 @@ class VAE: do_tile = False if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] - try: - memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / memory_used) - batch_number = max(1, batch_number) - # Pre-allocate output for VAEs that support direct buffer writes - preallocated = False - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) - preallocated = True + with model_management.cuda_device_context(self.device): + try: + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / memory_used) + batch_number = max(1, batch_number) - for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) - if preallocated: - self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) - else: - out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - pixel_samples[x:x+batch_number].copy_(out) - del out - self.process_output(pixel_samples[x:x+batch_number]) - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True - if do_tile: - comfy.model_management.soft_empty_cache() - dims = samples_in.ndim - 2 - if dims == 1 or self.extra_1d_channel is not None: - pixel_samples = self.decode_tiled_1d(samples_in) - elif dims == 2: - pixel_samples = self.decode_tiled_(samples_in) - elif dims == 3: - tile = 256 // self.spacial_compression_decode() - overlap = tile // 4 - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + for x in range(0, samples_in.shape[0], batch_number): + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + dims = samples_in.ndim - 2 + if dims == 1 or self.extra_1d_channel is not None: + pixel_samples = self.decode_tiled_1d(samples_in) + elif dims == 2: + pixel_samples = self.decode_tiled_(samples_in) + elif dims == 3: + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -1087,20 +1098,21 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1 or self.extra_1d_channel is not None: - args.pop("tile_y") - output = self.decode_tiled_1d(samples, **args) - elif dims == 2: - output = self.decode_tiled_(samples, **args) - elif dims == 3: - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (max(1, overlap_t), overlap, overlap) - if tile_t is not None: - args["tile_t"] = max(2, tile_t) + with model_management.cuda_device_context(self.device): + if dims == 1 or self.extra_1d_channel is not None: + args.pop("tile_y") + output = self.decode_tiled_1d(samples, **args) + elif dims == 2: + output = self.decode_tiled_(samples, **args) + elif dims == 3: + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (max(1, overlap_t), overlap, overlap) + if tile_t is not None: + args["tile_t"] = max(2, tile_t) - output = self.decode_tiled_3d(samples, **args) + output = self.decode_tiled_3d(samples, **args) return output.movedim(1, -1) def encode(self, pixel_samples): @@ -1113,44 +1125,46 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: pixel_samples = pixel_samples.unsqueeze(2) - try: - memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / max(1, memory_used)) - batch_number = max(1, batch_number) - samples = None - for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - out = self.first_stage_model.encode(pixels_in, device=self.device) + + with model_management.cuda_device_context(self.device): + try: + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / max(1, memory_used)) + batch_number = max(1, batch_number) + samples = None + for x in range(0, pixel_samples.shape[0], batch_number): + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) + if samples is None: + samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + samples[x:x + batch_number] = out + + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + if self.latent_dim == 3: + tile = 256 + overlap = tile // 4 + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + elif self.latent_dim == 1 or self.extra_1d_channel is not None: + samples = self.encode_tiled_1d(pixel_samples) else: - pixels_in = pixels_in.to(self.device) - out = self.first_stage_model.encode(pixels_in) - out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) - if samples is None: - samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - samples[x:x + batch_number] = out - - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True - - if do_tile: - comfy.model_management.soft_empty_cache() - if self.latent_dim == 3: - tile = 256 - overlap = tile // 4 - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - elif self.latent_dim == 1 or self.extra_1d_channel is not None: - samples = self.encode_tiled_1d(pixel_samples) - else: - samples = self.encode_tiled_(pixel_samples) + samples = self.encode_tiled_(pixel_samples) return samples @@ -1176,26 +1190,27 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1: - args.pop("tile_y") - samples = self.encode_tiled_1d(pixel_samples, **args) - elif dims == 2: - samples = self.encode_tiled_(pixel_samples, **args) - elif dims == 3: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) - else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + with model_management.cuda_device_context(self.device): + if dims == 1: + args.pop("tile_y") + samples = self.encode_tiled_1d(pixel_samples, **args) + elif dims == 2: + samples = self.encode_tiled_(pixel_samples, **args) + elif dims == 3: + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) return samples @@ -1710,12 +1725,52 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - if output_model and out[0] is not None: - out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) - if output_clip and out[1] is not None: - out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) + if out[0] is not None: + out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) + # Register reload factories for the CLIP and VAE produced by the same checkpoint so + # ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device, + # MultiGPU work-units, etc.) without falling back to copy.deepcopy of an + # already-loaded module. + if out[1] is not None and getattr(out[1], "patcher", None) is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options)) + if out[2] is not None and getattr(out[2], "patcher", None) is not None: + out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options)) return out + +def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + """Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init + factory for the CLIP returned by load_checkpoint_guess_config.""" + _, clip, _, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=False, + output_clip=True, + output_clipvision=False, + embedding_directory=embedding_directory, + output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) + return clip.patcher + + +def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + """Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init + factory for the VAE returned by load_checkpoint_guess_config.""" + _, _, vae, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=False, + output_clipvision=False, + embedding_directory=embedding_directory, + output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) + return vae.patcher + def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, embedding_directory=embedding_directory, @@ -1742,7 +1797,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) - load_device = model_management.get_torch_device() + load_device = model_options.get("load_device", model_management.get_torch_device()) custom_operations = model_options.get("custom_operations", None) if custom_operations is None: @@ -1782,13 +1837,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher - model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + offload_device = model_options.get("offload_device", model_management.unet_offload_device()) + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd, metadata=metadata) + vae_device = model_options.get("load_device", None) + vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device) if output_clip: if te_model_options.get("custom_operations", None) is None: @@ -1872,7 +1929,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) - load_device = model_management.get_torch_device() + load_device = model_options.get("load_device", model_management.get_torch_device()) model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: @@ -1897,7 +1954,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable else: logging.warning("{} {}".format(diffusers_keys[k], k)) - offload_device = model_management.unet_offload_device() + offload_device = model_options.get("offload_device", model_management.unet_offload_device()) unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.quant_config is not None: weight_dtype = None @@ -1939,6 +1996,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model + +def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False): + """Reload a disk-backed VAE from ``vae_path`` and return its patcher. + + Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so + :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a + fresh, untainted VAE patcher (no inherited per-device load state, no + in-place quantization fallout) for multigpu work-units and the + SelectVAEDevice node. The optional ``device`` matches the source loader's + VAE initialization path; the deepclone's ``load_device`` still controls + where the cloned patcher is targeted. + """ + if metadata is None: + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + else: + sd = comfy.utils.load_torch_file(vae_path) + vae = VAE(sd=sd, metadata=metadata, device=device) + vae.throw_exception_if_invalid() + return vae.patcher + def load_unet(unet_path, dtype=None): logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) diff --git a/comfy/utils.py b/comfy/utils.py index 31052714a..49ae12b06 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -86,6 +86,7 @@ def load_safetensors(ckpt): import comfy_aimdo.model_mmap f = open(ckpt, "rb", buffering=0) + file_lock = threading.Lock() model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) file_size = os.path.getsize(ckpt) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) @@ -111,7 +112,7 @@ def load_safetensors(ckpt): storage = tensor.untyped_storage() setattr(storage, "_comfy_tensor_file_slice", - comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) + comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) sd[name] = tensor diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py new file mode 100644 index 000000000..2bd752b7d --- /dev/null +++ b/comfy_extras/nodes_multigpu.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +import copy +import logging +from inspect import cleandoc +from typing import TYPE_CHECKING +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.sd import CLIP, VAE +import torch + +import comfy.model_management +import comfy.multigpu + + +class MultiGPUCFGSplitNode(io.ComfyNode): + """ + Prepares model to have sampling accelerated via splitting work units. + + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_WorkUnits", + display_name="MultiGPU CFG Split", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Model.Input("model"), + io.Int.Input("max_gpus", default=2, min=1, step=1), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) + return io.NodeOutput(model) + + +def _force_fp32_cpu_compute(patcher: ModelPatcher): + """Force fp32 inference dtype for CPU. + + PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16 + and run ~500-600x slower than fp32, which makes a normal-sized workflow + look frozen for hours. Routing through set_model_compute_dtype leaves the + weights as-is and casts at use, so peak memory does not blow up.""" + dtype = patcher.model_dtype() + if dtype in (torch.float16, torch.bfloat16): + logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).") + patcher.set_model_compute_dtype(torch.float32) + + +def _remember_base_devices(patcher: ModelPatcher): + """Stash the original load/offload device on the underlying model. + + Stored on patcher.model (which is shared with the input patcher), so + later "default" selections can recover the loader's original routing. + Only the first Select on a given chain writes these attrs; subsequent + deepclones inherit them onto their freshly-loaded model below. + """ + if not hasattr(patcher.model, "_select_base_load_device"): + patcher.model._select_base_load_device = patcher.load_device + patcher.model._select_base_offload_device = patcher.offload_device + + +def _propagate_base_devices(src_model, dst_model): + """Carry the loader-original device attrs onto the freshly-deepcloned model.""" + if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"): + dst_model._select_base_load_device = src_model._select_base_load_device + dst_model._select_base_offload_device = src_model._select_base_offload_device + + +def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device): + """Return a patcher whose actual model weights live on *target_load_device*. + + If *patcher* is already on *target_load_device* we just retarget the + (already-cloned) patcher's metadata in place. Otherwise we call + :meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from + the loader's ``cached_patcher_init`` factory -- the only safe way to + move weights that may already be partially loaded onto another device. + + NOTE: reusing the input patcher's model when the requested device + matches its current load_device is a deliberate fast path. Anything + that has already mutated the original model (e.g. a prior KSampler + invocation on the same model) will be observed here. This is by + design and documented on the SelectXDeviceNode docstrings -- placing + Select X Device after a node that consumes the same model is not + recommended. + """ + if patcher.load_device == target_load_device: + # Fast path: weights already on the desired device, just update offload. + patcher.offload_device = target_offload_device + return patcher + src_model = patcher.model + patcher = patcher.deepclone_multigpu(new_load_device=target_load_device) + patcher.offload_device = target_offload_device + _propagate_base_devices(src_model, patcher.model) + if hasattr(patcher, "register_load_device"): + patcher.register_load_device(patcher.load_device) + return patcher + + +def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None): + """Resolve the requested device and produce a patcher routed there. + + For "default" we restore the loader's original load/offload pair. + For CPU we pin both load and offload to CPU (and, on a dynamic + patcher, downgrade to a plain ModelPatcher so the dynamic-only + code paths are bypassed). + For an explicit GPU we keep the loader's original offload but + target the requested load device; if that differs from the current + load device the patcher is deepcloned onto the new device. + """ + _remember_base_devices(patcher) + base_load = patcher.model._select_base_load_device + base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device + + if resolved is None: + # "default" -> route back to the loader's original devices. + return _retarget_patcher(patcher, base_load, base_offload) + if resolved.type == "cpu": + if patcher.is_dynamic(): + # clone(disable_dynamic=True) requires cached_patcher_init; let the + # exception surface to the caller (Select*DeviceNode.execute), which + # will translate it into a passthrough+log so unsupported loaders + # don't hard-fail the workflow. + patcher = patcher.clone(disable_dynamic=True) + patcher.load_device = resolved + patcher.offload_device = resolved + return patcher + return _retarget_patcher(patcher, resolved, base_offload) + + +def _prune_multigpu_collision(model: ModelPatcher, primary_device): + """Drop any multigpu clone whose load_device matches *primary_device*. + + Without pruning, MultiGPU CFG Split would have stacked a clone on + the same device the primary now occupies (i.e. the workflow places + MultiGPU CFG Split before Select Model Device). Keeps the clone set + consistent with the new primary placement. + """ + multigpu_models = model.get_additional_models_with_key("multigpu") + if not multigpu_models: + return + filtered = [m for m in multigpu_models if m.load_device != primary_device] + if len(filtered) != len(multigpu_models): + logging.info(f"Select Model Device: pruning MultiGPU clone on {primary_device} that now collides with the primary model.") + model.set_additional_models("multigpu", filtered) + if hasattr(model, "match_multigpu_clones"): + model.match_multigpu_clones() + + +class SelectModelDeviceNode(io.ComfyNode): + """ + Place the diffusion model on a specific device (default / cpu / gpu:N). + + - "default" restores the device assigned by the loader (even after a + prior Select Model Device call). + - "cpu" pins both the load and offload device to CPU. + - "gpu:N" pins the load device to the Nth available GPU; the offload + device is restored to the loader's original choice. + + When the requested device differs from the device the input model is + already on, a fresh model is spawned via the loader's reload factory + (cached_patcher_init) so the new patcher owns independent weights on + the new device. Loaders that don't support multigpu (no factory) will + cause the node to pass through unchanged with a warning. + + If the workflow already has MultiGPU CFG Split applied and the chosen + GPU collides with one of the existing multigpu clones, that clone is + dropped so two patchers don't end up bound to the same device. + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the model through unchanged and logs a message + instead of failing. + + NOTE: Placing Select Model Device *after* a node that has already + consumed the same model (e.g. a KSampler that ran on this model on + the original device) is not recommended -- any state the prior + consumer mutated on the original model will be observed when the + selected device matches the original (fast path). Place Select Model + Device before any consumer of the model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectModelDevice", + display_name="Select Model Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Model.Input("model"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def validate_inputs(cls, device="default"): + # Allow unknown gpu:N values so portable workflows do not error + # at validation time; runtime fallback will handle them. + return True + + @classmethod + def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput: + model = model.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None and device not in (None, "default"): + logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(model) + try: + model = _apply_patcher_device(model, resolved) + except RuntimeError as e: + logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})") + return io.NodeOutput(model) + if resolved is not None: + if resolved.type == "cpu": + _force_fp32_cpu_compute(model) + _prune_multigpu_collision(model, model.load_device) + return io.NodeOutput(model) + + +class SelectCLIPDeviceNode(io.ComfyNode): + """ + Place the CLIP text encoder on a specific device (default / cpu / gpu:N). + + - "default" restores the device assigned by the loader. + - "cpu" pins both the load and offload device to CPU. + - "gpu:N" pins the load device to the Nth available GPU. + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the CLIP through unchanged and logs a message + instead of failing. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectCLIPDevice", + display_name="Select CLIP Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Clip.Input("clip"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()), + ], + outputs=[ + io.Clip.Output(), + ], + ) + + @classmethod + def validate_inputs(cls, device="default"): + return True + + @classmethod + def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput: + clip = clip.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None and device not in (None, "default"): + logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(clip) + try: + clip.patcher = _apply_patcher_device(clip.patcher, resolved) + except RuntimeError as e: + logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})") + return io.NodeOutput(clip) + + +class SelectVAEDeviceNode(io.ComfyNode): + """ + Place the VAE on a specific device (default / gpu:N). + + - "default" restores the device assigned by the loader. + - "gpu:N" pins the load device to the Nth available GPU; the offload + device is set to the standard VAE offload device. + + CPU is intentionally not exposed in the UI for the VAE; if a workflow + supplies "cpu" anyway (e.g. opened from another machine), the request + is dropped with a log message and the VAE is passed through unchanged. + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the VAE through unchanged and logs a message + instead of failing. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectVAEDevice", + display_name="Select VAE Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Vae.Input("vae"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options_no_cpu()), + ], + outputs=[ + io.Vae.Output(), + ], + ) + + @classmethod + def validate_inputs(cls, device="default"): + return True + + @classmethod + def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput: + # VAE has no .clone(); shallow-copy the wrapper and clone the patcher + # so we can retarget load/offload device without affecting the input VAE. + vae = copy.copy(vae) + vae.patcher = vae.patcher.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None and device not in (None, "default"): + logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(vae) + if resolved is not None and resolved.type == "cpu": + logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.") + return io.NodeOutput(vae) + if not hasattr(vae, "_select_base_device"): + vae._select_base_device = vae.device + try: + vae.patcher = _apply_patcher_device( + vae.patcher, resolved, + base_offload_override=comfy.model_management.vae_offload_device(), + ) + except RuntimeError as e: + logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})") + return io.NodeOutput(vae) + # Keep VAE wrapper in sync with whatever model the patcher now owns; + # deepclone_multigpu may have produced a fresh first_stage_model. + vae.first_stage_model = vae.patcher.model + vae.device = vae._select_base_device if resolved is None else resolved + return io.NodeOutput(vae) + + +class MultiGPUOptionsNode(io.ComfyNode): + """ + Select the relative speed of GPUs in the special case they have significantly different performance from one another. + + NOTE (not registered yet, see MultiGPUExtension.get_node_list below): + The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on + model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond + scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult + relative_speed when distributing conds across devices; it uses a uniform conds_per_device + round-robin via next_available_device(). Before re-enabling this node, wire its + relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(), + which already implements the proportional split) so the input actually affects work + distribution. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_Options", + display_name="MultiGPU Options", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Int.Input("device_index", default=0, min=0, max=64), + io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01), + io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + ], + outputs=[ + io.Custom("GPU_OPTIONS").Output(), + ], + ) + + @classmethod + def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: + if not gpu_options: + gpu_options = comfy.multigpu.GPUOptionsGroup() + else: + gpu_options = gpu_options.clone() + + opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) + gpu_options.add(opt) + + return io.NodeOutput(gpu_options) + + +class MultiGPUExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + MultiGPUCFGSplitNode, + SelectModelDeviceNode, + SelectCLIPDeviceNode, + SelectVAEDeviceNode, + # MultiGPUOptionsNode, + ] + + +async def comfy_entrypoint() -> MultiGPUExtension: + return MultiGPUExtension() diff --git a/main.py b/main.py index 26d523c30..bce451a83 100644 --- a/main.py +++ b/main.py @@ -218,7 +218,7 @@ import comfy.model_patcher if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + elif comfy_aimdo.control.init_devices(d.index for d in comfy.model_management.get_all_torch_devices()): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': diff --git a/nodes.py b/nodes.py index 13e46ac8a..fd4365c90 100644 --- a/nodes.py +++ b/nodes.py @@ -795,6 +795,7 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): metadata = None + vae_path = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -813,6 +814,14 @@ class VAELoader: metadata["tae_latent_channels"] = 128 vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() + # Register a reload factory on the patcher so multigpu deepclones + # (Select VAE Device, future MultiGPU VAE work-units) can produce + # per-device clones from the same loader context. Only set when we + # actually have a single backing file -- pixel_space and the + # image TAESDs (composed from separate encoder/decoder files via + # load_taesd) are not addressable by a single vae_path. + if vae_path is not None: + vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None)) return (vae,) class ControlNetLoader: @@ -2389,6 +2398,7 @@ async def init_builtin_extra_nodes(): "nodes_lt_audio.py", "nodes_lt.py", "nodes_hooks.py", + "nodes_multigpu.py", "nodes_load_3d.py", "nodes_cosmos.py", "nodes_video.py", diff --git a/server.py b/server.py index 44470b904..268441bd1 100644 --- a/server.py +++ b/server.py @@ -646,18 +646,37 @@ class PromptServer(): @routes.get("/system_stats") async def system_stats(request): - device = comfy.model_management.get_torch_device() - device_name = comfy.model_management.get_torch_device_name(device) + primary_device = comfy.model_management.get_torch_device() cpu_device = comfy.model_management.torch.device("cpu") ram_total = comfy.model_management.get_total_memory(cpu_device) ram_free = comfy.model_management.get_free_memory(cpu_device) - vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) - vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) required_frontend_version = FrontendManager.get_required_frontend_version() installed_templates_version = FrontendManager.get_installed_templates_version() required_templates_version = FrontendManager.get_required_templates_version() comfy_package_versions = FrontendManager.get_comfy_package_versions() + # Report every torch device visible to multigpu, with the primary + # device first so existing clients that read devices[0] keep working. + torch_devices = comfy.model_management.get_all_torch_devices() + if primary_device in torch_devices: + torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device] + else: + torch_devices = [primary_device] + list(torch_devices) + + device_entries = [] + for d in torch_devices: + vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True) + device_entries.append({ + "name": comfy.model_management.get_torch_device_name(d), + "type": d.type, + "index": d.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + }) + system_stats = { "system": { "os": sys.platform, @@ -673,17 +692,7 @@ class PromptServer(): "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", "argv": sys.argv }, - "devices": [ - { - "name": device_name, - "type": device.type, - "index": device.index, - "vram_total": vram_total, - "vram_free": vram_free, - "torch_vram_total": torch_vram_total, - "torch_vram_free": torch_vram_free, - } - ] + "devices": device_entries } return web.json_response(system_stats)