Two CodeRabbit findings from #7063 (#13 and #14) are deferred because
worksplit-multigpu's initial release scope is NVIDIA-only QA. Leave a
TODO at the unconditional torch.cuda.set_device call and at the
post-aggregation point so the required guards/synchronize are easy to
find when multigpu support is extended to XPU/NPU/MPS/CPU/DirectML.
Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
Per review feedback on #7063. The two functions share the conds-by-hooks
accumulation, memory-fit batching, and per-chunk output aggregation; the
multigpu variant adds per-device scheduling, .to(device) placement,
per-device patcher/control lookup, and thread-pool dispatch around the
inner loop. Documenting the relationship without extracting helpers --
extraction can land after the initial worksplit-multigpu release once
both paths have settled.
Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
QwenFunControlNet.pre_run stashes model.diffusion_model into extra_args,
which the control_model then uses for forward passes (img_in, txt_in,
pe_embedder, time_text_embed). With multigpu, every per-device control
clone was being pre_run with the base model on GPU0, so secondary
devices would invoke those modules with parameters on GPU0 and inputs
on their own device, raising 'Expected all tensors to be on the same
device'. Build a device -> per-device BaseModel lookup from the
patcher's additional multigpu models and pass each clone the model on
its own device. Falls back to the base model when no per-device match
is found (single-GPU path and the case where cnet.multigpu_clones lags
the patcher's clone set).
Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
The multigpu cond-batching loop called model.memory_required(input_shape) without conditioning shapes, while the single-GPU path at line 279 passes cond_shapes. Large conditioning tensors (e.g. video prompts, control inputs) were therefore under-counted, risking OOM at runtime when the chosen batch size was too large. Match the single-GPU pattern by building cond_shapes from each batched cond's conditioning dict and passing it to memory_required.
Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5
Co-authored-by: Amp <amp@ampcode.com>
Behaviour-equivalent cleanup of _calc_cond_batch_multigpu device
scheduling. No change to batching decisions or memory checks for any
valid input.
Changes:
* Replace re-summed batched_to_run_length with a per-device load
dict (device_load), so capacity checks are O(1) and use a single
source of truth.
* Extract device selection into next_available_device(), which scans
at most len(devices) positions and raises if no device has
remaining capacity. This makes the 'skip a full device' rule live
in one place instead of two and guarantees the outer while loop
cannot spin forever on a scheduling bug.
* Drop the unused current_device assignment before the outer loop
and the index_device % len(devices) modulo dance (now handled
inside next_available_device).
* Minor cleanups: list comprehensions for total_conds, conds_to_batch,
and the devices list.
Fixes _calc_cond_batch_multigpu so that:
1. conds_per_device uses real division before math.ceil. The previous
expression math.ceil(total_conds // len(devices)) applied integer
floor division first, making ceil a no-op. For 3 conds across 2
devices this produced conds_per_device=1 instead of 2.
2. The scheduling loop skips devices that have already reached
capacity instead of appending empty batch groups. Without this
guard, the loop could repeatedly emit zero-length groups for a
full device, leaving sampling stuck at 0/N until timeout.
Reproduces with an Omnigen2 image workflow that produces three
condition entries scheduled across two CUDA devices. With the fix
the scheduler assigns conds_per_device=2 and splits the batches as
2 + 1 across the two devices, allowing sampling to complete.
Original fix authored and validated by @pollockjj in
pollockjj/ComfyUI#64.
Co-authored-by: John Pollock <pollockjj@gmail.com>
* Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2)
Amp-Thread-ID: https://ampcode.com/threads/T-019da964-2cc8-77f9-9aae-23f65da233db
Co-authored-by: Amp <amp@ampcode.com>
* Add GPU device selection to all loader nodes
- Add get_gpu_device_options() and resolve_gpu_device_option() helpers
in model_management.py for vendor-agnostic GPU device selection
- Add device widget to CheckpointLoaderSimple, UNETLoader, VAELoader
- Expand device options in CLIPLoader, DualCLIPLoader, LTXAVTextEncoderLoader
from [default, cpu] to include gpu:0, gpu:1, etc. on multi-GPU systems
- Wire load_diffusion_model_state_dict and load_state_dict_guess_config
to respect model_options['load_device']
- Graceful fallback: unrecognized devices (e.g. gpu:1 on single-GPU)
silently fall back to default
Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
* Add VALIDATE_INPUTS to skip device combo validation for workflow portability
When a workflow saved on a 2-GPU machine (with device=gpu:1) is loaded
on a 1-GPU machine, the combo validation would reject the unknown value.
VALIDATE_INPUTS with the device parameter bypasses combo validation for
that input only, allowing resolve_gpu_device_option to handle the
graceful fallback at runtime.
Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
* Set CUDA device context in outer_sample to match model load_device
Custom CUDA kernels (comfy_kitchen fp8 quantization) use
torch.cuda.current_device() for DLPack tensor export. When a model is
loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match
or the kernel fails with 'Can't export tensors on a different CUDA
device index'. Save and restore the previous device around sampling.
Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
* Fix code review bugs: negative index guard, CPU offload_device, checkpoint te_model_options
- resolve_gpu_device_option: reject negative indices (gpu:-1)
- UNETLoader: set offload_device when cpu is selected
- CheckpointLoaderSimple: pass te_model_options for CLIP device,
set offload_device for cpu, pass load_device to VAE
- load_diffusion_model_state_dict: respect offload_device from model_options
- load_state_dict_guess_config: respect offload_device, pass load_device to VAE
Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
* Fix CUDA device context for CLIP encoding and VAE encode/decode
Add torch.cuda.set_device() calls to match model's load device in:
- CLIP.encode_from_tokens: fixes 'Can't export tensors on a different
CUDA device index' when CLIP is loaded on a non-default GPU
- CLIP.encode_from_tokens_scheduled: same fix for the hooks code path
- CLIP.generate: same fix for text generation
- VAE.decode: fixes VAE decoding on non-default GPU
- VAE.encode: fixes VAE encoding on non-default GPU
Same pattern as the existing outer_sample fix in samplers.py - saves
and restores previous CUDA device in a try/finally block.
Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
* Extract cuda_device_context manager, fix tiled VAE methods
Add model_management.cuda_device_context() — a context manager that
saves/restores torch.cuda.current_device when operating on a non-default
GPU. Replaces 6 copies of the manual save/set/restore boilerplate.
Refactored call sites:
- CLIP.encode_from_tokens
- CLIP.encode_from_tokens_scheduled (hooks path)
- CLIP.generate
- VAE.decode
- VAE.encode
- samplers.outer_sample
Bug fixes (newly wrapped):
- VAE.decode_tiled: was missing device context entirely, would fail
on non-default GPU when called from 'VAE Decode (Tiled)' node
- VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node
Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
* Restore CheckpointLoaderSimple, add CheckpointLoaderDevice
Revert CheckpointLoaderSimple to its original form (no device input)
so it remains the simple default loader.
Add new CheckpointLoaderDevice node (advanced/loaders) with separate
model_device, clip_device, and vae_device inputs for per-component
GPU placement in multi-GPU setups.
Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
---------
Co-authored-by: Amp <amp@ampcode.com>
Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090
with SD1.5 and NetaYume models. No meaningful performance difference
(within noise). All-pool is simpler: eliminates the main_device
special case, main_batch_tuple deferred execution, and the 3-way
branch in the dispatch loop.
Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a
persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device()
once at startup, preserving compiled kernel caches across diffusion steps.
- Add MultiGPUThreadPool class in comfy/multigpu.py
- Create pool in CFGGuider.outer_sample(), shut down in finally block
- Main thread handles its own device batch directly for zero overhead
- Falls back to sequential execution if no pool is available
* sd: add support for clip model reconstruction
* nodes: SetClipHooks: Demote the dynamic model patcher
* mp: Make dynamic_disable more robust
The backup need to not be cloned. In addition add a delegate object
to ModelPatcherDynamic so that non-cloning code can do
ModelPatcherDynamic demotion
* sampler_helpers: Demote to non-dynamic model patcher when hooking
* code rabbit review comments
* Attempting a universal implementation of EasyCache, starting with flux as test; I screwed up the math a bit, but when I set it just right it works.
* Fixed math to make threshold work as expected, refactored code to use EasyCacheHolder instead of a dict wrapped by object
* Use sigmas from transformer_options instead of timesteps to be compatible with a greater amount of models, make end_percent work
* Make log statement when not skipping useful, preparing for per-cond caching
* Added DIFFUSION_MODEL wrapper around forward function for wan model
* Add subsampling for heuristic inputs
* Add subsampling to output_prev (output_prev_subsampled now)
* Properly consider conds in EasyCache logic
* Created SuperEasyCache to test what happens if caching and reuse is moved outside the scope of conds, added PREDICT_NOISE wrapper to facilitate this test
* Change max reuse_threshold to 3.0
* Mark EasyCache/SuperEasyCache as experimental (beta)
* Make Lumina2 compatible with EasyCache
* Add EasyCache support for Qwen Image
* Fix missing comma, curse you Cursor
* Add EasyCache support to AceStep
* Add EasyCache support to Chroma
* Added EasyCache support to Cosmos Predict t2i
* Make EasyCache not crash with Cosmos Predict ImagToVideo latents, but does not work well at all
* Add EasyCache support to hidream
* Added EasyCache support to hunyuan video
* Added EasyCache support to hunyuan3d
* Added EasyCache support to LTXV (not very good, but does not crash)
* Implemented EasyCache for aura_flow
* Renamed SuperEasyCache to LazyCache, hardcoded subsample_factor to 8 on nodes
* Eatra logging when verbose is true for EasyCache
* Added initial support for basic context windows - in progress
* Add prepare_sampling wrapper for context window to more accurately estimate latent memory requirements, fixed merging wrappers/callbacks dicts in prepare_model_patcher
* Made context windows compatible with different dimensions; works for WAN, but results are bad
* Fix comfy.patcher_extension.merge_nested_dicts calls in prepare_model_patcher in sampler_helpers.py
* Considering adding some callbacks to context window code to allow extensions of behavior without the need to rewrite code
* Made dim slicing cleaner
* Add Wan Context WIndows node for testing
* Made context schedule and fuse method functions be stored on the handler instead of needing to be registered in core code to be found
* Moved some code around between node_context_windows.py and context_windows.py
* Change manual context window nodes names/ids
* Added callbacks to IndexListContexHandler
* Adjusted default values for context_length and context_overlap, made schema.inputs definition for WAN Context Windows less annoying
* Make get_resized_cond more robust for various dim sizes
* Fix typo
* Another small fix