diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 47b8174f4..9bda414d1 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
-parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
+parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index ba670b16d..6dbbaa959 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see .
"""
-
+from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
+import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@@ -64,6 +65,18 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
+class ControlIsolation:
+ '''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
+ def __init__(self, control: ControlBase):
+ self.control = control
+ self.orig_previous_controlnet = control.previous_controlnet
+
+ def __enter__(self):
+ self.control.previous_controlnet = None
+
+ def __exit__(self, *args):
+ self.control.previous_controlnet = self.orig_previous_controlnet
+
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@@ -77,7 +90,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
- self.previous_controlnet = None
+ self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@@ -85,6 +98,7 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
+ self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -111,17 +125,38 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
-
+ for device_cnet in self.multigpu_clones.values():
+ with ControlIsolation(device_cnet):
+ device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
+ for device_cnet in self.multigpu_clones.values():
+ out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
+ def get_models_only_self(self):
+ 'Calls get_models, but temporarily sets previous_controlnet to None.'
+ with ControlIsolation(self):
+ return self.get_models()
+
+ def get_instance_for_device(self, device):
+ 'Returns instance of this Control object intended for selected device.'
+ return self.multigpu_clones.get(device, self)
+
+ def deepclone_multigpu(self, load_device, autoregister=False):
+ '''
+ Create deep clone of Control object where model(s) is set to other devices.
+
+ When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
+ '''
+ raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
+
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@@ -130,7 +165,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
- def copy_to(self, c):
+ def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@@ -284,6 +319,14 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
+ def deepclone_multigpu(self, load_device, autoregister=False):
+ c = self.copy()
+ c.control_model = copy.deepcopy(c.control_model)
+ c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
+ if autoregister:
+ self.multigpu_clones[load_device] = c
+ return c
+
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
+ def cleanup(self):
+ self.extra_args.pop("base_model", None)
+ super().cleanup()
+
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
@@ -906,6 +953,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
+ def deepclone_multigpu(self, load_device, autoregister=False):
+ c = self.copy()
+ c.t2i_model = copy.deepcopy(c.t2i_model)
+ c.device = load_device
+ if autoregister:
+ self.multigpu_clones[load_device] = c
+ return c
+
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'
diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
index bc36b8998..4e4819fe3 100644
--- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
+++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
@@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module):
def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2)
- if context.shape[0] >= 2:
- uncond_emb, cond_emb = context.chunk(2, dim = 0)
- context = torch.cat([cond_emb, uncond_emb], dim = 0)
+
+ swap_cfg_halves = context.shape[0] >= 2
+
+ if swap_cfg_halves:
+ first_half, second_half = context.chunk(2, dim = 0)
+ context = torch.cat([second_half, first_half], dim = 0)
+
main_condition = context
t = 1.0 - t
@@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module):
output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0)
- if output.shape[0] >= 2:
- cond_emb, uncond_emb = output.chunk(2, dim = 0)
- return torch.cat([uncond_emb, cond_emb])
- else:
- return output
+ if swap_cfg_halves:
+ first_half, second_half = output.chunk(2, dim = 0)
+ output = torch.cat([second_half, first_half], dim = 0)
+
+ return output
diff --git a/comfy/memory_management.py b/comfy/memory_management.py
index c43f0c4a2..962addb27 100644
--- a/comfy/memory_management.py
+++ b/comfy/memory_management.py
@@ -1,6 +1,5 @@
import math
import ctypes
-import threading
import dataclasses
import torch
from typing import NamedTuple
@@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
- thread_id: int
+ lock: object
offset: int
size: int
@@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
- or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
@@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0
- hostbuf.read_file_slice(file_obj, info.offset, info.size,
- offset=destination.data_ptr() - hostbuf.get_raw_address(),
- stream=stream_ptr,
- device_ptr=device_ptr,
- device=None if destination2 is None else destination2.device.index)
+ with info.lock:
+ hostbuf.read_file_slice(file_obj, info.offset, info.size,
+ offset=destination.data_ptr() - hostbuf.get_raw_address(),
+ stream=stream_ptr,
+ device_ptr=device_ptr,
+ device=None if destination2 is None else destination2.device.index)
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
- file_obj.seek(info.offset)
- done = 0
- while done < info.size:
- try:
- n = file_obj.readinto(view[done:])
- except OSError:
- return False
- if n <= 0:
- return False
- done += n
+ with info.lock:
+ file_obj.seek(info.offset)
+ done = 0
+ while done < info.size:
+ try:
+ n = file_obj.readinto(view[done:])
+ except OSError:
+ return False
+ if n <= 0:
+ return False
+ done += n
return True
finally:
view.release()
diff --git a/comfy/model_management.py b/comfy/model_management.py
index cd8772d3a..b01c4d7fa 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see .
"""
+from __future__ import annotations
import psutil
import logging
@@ -27,13 +28,18 @@ import platform
import weakref
import gc
import os
-from contextlib import nullcontext
+from contextlib import contextmanager, nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from comfy.model_patcher import ModelPatcher
+
+
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -204,6 +210,107 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
+def get_all_torch_devices(exclude_current=False):
+ global cpu_state
+ devices = []
+ if cpu_state == CPUState.GPU:
+ # NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
+ # without the AMD arm, single-GPU ROCm users get an empty list
+ # which silently turns unload_all_models() into a no-op.
+ if is_nvidia() or is_amd():
+ for i in range(torch.cuda.device_count()):
+ devices.append(torch.device("cuda", i))
+ elif is_intel_xpu():
+ for i in range(torch.xpu.device_count()):
+ devices.append(torch.device("xpu", i))
+ elif is_ascend_npu():
+ for i in range(torch.npu.device_count()):
+ devices.append(torch.device("npu", i))
+ elif is_mlu():
+ for i in range(torch.mlu.device_count()):
+ devices.append(torch.device("mlu", i))
+ else:
+ # Fallback for unhandled GPU backends (e.g. DirectML): at least
+ # report the current device so callers like unload_all_models()
+ # do not silently no-op.
+ devices.append(get_torch_device())
+ else:
+ devices.append(get_torch_device())
+ if exclude_current:
+ current = get_torch_device()
+ if current in devices:
+ devices.remove(current)
+ return devices
+
+def get_gpu_device_options():
+ """Return list of device option strings for node widgets.
+
+ Always includes "default" and "cpu". When multiple GPUs are present,
+ adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
+ """
+ options = ["default", "cpu"]
+ devices = get_all_torch_devices()
+ if len(devices) > 1:
+ for i in range(len(devices)):
+ options.append(f"gpu:{i}")
+ return options
+
+def get_gpu_device_options_no_cpu():
+ """Variant of get_gpu_device_options that omits "cpu".
+
+ Intended for components like the VAE selector where running on CPU
+ is impractical and should not be offered as a choice.
+ """
+ return [o for o in get_gpu_device_options() if o != "cpu"]
+
+def resolve_gpu_device_option(option: str):
+ """Resolve a device option string to a torch.device.
+
+ Returns None for "default" (let the caller use its normal default).
+ Returns torch.device("cpu") for "cpu".
+ For "gpu:N", returns the Nth torch device. Returns None if the
+ index is out of range, the option string is malformed, or
+ unrecognized (callers are expected to log their own context-rich
+ message before falling back to the default device).
+ """
+ if option is None or option == "default":
+ return None
+ if option == "cpu":
+ return torch.device("cpu")
+ if option.startswith("gpu:"):
+ try:
+ idx = int(option[4:])
+ except ValueError:
+ return None
+ devices = get_all_torch_devices()
+ if 0 <= idx < len(devices):
+ return devices[idx]
+ return None
+
+@contextmanager
+def cuda_device_context(device):
+ """Context manager that sets torch.cuda.current_device to match *device*.
+
+ Used when running operations on a non-default CUDA device so that custom
+ CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
+ device index. The previous device is restored on exit.
+
+ No-op when *device* is not CUDA, has no explicit index, or already matches
+ the current device.
+ """
+ prev = None
+ if device.type == "cuda" and device.index is not None:
+ prev = torch.cuda.current_device()
+ if prev != device.index:
+ torch.cuda.set_device(device)
+ else:
+ prev = None
+ try:
+ yield
+ finally:
+ if prev is not None:
+ torch.cuda.set_device(prev)
+
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@@ -492,9 +599,13 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
+try:
+ for device in get_all_torch_devices(exclude_current=True):
+ logging.info("Device: {}".format(get_torch_device_name(device)))
+except:
+ pass
-
-current_loaded_models = []
+current_loaded_models: list[LoadedModel] = []
DIRTY_MMAPS = set()
@@ -554,7 +665,7 @@ def ensure_pin_registerable(size, evict_active=False):
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel:
- def __init__(self, model):
+ def __init__(self, model: ModelPatcher):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@@ -562,7 +673,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
- def _set_model(self, model):
+ def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@@ -573,6 +684,7 @@ class LoadedModel:
model = self._parent_model()
if model is not None:
self._set_model(model)
+ self.device = model.load_device
@property
def model(self):
@@ -1848,7 +1960,34 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
- free_memory(1e30, get_torch_device())
+ for device in get_all_torch_devices():
+ free_memory(1e30, device)
+
+def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
+ 'Unload only model and its clones - primarily for multigpu cloning purposes.'
+ initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
+ additional_models = []
+ if unload_additional_models:
+ additional_models = model.get_nested_additional_models()
+ keep_loaded = []
+ for loaded_model in initial_keep_loaded:
+ if loaded_model.model is not None:
+ if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
+ continue
+ # check additional models if they are a match
+ skip = False
+ for add_model in additional_models:
+ if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
+ skip = True
+ break
+ if skip:
+ continue
+ keep_loaded.append(loaded_model)
+ if not all_devices:
+ free_memory(1e30, get_torch_device(), keep_loaded)
+ else:
+ for device in get_all_torch_devices():
+ free_memory(1e30, device, keep_loaded)
def debug_memory_summary():
if is_amd() or is_nvidia():
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index b44b99e4a..00a15fa63 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -78,12 +78,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
-def create_hook_patches_clone(orig_hook_patches):
+def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
+ if copy_tuples:
+ for i in range(len(new_hook_patches[hook_ref][k])):
+ new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@@ -329,7 +332,10 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
- self.cached_patcher_init: tuple[Callable, tuple] | None = None
+ self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
+ self.is_multigpu_base_clone = False
+ self.clone_base_uuid = uuid.uuid4()
+
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -366,7 +372,8 @@ class ModelPatcher:
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
- aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
+ aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
+ aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self):
@@ -380,6 +387,8 @@ class ModelPatcher:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
+ if len(self.cached_patcher_init) > 2:
+ temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
@@ -438,19 +447,113 @@ class ModelPatcher:
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
+ n.is_multigpu_base_clone = self.is_multigpu_base_clone
+ n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
+ def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
+ logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
+ if self.cached_patcher_init is None:
+ raise RuntimeError(
+ f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
+ "the loader that produced this model does not support multigpu "
+ "(cached_patcher_init is not initialized). Use a core loader "
+ "(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
+ "or have the custom loader register a cached_patcher_init factory."
+ )
+ comfy.model_management.unload_model_and_clones(self)
+ # Produce a freshly-loaded patcher from the loader factory so the multigpu
+ # clone owns its own untainted model weights (rather than relying on
+ # copy.deepcopy of an already-patched/already-loaded module).
+ temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
+ if len(self.cached_patcher_init) > 2:
+ temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
+ # Override clone()'s normal "share self.model + share backup containers" with
+ # the pristine model from temp_model_patcher plus empty backup containers --
+ # the fresh model has no patches applied, so any deepcopy of self's stale
+ # backup/object_patches_backup/pinned would just propagate dead state that
+ # no longer corresponds to anything in n.model.
+ model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
+ n = self.clone(model_override=model_override)
+ # clone() copies hook_backup by reference from self; reset since model is pristine.
+ n.hook_backup = {}
+ # set load device, if present
+ if new_load_device is not None:
+ n.load_device = new_load_device
+ # Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
+ # has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
+ # __init__ only registered its own (default) load_device.
+ if hasattr(n, "register_load_device"):
+ n.register_load_device(n.load_device)
+ # multigpu clone should not have multigpu additional_models entry
+ n.remove_additional_models("multigpu")
+ # multigpu_clone all stored additional_models; make sure circular references are properly handled
+ if models_cache is None:
+ models_cache = {}
+ for key, model_list in n.additional_models.items():
+ for i in range(len(model_list)):
+ add_model = n.additional_models[key][i]
+ if add_model.clone_base_uuid not in models_cache:
+ models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
+ n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
+ for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
+ callback(self, n)
+ return n
+
+ def match_multigpu_clones(self):
+ multigpu_models = self.get_additional_models_with_key("multigpu")
+ if len(multigpu_models) > 0:
+ new_multigpu_models = []
+ for mm in multigpu_models:
+ # clone main model, but bring over relevant props from existing multigpu clone
+ n = self.clone()
+ n.load_device = mm.load_device
+ n.backup = mm.backup
+ n.object_patches_backup = mm.object_patches_backup
+ n.hook_backup = mm.hook_backup
+ n.model = mm.model
+ n.is_multigpu_base_clone = mm.is_multigpu_base_clone
+ n.remove_additional_models("multigpu")
+ orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
+ n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
+ # figure out which additional models are not present in multigpu clone
+ models_cache = {}
+ for mm_add_model in mm.get_additional_models():
+ models_cache[mm_add_model.clone_base_uuid] = mm_add_model
+ remove_models_uuids = set(list(models_cache.keys()))
+ for key, model_list in orig_additional_models.items():
+ for orig_add_model in model_list:
+ if orig_add_model.clone_base_uuid not in models_cache:
+ models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
+ existing_list = n.get_additional_models_with_key(key)
+ existing_list.append(models_cache[orig_add_model.clone_base_uuid])
+ n.set_additional_models(key, existing_list)
+ if orig_add_model.clone_base_uuid in remove_models_uuids:
+ remove_models_uuids.remove(orig_add_model.clone_base_uuid)
+ # remove duplicate additional models
+ for key, model_list in n.additional_models.items():
+ new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
+ n.set_additional_models(key, new_model_list)
+ for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
+ callback(self, n)
+ new_multigpu_models.append(n)
+ self.set_additional_models("multigpu", new_multigpu_models)
+
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
- def clone_has_same_weights(self, clone: 'ModelPatcher'):
- if not self.is_clone(clone):
- return False
+ def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
+ if allow_multigpu:
+ if self.clone_base_uuid != clone.clone_base_uuid:
+ return False
+ else:
+ if not self.is_clone(clone):
+ return False
if self.current_hooks != clone.current_hooks:
return False
@@ -1232,7 +1335,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
- all_models = []
+ all_models: list[ModelPatcher] = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@@ -1286,9 +1389,18 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
- def prepare_state(self, timestep):
+ def prepare_state(self, timestep, model_options):
+ ignore_multigpu = model_options.get("ignore_multigpu", False)
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
- callback(self, timestep)
+ callback(self, timestep, model_options)
+ if not ignore_multigpu and "multigpu_clones" in model_options:
+ model_options["ignore_multigpu"] = True
+ try:
+ for p in model_options["multigpu_clones"].values():
+ p: ModelPatcher
+ p.prepare_state(timestep, model_options)
+ finally:
+ model_options.pop("ignore_multigpu", None)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@@ -1301,12 +1413,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
+ multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
+ # cache changed for multigpu usage
+ if "multigpu_clones" in model_options:
+ if multigpu_kf_changed_cache is None:
+ multigpu_kf_changed_cache = []
+ multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@@ -1318,6 +1436,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
+ if "multigpu_clones" in model_options:
+ for p in model_options["multigpu_clones"].values():
+ p: ModelPatcher
+ p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
+
+ def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
+ 'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
+ if kf_changed_cache is None:
+ return
+ reset_current_hooks = False
+ # reset current_hooks if contains hook that changed
+ for hook in kf_changed_cache:
+ if self.current_hooks is not None:
+ for current_hook in self.current_hooks.hooks:
+ if current_hook == hook:
+ reset_current_hooks = True
+ break
+ for cached_group in list(self.cached_hook_patches.keys()):
+ if cached_group.contains(hook):
+ self.cached_hook_patches.pop(cached_group)
+ if reset_current_hooks:
+ self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
@@ -1566,16 +1706,27 @@ class ModelPatcherDynamic(ModelPatcher):
self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
- if self.load_device not in self.model.dynamic_pins:
- self.model.dynamic_pins[self.load_device] = {
+ self.register_load_device(self.load_device)
+ self.non_dynamic_delegate_model = None
+ assert load_device is not None
+
+ def register_load_device(self, device):
+ """Ensure dynamic_pins has an entry for *device*.
+
+ Called from __init__ and also from any code that retargets an
+ already-constructed patcher to a new load_device (e.g. the
+ Select{Model,CLIP,VAE}Device selector nodes); without this entry
+ partially_unload_ram() raises KeyError when it tries to read the
+ per-device pin state.
+ """
+ if device not in self.model.dynamic_pins:
+ self.model.dynamic_pins[device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"hostbufs_initialized": False,
"failed": False,
"active": False,
}
- self.non_dynamic_delegate_model = None
- assert load_device is not None
def is_dynamic(self):
return True
diff --git a/comfy/multigpu.py b/comfy/multigpu.py
new file mode 100644
index 000000000..e7f5b3d6f
--- /dev/null
+++ b/comfy/multigpu.py
@@ -0,0 +1,248 @@
+from __future__ import annotations
+import queue
+import threading
+import torch
+import logging
+
+from collections import namedtuple
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from comfy.model_patcher import ModelPatcher
+import comfy.utils
+import comfy.patcher_extension
+import comfy.model_management
+
+
+class MultiGPUThreadPool:
+ """Persistent thread pool for multi-GPU work distribution.
+
+ Maintains one worker thread per extra GPU device. Each thread calls
+ torch.cuda.set_device() once at startup so that compiled kernel caches
+ (inductor/triton) stay warm across diffusion steps.
+ """
+
+ def __init__(self, devices: list[torch.device]):
+ self._workers: list[threading.Thread] = []
+ self._work_queues: dict[torch.device, queue.Queue] = {}
+ self._result_queues: dict[torch.device, queue.Queue] = {}
+
+ for device in devices:
+ wq = queue.Queue()
+ rq = queue.Queue()
+ self._work_queues[device] = wq
+ self._result_queues[device] = rq
+ t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
+ t.start()
+ self._workers.append(t)
+
+ def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
+ try:
+ torch.cuda.set_device(device)
+ except Exception as e:
+ logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
+ while True:
+ item = work_q.get()
+ if item is None:
+ return
+ result_q.put((None, e))
+ return
+ while True:
+ item = work_q.get()
+ if item is None:
+ break
+ fn, args, kwargs = item
+ try:
+ result = fn(*args, **kwargs)
+ result_q.put((result, None))
+ except Exception as e:
+ result_q.put((None, e))
+
+ def submit(self, device: torch.device, fn, *args, **kwargs):
+ self._work_queues[device].put((fn, args, kwargs))
+
+ def get_result(self, device: torch.device):
+ return self._result_queues[device].get()
+
+ @property
+ def devices(self) -> list[torch.device]:
+ return list(self._work_queues.keys())
+
+ def shutdown(self):
+ for wq in self._work_queues.values():
+ wq.put(None) # sentinel
+ for t in self._workers:
+ t.join(timeout=5.0)
+
+
+class GPUOptions:
+ def __init__(self, device_index: int, relative_speed: float):
+ self.device_index = device_index
+ self.relative_speed = relative_speed
+
+ def clone(self):
+ return GPUOptions(self.device_index, self.relative_speed)
+
+ def create_dict(self):
+ return {
+ "relative_speed": self.relative_speed
+ }
+
+class GPUOptionsGroup:
+ def __init__(self):
+ self.options: dict[int, GPUOptions] = {}
+
+ def add(self, info: GPUOptions):
+ self.options[info.device_index] = info
+
+ def clone(self):
+ c = GPUOptionsGroup()
+ for opt in self.options.values():
+ c.add(opt)
+ return c
+
+ def register(self, model: ModelPatcher):
+ opts_dict = {}
+ # get devices that are valid for this model
+ devices: list[torch.device] = [model.load_device]
+ for extra_model in model.get_additional_models_with_key("multigpu"):
+ extra_model: ModelPatcher
+ devices.append(extra_model.load_device)
+ # create dictionary with actual device mapped to its GPUOptions
+ device_opts_list: list[GPUOptions] = []
+ for device in devices:
+ device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
+ opts_dict[device] = device_opts.create_dict()
+ device_opts_list.append(device_opts)
+ # make relative_speed relative to 1.0
+ min_speed = min([x.relative_speed for x in device_opts_list])
+ for value in opts_dict.values():
+ value['relative_speed'] /= min_speed
+ model.model_options['multigpu_options'] = opts_dict
+
+
+def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
+ 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
+ model = model.clone()
+ # check if multigpu is already prepared - get the load devices from them if possible to exclude
+ skip_devices = set()
+ multigpu_models = model.get_additional_models_with_key("multigpu")
+ if len(multigpu_models) > 0:
+ for mm in multigpu_models:
+ skip_devices.add(mm.load_device)
+ skip_devices = list(skip_devices)
+
+ # Exclude the primary model's actual device, not the global current device:
+ # after SelectModelDevice(gpu:N) the primary may not live on the process's
+ # current CUDA device, and excluding the wrong device picks bad extras.
+ all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False)
+ full_extra_devices = [d for d in all_devices if d != model.load_device]
+ limit_extra_devices = full_extra_devices[:max_gpus-1]
+ extra_devices = limit_extra_devices.copy()
+ # exclude skipped devices
+ for skip in skip_devices:
+ if skip in extra_devices:
+ extra_devices.remove(skip)
+ # create new deepclones
+ if len(extra_devices) > 0:
+ for device in extra_devices:
+ device_patcher = None
+ if reuse_loaded:
+ # Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice
+ # patcher on the same device shares clone_base_uuid but has
+ # is_multigpu_base_clone=False, which would later be filtered out by
+ # prepare_model_patcher_multigpu_clones() and silently shrink the
+ # work split back to one GPU.
+ loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
+ for lm in loaded_models:
+ if lm.model is None:
+ continue
+ if lm.load_device != device:
+ continue
+ if lm.clone_base_uuid != model.clone_base_uuid:
+ continue
+ if not getattr(lm, "is_multigpu_base_clone", False):
+ continue
+ device_patcher = lm.clone()
+ logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}")
+ break
+ if device_patcher is None:
+ device_patcher = model.deepclone_multigpu(new_load_device=device)
+ # Always flag the clone; whether reused or freshly deepcloned, it must
+ # advertise itself as a MultiGPU base clone so the cond scheduler picks
+ # it up in prepare_model_patcher_multigpu_clones().
+ device_patcher.is_multigpu_base_clone = True
+ multigpu_models = model.get_additional_models_with_key("multigpu")
+ multigpu_models.append(device_patcher)
+ model.set_additional_models("multigpu", multigpu_models)
+ model.match_multigpu_clones()
+ if gpu_options is None:
+ gpu_options = GPUOptionsGroup()
+ gpu_options.register(model)
+ else:
+ logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
+ # only keep model clones that don't go 'past' the intended max_gpu count;
+ # this prunes any inherited multigpu clones whose load_device is no longer allowed
+ # when max_gpus is lowered between runs.
+ allowed_devices = set(limit_extra_devices)
+ allowed_devices.add(model.load_device)
+ multigpu_models = model.get_additional_models_with_key("multigpu")
+ new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
+ if len(new_multigpu_models) != len(multigpu_models):
+ model.set_additional_models("multigpu", new_multigpu_models)
+ model.match_multigpu_clones()
+ return model
+
+
+LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
+def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
+ 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
+ opts_dict = model_options['multigpu_options']
+ devices = list(model_options['multigpu_clones'].keys())
+ speed_per_device = []
+ work_per_device = []
+ # get sum of each device's relative_speed
+ total_speed = 0.0
+ for opts in opts_dict.values():
+ total_speed += opts['relative_speed']
+ # get relative work for each device;
+ # obtained by w = (W*r)/R
+ for device in devices:
+ relative_speed = opts_dict[device]['relative_speed']
+ relative_work = (total_work*relative_speed) / total_speed
+ speed_per_device.append(relative_speed)
+ work_per_device.append(relative_work)
+ # relative work must be expressed in whole numbers, but likely is a decimal;
+ # perform rounding while maintaining total sum equal to total work (sum of relative works)
+ work_per_device = round_preserved(work_per_device)
+ dict_work_per_device = {}
+ for device, relative_work in zip(devices, work_per_device):
+ dict_work_per_device[device] = relative_work
+ if not return_idle_time:
+ return LoadBalance(dict_work_per_device, None)
+ # divide relative work by relative speed to get estimated completion time of said work by each device;
+ # time here is relative and does not correspond to real-world units
+ completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
+ # calculate relative time spent by the devices waiting on each other after their work is completed
+ idle_time = abs(min(completion_time) - max(completion_time))
+ # if need to compare work idle time, need to normalize to a common total work
+ if work_normalized:
+ idle_time *= (work_normalized/total_work)
+
+ return LoadBalance(dict_work_per_device, idle_time)
+
+def round_preserved(values: list[float]):
+ 'Round all values in a list, preserving the combined sum of values.'
+ # get floor of values; casting to int does it too
+ floored = [int(x) for x in values]
+ total_floored = sum(floored)
+ # get remainder to distribute
+ remainder = round(sum(values)) - total_floored
+ # pair values with fractional portions
+ fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
+ # sort by fractional part in descending order
+ fractional.sort(key=lambda x: x[1], reverse=True)
+ # distribute the remainder
+ for i in range(remainder):
+ index = fractional[i][0]
+ floored[index] += 1
+ return floored
diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py
index 5ee4d5ee5..4b276b175 100644
--- a/comfy/patcher_extension.py
+++ b/comfy/patcher_extension.py
@@ -3,6 +3,8 @@ from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
+ ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
+ ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index 3782fd2d5..bdce2f2d8 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -1,16 +1,18 @@
from __future__ import annotations
+import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
+import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
+ from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
@@ -119,6 +121,47 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
+def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
+ '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
+ multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
+ if len(multigpu_models) == 0:
+ return
+ extra_devices = [x.load_device for x in multigpu_models]
+ # handle controlnets
+ controlnets: set[ControlBase] = set()
+ for k in conds:
+ for kk in conds[k]:
+ if 'control' in kk:
+ controlnets.add(kk['control'])
+ if len(controlnets) > 0:
+ # first, unload all controlnet clones
+ for cnet in list(controlnets):
+ cnet_models = cnet.get_models()
+ for cm in cnet_models:
+ comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
+
+ # next, make sure each controlnet has a deepclone for all relevant devices
+ for cnet in controlnets:
+ curr_cnet = cnet
+ while curr_cnet is not None:
+ for device in extra_devices:
+ if device not in curr_cnet.multigpu_clones:
+ curr_cnet.deepclone_multigpu(device, autoregister=True)
+ curr_cnet = curr_cnet.previous_controlnet
+ # since all device clones are now present, recreate the linked list for cloned cnets per device
+ for cnet in controlnets:
+ curr_cnet = cnet
+ while curr_cnet is not None:
+ prev_cnet = curr_cnet.previous_controlnet
+ for device in extra_devices:
+ device_cnet = curr_cnet.get_instance_for_device(device)
+ prev_device_cnet = None
+ if prev_cnet is not None:
+ prev_device_cnet = prev_cnet.get_instance_for_device(device)
+ device_cnet.set_previous_controlnet(prev_device_cnet)
+ curr_cnet = prev_cnet
+ # potentially handle gligen - since not widely used, ignored for now
+
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
- real_model: BaseModel = None
+ model.match_multigpu_clones()
+ preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
- real_model = model.model
+ real_model: BaseModel = model.model
return real_model, conds, models
@@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
+
+def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
+ '''
+ In case multigpu acceleration is enabled, prep ModelPatchers for each device.
+ '''
+ multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
+ if len(multigpu_patchers) > 0:
+ multigpu_dict: dict[torch.device, ModelPatcher] = {}
+ multigpu_dict[model_patcher.load_device] = model_patcher
+ for x in multigpu_patchers:
+ x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
+ x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
+ multigpu_dict[x.load_device] = x
+ model_options["multigpu_clones"] = multigpu_dict
+ return multigpu_patchers
diff --git a/comfy/samplers.py b/comfy/samplers.py
index c5e36ff05..e31277f7b 100755
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -1,7 +1,9 @@
from __future__ import annotations
+
+import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
-from typing import TYPE_CHECKING, Callable, NamedTuple
+from typing import TYPE_CHECKING, Callable, NamedTuple, Any
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
@@ -16,6 +18,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
+import comfy.multigpu
import comfy.utils
import scipy.stats
import numpy
@@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
-def cond_cat(c_list):
+def cond_cat(c_list, device=None):
temp = {}
for x in c_list:
for k in x:
@@ -153,6 +156,8 @@ def cond_cat(c_list):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
+ if device is not None and hasattr(out[k], 'to'):
+ out[k] = out[k].to(device)
return out
@@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
)
return executor.execute(model, conds, x_in, timestep, model_options)
-def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
+def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ # NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
+ # (hooked_to_run accumulation, memory-fit batching, per-chunk output
+ # aggregation) is duplicated there with per-device scheduling layered on top.
+ if 'multigpu_clones' in model_options:
+ return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
- model.current_patcher.prepare_state(timestep)
+ model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@@ -344,6 +354,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds
+def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ # NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
+ # accumulation, memory-fit batching, and output aggregation, but adds a
+ # per-device scheduler, per-device patcher/control lookup, tensor .to(device)
+ # placement, and MultiGPUThreadPool dispatch around the inner loop.
+ out_conds = []
+ out_counts = []
+ # separate conds by matching hooks
+ hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
+ default_conds = []
+ has_default_conds = False
+
+ output_device = x_in.device
+
+ for i in range(len(conds)):
+ out_conds.append(torch.zeros_like(x_in))
+ out_counts.append(torch.ones_like(x_in) * 1e-37)
+
+ cond = conds[i]
+ default_c = []
+ if cond is not None:
+ for x in cond:
+ if 'default' in x:
+ default_c.append(x)
+ has_default_conds = True
+ continue
+ p = get_area_and_mult(x, x_in, timestep)
+ if p is None:
+ continue
+ if p.hooks is not None:
+ model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
+ hooked_to_run.setdefault(p.hooks, list())
+ hooked_to_run[p.hooks] += [(p, i)]
+ default_conds.append(default_c)
+
+ if has_default_conds:
+ finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
+
+ model.current_patcher.prepare_state(timestep, model_options)
+
+ devices = list(model_options['multigpu_clones'].keys())
+ device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
+ # Track conds currently scheduled per device; single source of truth for capacity checks.
+ device_load: dict[torch.device, int] = {d: 0 for d in devices}
+
+ total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
+ conds_per_device = max(1, math.ceil(total_conds / len(devices)))
+
+ def next_available_device(start: int) -> tuple[int, torch.device]:
+ """Return (index, device) for the next device with remaining capacity, starting at `start`.
+
+ Scans at most len(devices) positions, so this always terminates. Raises if no device
+ has remaining capacity, which would indicate a bug in conds_per_device accounting.
+ """
+ for offset in range(len(devices)):
+ i = (start + offset) % len(devices)
+ if device_load[devices[i]] < conds_per_device:
+ return i, devices[i]
+ raise RuntimeError(
+ f"MultiGPU scheduler: all {len(devices)} devices at capacity "
+ f"({conds_per_device}) but conds remain to schedule"
+ )
+
+ # run every hooked_to_run separately
+ index_device = 0
+ for hooks, to_run in hooked_to_run.items():
+ while len(to_run) > 0:
+ index_device, current_device = next_available_device(index_device)
+ remaining_capacity = conds_per_device - device_load[current_device]
+
+ first = to_run[0]
+ first_shape = first[0][0].shape
+ # collect candidate indices that can be concatenated with `first`, up to remaining capacity
+ to_batch_temp = []
+ for x in range(len(to_run)):
+ if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
+ to_batch_temp += [x]
+
+ to_batch_temp.reverse()
+ to_batch = to_batch_temp[:1]
+
+ free_memory = comfy.model_management.get_free_memory(current_device)
+ for i in range(1, len(to_batch_temp) + 1):
+ batch_amount = to_batch_temp[:len(to_batch_temp)//i]
+ input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
+ cond_shapes = collections.defaultdict(list)
+ for tt in batch_amount:
+ for k, v in to_run[tt][0].conditioning.items():
+ cond_shapes[k].append(v.size())
+ if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
+ to_batch = batch_amount
+ break
+
+ conds_to_batch = [to_run.pop(x) for x in to_batch]
+ device_load[current_device] += len(conds_to_batch)
+ device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
+
+ if device_load[current_device] >= conds_per_device:
+ index_device += 1
+
+ class thread_result(NamedTuple):
+ output: Any
+ mult: Any
+ area: Any
+ batch_chunks: int
+ cond_or_uncond: Any
+ error: Exception = None
+
+ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
+ try:
+ # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
+ # we extend multigpu QA beyond CUDA. Unconditional call crashes on
+ # XPU/NPU/MPS/CPU/DirectML backends.
+ torch.cuda.set_device(device)
+ model_current: BaseModel = model_options["multigpu_clones"][device].model
+ # run every hooked_to_run separately
+ with torch.no_grad():
+ for hooks, to_batch in batch_tuple:
+ input_x = []
+ mult = []
+ c = []
+ cond_or_uncond = []
+ uuids = []
+ area = []
+ control: ControlBase = None
+ patches = None
+ for x in to_batch:
+ o = x
+ p = o[0]
+ input_x.append(p.input_x)
+ mult.append(p.mult)
+ c.append(p.conditioning)
+ area.append(p.area)
+ cond_or_uncond.append(o[1])
+ uuids.append(p.uuid)
+ control = p.control
+ patches = p.patches
+
+ batch_chunks = len(cond_or_uncond)
+ input_x = torch.cat(input_x).to(device)
+ c = cond_cat(c, device=device)
+ timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
+
+ transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
+ if 'transformer_options' in model_options:
+ transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
+ model_options['transformer_options'],
+ copy_dict1=False)
+
+ if patches is not None:
+ transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
+ transformer_options.get("patches", {}),
+ patches
+ )
+
+ transformer_options["cond_or_uncond"] = cond_or_uncond[:]
+ transformer_options["uuids"] = uuids[:]
+ transformer_options["sigmas"] = timestep.to(device)
+ transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
+ transformer_options["multigpu_thread_device"] = device
+
+ cast_transformer_options(transformer_options, device=device)
+ c['transformer_options'] = transformer_options
+
+ if control is not None:
+ device_control = control.get_instance_for_device(device)
+ c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
+
+ if 'model_function_wrapper' in model_options:
+ output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
+ else:
+ output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
+ # TODO: non-NVIDIA support -- the `.to(output_device)` copies
+ # above are async on CUDA, so the main thread's aggregation
+ # could race with in-flight transfers. CUDA-only QA has not
+ # surfaced this in practice, but before extending multigpu
+ # beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
+ # here (guarded by `output_device.type == "cuda"`).
+ results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
+ except Exception as e:
+ results.append(thread_result(None, None, None, None, None, error=e))
+ raise
+
+
+ def _handle_batch_pooled(device, batch_tuple):
+ worker_results = []
+ _handle_batch(device, batch_tuple, worker_results)
+ return worker_results
+
+ results: list[thread_result] = []
+ thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
+
+ # Submit all GPU work to pool threads
+ pool_devices = []
+ for device, batch_tuple in device_batched_hooked_to_run.items():
+ if thread_pool is not None:
+ thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
+ pool_devices.append(device)
+ else:
+ # Fallback: no pool, run everything on main thread
+ _handle_batch(device, batch_tuple, results)
+
+ # Collect results from pool workers
+ for device in pool_devices:
+ worker_results, error = thread_pool.get_result(device)
+ if error is not None:
+ raise error
+ results.extend(worker_results)
+
+ for output, mult, area, batch_chunks, cond_or_uncond, error in results:
+ if error is not None:
+ raise error
+ for o in range(batch_chunks):
+ cond_index = cond_or_uncond[o]
+ a = area[o]
+ if a is None:
+ out_conds[cond_index] += output[o] * mult[o]
+ out_counts[cond_index] += mult[o]
+ else:
+ out_c = out_conds[cond_index]
+ out_cts = out_counts[cond_index]
+ dims = len(a) // 2
+ for i in range(dims):
+ out_c = out_c.narrow(i + 2, a[i + dims], a[i])
+ out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
+ out_c += output[o] * mult[o]
+ out_cts += mult[o]
+
+ for i in range(len(out_conds)):
+ out_conds[i] /= out_counts[i]
+
+ return out_conds
+
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@@ -642,12 +885,21 @@ def calculate_start_end_timesteps(model, conds):
def pre_run_control(model, conds):
s = model.model_sampling
+ # Per-device model lookup so multigpu control clones get the matching
+ # diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
+ device_models: dict = {}
+ patcher = getattr(model, "current_patcher", None)
+ if patcher is not None:
+ for p in patcher.get_additional_models_with_key("multigpu"):
+ device_models[p.load_device] = p.model
for t in range(len(conds)):
x = conds[t]
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
+ for device, device_cnet in x['control'].multigpu_clones.items():
+ device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@@ -890,7 +1142,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
+ cast_transformer_options(to_load_options, device, dtype)
+def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@@ -899,18 +1153,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing
if len(casts) == 0:
return
-
# try to call .to on patches
- if "patches" in to_load_options:
- patches = to_load_options["patches"]
+ if "patches" in transformer_options:
+ patches = transformer_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
- if "patches_replace" in to_load_options:
- patches = to_load_options["patches_replace"]
+ if "patches_replace" in transformer_options:
+ patches = transformer_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@@ -920,8 +1173,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
- if wc_name in to_load_options:
- wc: dict[str, list] = to_load_options[wc_name]
+ if wc_name in transformer_options:
+ wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@@ -929,7 +1182,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
-
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@@ -984,16 +1236,32 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
- noise = noise.to(device=device, dtype=torch.float32)
- latent_image = latent_image.to(device=device, dtype=torch.float32)
- sigmas = sigmas.to(device)
- cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
+ multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
- try:
- self.model_patcher.pre_run()
- output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
- finally:
- self.model_patcher.cleanup()
+ # Create persistent thread pool for all GPU devices (main + extras)
+ if multigpu_patchers:
+ extra_devices = [p.load_device for p in multigpu_patchers]
+ all_devices = [device] + extra_devices
+ self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
+
+ with comfy.model_management.cuda_device_context(device):
+ try:
+ noise = noise.to(device=device, dtype=torch.float32)
+ latent_image = latent_image.to(device=device, dtype=torch.float32)
+ sigmas = sigmas.to(device)
+ cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
+
+ self.model_patcher.pre_run()
+ for multigpu_patcher in multigpu_patchers:
+ multigpu_patcher.pre_run()
+ output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
+ finally:
+ thread_pool = self.model_options.pop("multigpu_thread_pool", None)
+ if thread_pool is not None:
+ thread_pool.shutdown()
+ self.model_patcher.cleanup()
+ for multigpu_patcher in multigpu_patchers:
+ multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
diff --git a/comfy/sd.py b/comfy/sd.py
index 7bd07ed3a..084170c62 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -335,41 +335,43 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
- self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
+ device = self.patcher.load_device
+ self.cond_stage_model.set_clip_options({"execution_device": device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
- for scheduled_opts in scheduled_keyframes:
- t_range = scheduled_opts[0]
- # don't bother encoding any conds outside of start_percent and end_percent bounds
- if "start_percent" in add_dict:
- if t_range[1] < add_dict["start_percent"]:
- continue
- if "end_percent" in add_dict:
- if t_range[0] > add_dict["end_percent"]:
- continue
- hooks_keyframes = scheduled_opts[1]
- for hook, keyframe in hooks_keyframes:
- hook.hook_keyframe._current_keyframe = keyframe
- # apply appropriate hooks with values that match new hook_keyframe
- self.patcher.patch_hooks(all_hooks)
- # perform encoding as normal
- o = self.cond_stage_model.encode_token_weights(tokens)
- cond, pooled = o[:2]
- pooled_dict = {"pooled_output": pooled}
- # add clip_start_percent and clip_end_percent in pooled
- pooled_dict["clip_start_percent"] = t_range[0]
- pooled_dict["clip_end_percent"] = t_range[1]
- # add/update any keys with the provided add_dict
- pooled_dict.update(add_dict)
- # add hooks stored on clip
- self.add_hooks_to_dict(pooled_dict)
- all_cond_pooled.append([cond, pooled_dict])
- if show_pbar:
- pbar.update(1)
- model_management.throw_exception_if_processing_interrupted()
+ with model_management.cuda_device_context(device):
+ for scheduled_opts in scheduled_keyframes:
+ t_range = scheduled_opts[0]
+ # don't bother encoding any conds outside of start_percent and end_percent bounds
+ if "start_percent" in add_dict:
+ if t_range[1] < add_dict["start_percent"]:
+ continue
+ if "end_percent" in add_dict:
+ if t_range[0] > add_dict["end_percent"]:
+ continue
+ hooks_keyframes = scheduled_opts[1]
+ for hook, keyframe in hooks_keyframes:
+ hook.hook_keyframe._current_keyframe = keyframe
+ # apply appropriate hooks with values that match new hook_keyframe
+ self.patcher.patch_hooks(all_hooks)
+ # perform encoding as normal
+ o = self.cond_stage_model.encode_token_weights(tokens)
+ cond, pooled = o[:2]
+ pooled_dict = {"pooled_output": pooled}
+ # add clip_start_percent and clip_end_percent in pooled
+ pooled_dict["clip_start_percent"] = t_range[0]
+ pooled_dict["clip_end_percent"] = t_range[1]
+ # add/update any keys with the provided add_dict
+ pooled_dict.update(add_dict)
+ # add hooks stored on clip
+ self.add_hooks_to_dict(pooled_dict)
+ all_cond_pooled.append([cond, pooled_dict])
+ if show_pbar:
+ pbar.update(1)
+ model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled
@@ -383,8 +385,12 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
- self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
- o = self.cond_stage_model.encode_token_weights(tokens)
+ device = self.patcher.load_device
+ self.cond_stage_model.set_clip_options({"execution_device": device})
+
+ with model_management.cuda_device_context(device):
+ o = self.cond_stage_model.encode_token_weights(tokens)
+
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
@@ -446,9 +452,12 @@ class CLIP:
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
+ device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"layer": None})
- self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
- return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
+ self.cond_stage_model.set_clip_options({"execution_device": device})
+
+ with model_management.cuda_device_context(device):
+ return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@@ -1026,50 +1035,52 @@ class VAE:
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
- try:
- memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
- model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
- free_memory = self.patcher.get_free_memory(self.device)
- batch_number = int(free_memory / memory_used)
- batch_number = max(1, batch_number)
- # Pre-allocate output for VAEs that support direct buffer writes
- preallocated = False
- if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
- pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
- preallocated = True
+ with model_management.cuda_device_context(self.device):
+ try:
+ memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
+ model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
+ free_memory = self.patcher.get_free_memory(self.device)
+ batch_number = int(free_memory / memory_used)
+ batch_number = max(1, batch_number)
- for x in range(0, samples_in.shape[0], batch_number):
- samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
- if preallocated:
- self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
- else:
- out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
- if pixel_samples is None:
- pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
- pixel_samples[x:x+batch_number].copy_(out)
- del out
- self.process_output(pixel_samples[x:x+batch_number])
- except Exception as e:
- model_management.raise_non_oom(e)
- logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
- #NOTE: We don't know what tensors were allocated to stack variables at the time of the
- #exception and the exception itself refs them all until we get out of this except block.
- #So we just set a flag for tiler fallback so that tensor gc can happen once the
- #exception is fully off the books.
- do_tile = True
+ # Pre-allocate output for VAEs that support direct buffer writes
+ preallocated = False
+ if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
+ pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
+ preallocated = True
- if do_tile:
- comfy.model_management.soft_empty_cache()
- dims = samples_in.ndim - 2
- if dims == 1 or self.extra_1d_channel is not None:
- pixel_samples = self.decode_tiled_1d(samples_in)
- elif dims == 2:
- pixel_samples = self.decode_tiled_(samples_in)
- elif dims == 3:
- tile = 256 // self.spacial_compression_decode()
- overlap = tile // 4
- pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
+ for x in range(0, samples_in.shape[0], batch_number):
+ samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
+ if preallocated:
+ self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
+ else:
+ out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
+ if pixel_samples is None:
+ pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
+ pixel_samples[x:x+batch_number].copy_(out)
+ del out
+ self.process_output(pixel_samples[x:x+batch_number])
+ except Exception as e:
+ model_management.raise_non_oom(e)
+ logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
+ #NOTE: We don't know what tensors were allocated to stack variables at the time of the
+ #exception and the exception itself refs them all until we get out of this except block.
+ #So we just set a flag for tiler fallback so that tensor gc can happen once the
+ #exception is fully off the books.
+ do_tile = True
+
+ if do_tile:
+ comfy.model_management.soft_empty_cache()
+ dims = samples_in.ndim - 2
+ if dims == 1 or self.extra_1d_channel is not None:
+ pixel_samples = self.decode_tiled_1d(samples_in)
+ elif dims == 2:
+ pixel_samples = self.decode_tiled_(samples_in)
+ elif dims == 3:
+ tile = 256 // self.spacial_compression_decode()
+ overlap = tile // 4
+ pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@@ -1087,20 +1098,21 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
- if dims == 1 or self.extra_1d_channel is not None:
- args.pop("tile_y")
- output = self.decode_tiled_1d(samples, **args)
- elif dims == 2:
- output = self.decode_tiled_(samples, **args)
- elif dims == 3:
- if overlap_t is None:
- args["overlap"] = (1, overlap, overlap)
- else:
- args["overlap"] = (max(1, overlap_t), overlap, overlap)
- if tile_t is not None:
- args["tile_t"] = max(2, tile_t)
+ with model_management.cuda_device_context(self.device):
+ if dims == 1 or self.extra_1d_channel is not None:
+ args.pop("tile_y")
+ output = self.decode_tiled_1d(samples, **args)
+ elif dims == 2:
+ output = self.decode_tiled_(samples, **args)
+ elif dims == 3:
+ if overlap_t is None:
+ args["overlap"] = (1, overlap, overlap)
+ else:
+ args["overlap"] = (max(1, overlap_t), overlap, overlap)
+ if tile_t is not None:
+ args["tile_t"] = max(2, tile_t)
- output = self.decode_tiled_3d(samples, **args)
+ output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples):
@@ -1113,44 +1125,46 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
pixel_samples = pixel_samples.unsqueeze(2)
- try:
- memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
- model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
- free_memory = self.patcher.get_free_memory(self.device)
- batch_number = int(free_memory / max(1, memory_used))
- batch_number = max(1, batch_number)
- samples = None
- for x in range(0, pixel_samples.shape[0], batch_number):
- pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
- if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
- out = self.first_stage_model.encode(pixels_in, device=self.device)
+
+ with model_management.cuda_device_context(self.device):
+ try:
+ memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
+ model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
+ free_memory = self.patcher.get_free_memory(self.device)
+ batch_number = int(free_memory / max(1, memory_used))
+ batch_number = max(1, batch_number)
+ samples = None
+ for x in range(0, pixel_samples.shape[0], batch_number):
+ pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
+ if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
+ out = self.first_stage_model.encode(pixels_in, device=self.device)
+ else:
+ pixels_in = pixels_in.to(self.device)
+ out = self.first_stage_model.encode(pixels_in)
+ out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
+ if samples is None:
+ samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
+ samples[x:x + batch_number] = out
+
+ except Exception as e:
+ model_management.raise_non_oom(e)
+ logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
+ #NOTE: We don't know what tensors were allocated to stack variables at the time of the
+ #exception and the exception itself refs them all until we get out of this except block.
+ #So we just set a flag for tiler fallback so that tensor gc can happen once the
+ #exception is fully off the books.
+ do_tile = True
+
+ if do_tile:
+ comfy.model_management.soft_empty_cache()
+ if self.latent_dim == 3:
+ tile = 256
+ overlap = tile // 4
+ samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
+ elif self.latent_dim == 1 or self.extra_1d_channel is not None:
+ samples = self.encode_tiled_1d(pixel_samples)
else:
- pixels_in = pixels_in.to(self.device)
- out = self.first_stage_model.encode(pixels_in)
- out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
- if samples is None:
- samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
- samples[x:x + batch_number] = out
-
- except Exception as e:
- model_management.raise_non_oom(e)
- logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
- #NOTE: We don't know what tensors were allocated to stack variables at the time of the
- #exception and the exception itself refs them all until we get out of this except block.
- #So we just set a flag for tiler fallback so that tensor gc can happen once the
- #exception is fully off the books.
- do_tile = True
-
- if do_tile:
- comfy.model_management.soft_empty_cache()
- if self.latent_dim == 3:
- tile = 256
- overlap = tile // 4
- samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
- elif self.latent_dim == 1 or self.extra_1d_channel is not None:
- samples = self.encode_tiled_1d(pixel_samples)
- else:
- samples = self.encode_tiled_(pixel_samples)
+ samples = self.encode_tiled_(pixel_samples)
return samples
@@ -1176,26 +1190,27 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
- if dims == 1:
- args.pop("tile_y")
- samples = self.encode_tiled_1d(pixel_samples, **args)
- elif dims == 2:
- samples = self.encode_tiled_(pixel_samples, **args)
- elif dims == 3:
- if tile_t is not None:
- tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
- else:
- tile_t_latent = 9999
- args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
+ with model_management.cuda_device_context(self.device):
+ if dims == 1:
+ args.pop("tile_y")
+ samples = self.encode_tiled_1d(pixel_samples, **args)
+ elif dims == 2:
+ samples = self.encode_tiled_(pixel_samples, **args)
+ elif dims == 3:
+ if tile_t is not None:
+ tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
+ else:
+ tile_t_latent = 9999
+ args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
- if overlap_t is None:
- args["overlap"] = (1, overlap, overlap)
- else:
- args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
- maximum = pixel_samples.shape[2]
- maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
+ if overlap_t is None:
+ args["overlap"] = (1, overlap, overlap)
+ else:
+ args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
+ maximum = pixel_samples.shape[2]
+ maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
- samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
+ samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples
@@ -1710,12 +1725,52 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
- if output_model and out[0] is not None:
- out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
- if output_clip and out[1] is not None:
- out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
+ if out[0] is not None:
+ out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
+ # Register reload factories for the CLIP and VAE produced by the same checkpoint so
+ # ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
+ # MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
+ # already-loaded module.
+ if out[1] is not None and getattr(out[1], "patcher", None) is not None:
+ out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
+ if out[2] is not None and getattr(out[2], "patcher", None) is not None:
+ out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
+
+def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
+ """Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
+ factory for the CLIP returned by load_checkpoint_guess_config."""
+ _, clip, _, _ = load_checkpoint_guess_config(
+ ckpt_path,
+ output_vae=False,
+ output_clip=True,
+ output_clipvision=False,
+ embedding_directory=embedding_directory,
+ output_model=False,
+ model_options=model_options,
+ te_model_options=te_model_options,
+ disable_dynamic=disable_dynamic,
+ )
+ return clip.patcher
+
+
+def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
+ """Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
+ factory for the VAE returned by load_checkpoint_guess_config."""
+ _, _, vae, _ = load_checkpoint_guess_config(
+ ckpt_path,
+ output_vae=True,
+ output_clip=False,
+ output_clipvision=False,
+ embedding_directory=embedding_directory,
+ output_model=False,
+ model_options=model_options,
+ te_model_options=te_model_options,
+ disable_dynamic=disable_dynamic,
+ )
+ return vae.patcher
+
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
@@ -1742,7 +1797,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
- load_device = model_management.get_torch_device()
+ load_device = model_options.get("load_device", model_management.get_torch_device())
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
@@ -1782,13 +1837,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
- model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
+ offload_device = model_options.get("offload_device", model_management.unet_offload_device())
+ model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
- vae = VAE(sd=vae_sd, metadata=metadata)
+ vae_device = model_options.get("load_device", None)
+ vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
if output_clip:
if te_model_options.get("custom_operations", None) is None:
@@ -1872,7 +1929,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
- load_device = model_management.get_torch_device()
+ load_device = model_options.get("load_device", model_management.get_torch_device())
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
@@ -1897,7 +1954,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
- offload_device = model_management.unet_offload_device()
+ offload_device = model_options.get("offload_device", model_management.unet_offload_device())
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
weight_dtype = None
@@ -1939,6 +1996,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model
+
+def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
+ """Reload a disk-backed VAE from ``vae_path`` and return its patcher.
+
+ Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
+ :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
+ fresh, untainted VAE patcher (no inherited per-device load state, no
+ in-place quantization fallout) for multigpu work-units and the
+ SelectVAEDevice node. The optional ``device`` matches the source loader's
+ VAE initialization path; the deepclone's ``load_device`` still controls
+ where the cloned patcher is targeted.
+ """
+ if metadata is None:
+ sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
+ else:
+ sd = comfy.utils.load_torch_file(vae_path)
+ vae = VAE(sd=sd, metadata=metadata, device=device)
+ vae.throw_exception_if_invalid()
+ return vae.patcher
+
def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
diff --git a/comfy/utils.py b/comfy/utils.py
index 31052714a..49ae12b06 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -86,6 +86,7 @@ def load_safetensors(ckpt):
import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0)
+ file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@@ -111,7 +112,7 @@ def load_safetensors(ckpt):
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
- comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
+ comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
sd[name] = tensor
diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py
new file mode 100644
index 000000000..2bd752b7d
--- /dev/null
+++ b/comfy_extras/nodes_multigpu.py
@@ -0,0 +1,412 @@
+from __future__ import annotations
+
+import copy
+import logging
+from inspect import cleandoc
+from typing import TYPE_CHECKING
+from typing_extensions import override
+
+from comfy_api.latest import ComfyExtension, io
+
+if TYPE_CHECKING:
+ from comfy.model_patcher import ModelPatcher
+ from comfy.sd import CLIP, VAE
+import torch
+
+import comfy.model_management
+import comfy.multigpu
+
+
+class MultiGPUCFGSplitNode(io.ComfyNode):
+ """
+ Prepares model to have sampling accelerated via splitting work units.
+
+ Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
+
+ Other than those exceptions, this node can be placed in any order.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="MultiGPU_WorkUnits",
+ display_name="MultiGPU CFG Split",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Model.Input("model"),
+ io.Int.Input("max_gpus", default=2, min=1, step=1),
+ ],
+ outputs=[
+ io.Model.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
+ model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
+ return io.NodeOutput(model)
+
+
+def _force_fp32_cpu_compute(patcher: ModelPatcher):
+ """Force fp32 inference dtype for CPU.
+
+ PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16
+ and run ~500-600x slower than fp32, which makes a normal-sized workflow
+ look frozen for hours. Routing through set_model_compute_dtype leaves the
+ weights as-is and casts at use, so peak memory does not blow up."""
+ dtype = patcher.model_dtype()
+ if dtype in (torch.float16, torch.bfloat16):
+ logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).")
+ patcher.set_model_compute_dtype(torch.float32)
+
+
+def _remember_base_devices(patcher: ModelPatcher):
+ """Stash the original load/offload device on the underlying model.
+
+ Stored on patcher.model (which is shared with the input patcher), so
+ later "default" selections can recover the loader's original routing.
+ Only the first Select on a given chain writes these attrs; subsequent
+ deepclones inherit them onto their freshly-loaded model below.
+ """
+ if not hasattr(patcher.model, "_select_base_load_device"):
+ patcher.model._select_base_load_device = patcher.load_device
+ patcher.model._select_base_offload_device = patcher.offload_device
+
+
+def _propagate_base_devices(src_model, dst_model):
+ """Carry the loader-original device attrs onto the freshly-deepcloned model."""
+ if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"):
+ dst_model._select_base_load_device = src_model._select_base_load_device
+ dst_model._select_base_offload_device = src_model._select_base_offload_device
+
+
+def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device):
+ """Return a patcher whose actual model weights live on *target_load_device*.
+
+ If *patcher* is already on *target_load_device* we just retarget the
+ (already-cloned) patcher's metadata in place. Otherwise we call
+ :meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from
+ the loader's ``cached_patcher_init`` factory -- the only safe way to
+ move weights that may already be partially loaded onto another device.
+
+ NOTE: reusing the input patcher's model when the requested device
+ matches its current load_device is a deliberate fast path. Anything
+ that has already mutated the original model (e.g. a prior KSampler
+ invocation on the same model) will be observed here. This is by
+ design and documented on the SelectXDeviceNode docstrings -- placing
+ Select X Device after a node that consumes the same model is not
+ recommended.
+ """
+ if patcher.load_device == target_load_device:
+ # Fast path: weights already on the desired device, just update offload.
+ patcher.offload_device = target_offload_device
+ return patcher
+ src_model = patcher.model
+ patcher = patcher.deepclone_multigpu(new_load_device=target_load_device)
+ patcher.offload_device = target_offload_device
+ _propagate_base_devices(src_model, patcher.model)
+ if hasattr(patcher, "register_load_device"):
+ patcher.register_load_device(patcher.load_device)
+ return patcher
+
+
+def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
+ """Resolve the requested device and produce a patcher routed there.
+
+ For "default" we restore the loader's original load/offload pair.
+ For CPU we pin both load and offload to CPU (and, on a dynamic
+ patcher, downgrade to a plain ModelPatcher so the dynamic-only
+ code paths are bypassed).
+ For an explicit GPU we keep the loader's original offload but
+ target the requested load device; if that differs from the current
+ load device the patcher is deepcloned onto the new device.
+ """
+ _remember_base_devices(patcher)
+ base_load = patcher.model._select_base_load_device
+ base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
+
+ if resolved is None:
+ # "default" -> route back to the loader's original devices.
+ return _retarget_patcher(patcher, base_load, base_offload)
+ if resolved.type == "cpu":
+ if patcher.is_dynamic():
+ # clone(disable_dynamic=True) requires cached_patcher_init; let the
+ # exception surface to the caller (Select*DeviceNode.execute), which
+ # will translate it into a passthrough+log so unsupported loaders
+ # don't hard-fail the workflow.
+ patcher = patcher.clone(disable_dynamic=True)
+ patcher.load_device = resolved
+ patcher.offload_device = resolved
+ return patcher
+ return _retarget_patcher(patcher, resolved, base_offload)
+
+
+def _prune_multigpu_collision(model: ModelPatcher, primary_device):
+ """Drop any multigpu clone whose load_device matches *primary_device*.
+
+ Without pruning, MultiGPU CFG Split would have stacked a clone on
+ the same device the primary now occupies (i.e. the workflow places
+ MultiGPU CFG Split before Select Model Device). Keeps the clone set
+ consistent with the new primary placement.
+ """
+ multigpu_models = model.get_additional_models_with_key("multigpu")
+ if not multigpu_models:
+ return
+ filtered = [m for m in multigpu_models if m.load_device != primary_device]
+ if len(filtered) != len(multigpu_models):
+ logging.info(f"Select Model Device: pruning MultiGPU clone on {primary_device} that now collides with the primary model.")
+ model.set_additional_models("multigpu", filtered)
+ if hasattr(model, "match_multigpu_clones"):
+ model.match_multigpu_clones()
+
+
+class SelectModelDeviceNode(io.ComfyNode):
+ """
+ Place the diffusion model on a specific device (default / cpu / gpu:N).
+
+ - "default" restores the device assigned by the loader (even after a
+ prior Select Model Device call).
+ - "cpu" pins both the load and offload device to CPU.
+ - "gpu:N" pins the load device to the Nth available GPU; the offload
+ device is restored to the loader's original choice.
+
+ When the requested device differs from the device the input model is
+ already on, a fresh model is spawned via the loader's reload factory
+ (cached_patcher_init) so the new patcher owns independent weights on
+ the new device. Loaders that don't support multigpu (no factory) will
+ cause the node to pass through unchanged with a warning.
+
+ If the workflow already has MultiGPU CFG Split applied and the chosen
+ GPU collides with one of the existing multigpu clones, that clone is
+ dropped so two patchers don't end up bound to the same device.
+
+ When the selected device does not exist on the current machine
+ (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
+ the node passes the model through unchanged and logs a message
+ instead of failing.
+
+ NOTE: Placing Select Model Device *after* a node that has already
+ consumed the same model (e.g. a KSampler that ran on this model on
+ the original device) is not recommended -- any state the prior
+ consumer mutated on the original model will be observed when the
+ selected device matches the original (fast path). Place Select Model
+ Device before any consumer of the model.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SelectModelDevice",
+ display_name="Select Model Device",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Model.Input("model"),
+ io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
+ ],
+ outputs=[
+ io.Model.Output(),
+ ],
+ )
+
+ @classmethod
+ def validate_inputs(cls, device="default"):
+ # Allow unknown gpu:N values so portable workflows do not error
+ # at validation time; runtime fallback will handle them.
+ return True
+
+ @classmethod
+ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput:
+ model = model.clone()
+ resolved = comfy.model_management.resolve_gpu_device_option(device)
+ if resolved is None and device not in (None, "default"):
+ logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
+ return io.NodeOutput(model)
+ try:
+ model = _apply_patcher_device(model, resolved)
+ except RuntimeError as e:
+ logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
+ return io.NodeOutput(model)
+ if resolved is not None:
+ if resolved.type == "cpu":
+ _force_fp32_cpu_compute(model)
+ _prune_multigpu_collision(model, model.load_device)
+ return io.NodeOutput(model)
+
+
+class SelectCLIPDeviceNode(io.ComfyNode):
+ """
+ Place the CLIP text encoder on a specific device (default / cpu / gpu:N).
+
+ - "default" restores the device assigned by the loader.
+ - "cpu" pins both the load and offload device to CPU.
+ - "gpu:N" pins the load device to the Nth available GPU.
+
+ When the selected device does not exist on the current machine
+ (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
+ the node passes the CLIP through unchanged and logs a message
+ instead of failing.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SelectCLIPDevice",
+ display_name="Select CLIP Device",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Clip.Input("clip"),
+ io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
+ ],
+ outputs=[
+ io.Clip.Output(),
+ ],
+ )
+
+ @classmethod
+ def validate_inputs(cls, device="default"):
+ return True
+
+ @classmethod
+ def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput:
+ clip = clip.clone()
+ resolved = comfy.model_management.resolve_gpu_device_option(device)
+ if resolved is None and device not in (None, "default"):
+ logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
+ return io.NodeOutput(clip)
+ try:
+ clip.patcher = _apply_patcher_device(clip.patcher, resolved)
+ except RuntimeError as e:
+ logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})")
+ return io.NodeOutput(clip)
+
+
+class SelectVAEDeviceNode(io.ComfyNode):
+ """
+ Place the VAE on a specific device (default / gpu:N).
+
+ - "default" restores the device assigned by the loader.
+ - "gpu:N" pins the load device to the Nth available GPU; the offload
+ device is set to the standard VAE offload device.
+
+ CPU is intentionally not exposed in the UI for the VAE; if a workflow
+ supplies "cpu" anyway (e.g. opened from another machine), the request
+ is dropped with a log message and the VAE is passed through unchanged.
+
+ When the selected device does not exist on the current machine
+ (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
+ the node passes the VAE through unchanged and logs a message
+ instead of failing.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SelectVAEDevice",
+ display_name="Select VAE Device",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Vae.Input("vae"),
+ io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options_no_cpu()),
+ ],
+ outputs=[
+ io.Vae.Output(),
+ ],
+ )
+
+ @classmethod
+ def validate_inputs(cls, device="default"):
+ return True
+
+ @classmethod
+ def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput:
+ # VAE has no .clone(); shallow-copy the wrapper and clone the patcher
+ # so we can retarget load/offload device without affecting the input VAE.
+ vae = copy.copy(vae)
+ vae.patcher = vae.patcher.clone()
+ resolved = comfy.model_management.resolve_gpu_device_option(device)
+ if resolved is None and device not in (None, "default"):
+ logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.")
+ return io.NodeOutput(vae)
+ if resolved is not None and resolved.type == "cpu":
+ logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
+ return io.NodeOutput(vae)
+ if not hasattr(vae, "_select_base_device"):
+ vae._select_base_device = vae.device
+ try:
+ vae.patcher = _apply_patcher_device(
+ vae.patcher, resolved,
+ base_offload_override=comfy.model_management.vae_offload_device(),
+ )
+ except RuntimeError as e:
+ logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})")
+ return io.NodeOutput(vae)
+ # Keep VAE wrapper in sync with whatever model the patcher now owns;
+ # deepclone_multigpu may have produced a fresh first_stage_model.
+ vae.first_stage_model = vae.patcher.model
+ vae.device = vae._select_base_device if resolved is None else resolved
+ return io.NodeOutput(vae)
+
+
+class MultiGPUOptionsNode(io.ComfyNode):
+ """
+ Select the relative speed of GPUs in the special case they have significantly different performance from one another.
+
+ NOTE (not registered yet, see MultiGPUExtension.get_node_list below):
+ The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on
+ model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond
+ scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult
+ relative_speed when distributing conds across devices; it uses a uniform conds_per_device
+ round-robin via next_available_device(). Before re-enabling this node, wire its
+ relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(),
+ which already implements the proportional split) so the input actually affects work
+ distribution.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="MultiGPU_Options",
+ display_name="MultiGPU Options",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Int.Input("device_index", default=0, min=0, max=64),
+ io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
+ io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
+ ],
+ outputs=[
+ io.Custom("GPU_OPTIONS").Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
+ if not gpu_options:
+ gpu_options = comfy.multigpu.GPUOptionsGroup()
+ else:
+ gpu_options = gpu_options.clone()
+
+ opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
+ gpu_options.add(opt)
+
+ return io.NodeOutput(gpu_options)
+
+
+class MultiGPUExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ MultiGPUCFGSplitNode,
+ SelectModelDeviceNode,
+ SelectCLIPDeviceNode,
+ SelectVAEDeviceNode,
+ # MultiGPUOptionsNode,
+ ]
+
+
+async def comfy_entrypoint() -> MultiGPUExtension:
+ return MultiGPUExtension()
diff --git a/main.py b/main.py
index 26d523c30..bce451a83 100644
--- a/main.py
+++ b/main.py
@@ -218,7 +218,7 @@ import comfy.model_patcher
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
- elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
+ elif comfy_aimdo.control.init_devices(d.index for d in comfy.model_management.get_all_torch_devices()):
if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug()
elif args.verbose == 'CRITICAL':
diff --git a/nodes.py b/nodes.py
index 13e46ac8a..fd4365c90 100644
--- a/nodes.py
+++ b/nodes.py
@@ -795,6 +795,7 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
metadata = None
+ vae_path = None
if vae_name == "pixel_space":
sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0)
@@ -813,6 +814,14 @@ class VAELoader:
metadata["tae_latent_channels"] = 128
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
+ # Register a reload factory on the patcher so multigpu deepclones
+ # (Select VAE Device, future MultiGPU VAE work-units) can produce
+ # per-device clones from the same loader context. Only set when we
+ # actually have a single backing file -- pixel_space and the
+ # image TAESDs (composed from separate encoder/decoder files via
+ # load_taesd) are not addressable by a single vae_path.
+ if vae_path is not None:
+ vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None))
return (vae,)
class ControlNetLoader:
@@ -2389,6 +2398,7 @@ async def init_builtin_extra_nodes():
"nodes_lt_audio.py",
"nodes_lt.py",
"nodes_hooks.py",
+ "nodes_multigpu.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_video.py",
diff --git a/server.py b/server.py
index 44470b904..268441bd1 100644
--- a/server.py
+++ b/server.py
@@ -646,18 +646,37 @@ class PromptServer():
@routes.get("/system_stats")
async def system_stats(request):
- device = comfy.model_management.get_torch_device()
- device_name = comfy.model_management.get_torch_device_name(device)
+ primary_device = comfy.model_management.get_torch_device()
cpu_device = comfy.model_management.torch.device("cpu")
ram_total = comfy.model_management.get_total_memory(cpu_device)
ram_free = comfy.model_management.get_free_memory(cpu_device)
- vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
- vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
required_frontend_version = FrontendManager.get_required_frontend_version()
installed_templates_version = FrontendManager.get_installed_templates_version()
required_templates_version = FrontendManager.get_required_templates_version()
comfy_package_versions = FrontendManager.get_comfy_package_versions()
+ # Report every torch device visible to multigpu, with the primary
+ # device first so existing clients that read devices[0] keep working.
+ torch_devices = comfy.model_management.get_all_torch_devices()
+ if primary_device in torch_devices:
+ torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device]
+ else:
+ torch_devices = [primary_device] + list(torch_devices)
+
+ device_entries = []
+ for d in torch_devices:
+ vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True)
+ vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True)
+ device_entries.append({
+ "name": comfy.model_management.get_torch_device_name(d),
+ "type": d.type,
+ "index": d.index,
+ "vram_total": vram_total,
+ "vram_free": vram_free,
+ "torch_vram_total": torch_vram_total,
+ "torch_vram_free": torch_vram_free,
+ })
+
system_stats = {
"system": {
"os": sys.platform,
@@ -673,17 +692,7 @@ class PromptServer():
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"argv": sys.argv
},
- "devices": [
- {
- "name": device_name,
- "type": device.type,
- "index": device.index,
- "vram_total": vram_total,
- "vram_free": vram_free,
- "torch_vram_total": torch_vram_total,
- "torch_vram_free": torch_vram_free,
- }
- ]
+ "devices": device_entries
}
return web.json_response(system_stats)