* Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2)
Amp-Thread-ID: https://ampcode.com/threads/T-019da964-2cc8-77f9-9aae-23f65da233db
Co-authored-by: Amp <amp@ampcode.com>
* Add GPU device selection to all loader nodes
- Add get_gpu_device_options() and resolve_gpu_device_option() helpers
in model_management.py for vendor-agnostic GPU device selection
- Add device widget to CheckpointLoaderSimple, UNETLoader, VAELoader
- Expand device options in CLIPLoader, DualCLIPLoader, LTXAVTextEncoderLoader
from [default, cpu] to include gpu:0, gpu:1, etc. on multi-GPU systems
- Wire load_diffusion_model_state_dict and load_state_dict_guess_config
to respect model_options['load_device']
- Graceful fallback: unrecognized devices (e.g. gpu:1 on single-GPU)
silently fall back to default
Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
* Add VALIDATE_INPUTS to skip device combo validation for workflow portability
When a workflow saved on a 2-GPU machine (with device=gpu:1) is loaded
on a 1-GPU machine, the combo validation would reject the unknown value.
VALIDATE_INPUTS with the device parameter bypasses combo validation for
that input only, allowing resolve_gpu_device_option to handle the
graceful fallback at runtime.
Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
* Set CUDA device context in outer_sample to match model load_device
Custom CUDA kernels (comfy_kitchen fp8 quantization) use
torch.cuda.current_device() for DLPack tensor export. When a model is
loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match
or the kernel fails with 'Can't export tensors on a different CUDA
device index'. Save and restore the previous device around sampling.
Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
* Fix code review bugs: negative index guard, CPU offload_device, checkpoint te_model_options
- resolve_gpu_device_option: reject negative indices (gpu:-1)
- UNETLoader: set offload_device when cpu is selected
- CheckpointLoaderSimple: pass te_model_options for CLIP device,
set offload_device for cpu, pass load_device to VAE
- load_diffusion_model_state_dict: respect offload_device from model_options
- load_state_dict_guess_config: respect offload_device, pass load_device to VAE
Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
* Fix CUDA device context for CLIP encoding and VAE encode/decode
Add torch.cuda.set_device() calls to match model's load device in:
- CLIP.encode_from_tokens: fixes 'Can't export tensors on a different
CUDA device index' when CLIP is loaded on a non-default GPU
- CLIP.encode_from_tokens_scheduled: same fix for the hooks code path
- CLIP.generate: same fix for text generation
- VAE.decode: fixes VAE decoding on non-default GPU
- VAE.encode: fixes VAE encoding on non-default GPU
Same pattern as the existing outer_sample fix in samplers.py - saves
and restores previous CUDA device in a try/finally block.
Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
* Extract cuda_device_context manager, fix tiled VAE methods
Add model_management.cuda_device_context() — a context manager that
saves/restores torch.cuda.current_device when operating on a non-default
GPU. Replaces 6 copies of the manual save/set/restore boilerplate.
Refactored call sites:
- CLIP.encode_from_tokens
- CLIP.encode_from_tokens_scheduled (hooks path)
- CLIP.generate
- VAE.decode
- VAE.encode
- samplers.outer_sample
Bug fixes (newly wrapped):
- VAE.decode_tiled: was missing device context entirely, would fail
on non-default GPU when called from 'VAE Decode (Tiled)' node
- VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node
Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
* Restore CheckpointLoaderSimple, add CheckpointLoaderDevice
Revert CheckpointLoaderSimple to its original form (no device input)
so it remains the simple default loader.
Add new CheckpointLoaderDevice node (advanced/loaders) with separate
model_device, clip_device, and vae_device inputs for per-component
GPU placement in multi-GPU setups.
Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
---------
Co-authored-by: Amp <amp@ampcode.com>
Skip unnecessary clone of inference-mode tensors when already inside
torch.inference_mode(), matching the existing guard in set_attr_param.
The unconditional clone introduced in 20561aa9 caused transient VRAM
doubling during model movement for FP8/quantized models.
Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090
with SD1.5 and NetaYume models. No meaningful performance difference
(within noise). All-pool is simpler: eliminates the main_device
special case, main_batch_tuple deferred execution, and the 3-way
branch in the dispatch loop.
Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a
persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device()
once at startup, preserving compiled kernel caches across diffusion steps.
- Add MultiGPUThreadPool class in comfy/multigpu.py
- Create pool in CFGGuider.outer_sample(), shut down in finally block
- Main thread handles its own device batch directly for zero overhead
- Falls back to sequential execution if no pool is available
When a multigpu clone ModelPatcher is garbage collected, LoadedModel._switch_parent
switches the weakref to point at the parent (main) ModelPatcher. However, it was not
updating LoadedModel.device, leaving it with the old clone's device (e.g., cuda:1).
On subsequent runs, this stale device was passed to ModelPatcherDynamic.load(), causing
an assertion failure (device_to != self.load_device).
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
* mm: Lower windows pin threshold
Some workflows have more extranous use of shared GPU memory than is
accounted for in the 5% pin headroom. Lower this for safety.
* mm: Remove pin count clearing threshold.
TOTAL_PINNED_MEMORY is shared between the legacy and aimdo pinning
systems, however this catch-all assumes only the legacy system exists.
Remove the catch-all as the PINNED_MEMORY buffer is coherent already.
There was an issue where the resample split was too early and dropped one
of the rolling convolutions a frame early. This is most noticable as a
lighting/color change between pixel frames 5->6 (latent 2->3), or as a
lighting change between the first and last frame in an FLF wan flow.
The recent PR that added resize_cond_for_context_window methods to
model classes used inline 'import comfy.context_windows' in each
method body. This moves that import to the top-level import section,
replacing 4 duplicate inline imports with a single top-level one.
* Add slice_cond and per-model context window cond resizing
* Fix cond_value.size() call in context window cond resizing
* Expose additional advanced inputs for ContextWindowsManualNode
Necessary for WanAnimate context windows workflow, which needs cond_retain_index_list = 0 to work properly with its reference input.
---------
* sd: soft_empty_cache on tiler fallback
This doesnt cost a lot and creates the expected VRAM reduction in
resource monitors when you fallback to tiler.
* wan: vae: Don't recursion in local fns (move run_up)
Moved Decoder3d’s recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
* ltx: vae: Don't recursion in local fns (move run_up)
Mov the recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
* ltx: vae: add cache state to downsample block
* ltx: vae: Add time stride awareness to causal_conv_3d
* ltx: vae: Automate truncation for encoder
Other VAEs just truncate without error. Do the same.
* sd/ltx: Make chunked_io a flag in its own right
Taking this bi-direcitonal, so make it a for-purpose named flag.
* ltx: vae: implement chunked encoder + CPU IO chunking
People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.
* ltx: vae-encode: round chunk sizes more strictly
Only powers of 2 and multiple of 8 are valid due to cache slicing.