Multi-GPU device selection for loader nodes + CUDA context fixes (#13483)

* Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2)

Amp-Thread-ID: https://ampcode.com/threads/T-019da964-2cc8-77f9-9aae-23f65da233db
Co-authored-by: Amp <amp@ampcode.com>

* Add GPU device selection to all loader nodes

- Add get_gpu_device_options() and resolve_gpu_device_option() helpers
  in model_management.py for vendor-agnostic GPU device selection
- Add device widget to CheckpointLoaderSimple, UNETLoader, VAELoader
- Expand device options in CLIPLoader, DualCLIPLoader, LTXAVTextEncoderLoader
  from [default, cpu] to include gpu:0, gpu:1, etc. on multi-GPU systems
- Wire load_diffusion_model_state_dict and load_state_dict_guess_config
  to respect model_options['load_device']
- Graceful fallback: unrecognized devices (e.g. gpu:1 on single-GPU)
  silently fall back to default

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Add VALIDATE_INPUTS to skip device combo validation for workflow portability

When a workflow saved on a 2-GPU machine (with device=gpu:1) is loaded
on a 1-GPU machine, the combo validation would reject the unknown value.
VALIDATE_INPUTS with the device parameter bypasses combo validation for
that input only, allowing resolve_gpu_device_option to handle the
graceful fallback at runtime.

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Set CUDA device context in outer_sample to match model load_device

Custom CUDA kernels (comfy_kitchen fp8 quantization) use
torch.cuda.current_device() for DLPack tensor export. When a model is
loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match
or the kernel fails with 'Can't export tensors on a different CUDA
device index'. Save and restore the previous device around sampling.

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Fix code review bugs: negative index guard, CPU offload_device, checkpoint te_model_options

- resolve_gpu_device_option: reject negative indices (gpu:-1)
- UNETLoader: set offload_device when cpu is selected
- CheckpointLoaderSimple: pass te_model_options for CLIP device,
  set offload_device for cpu, pass load_device to VAE
- load_diffusion_model_state_dict: respect offload_device from model_options
- load_state_dict_guess_config: respect offload_device, pass load_device to VAE

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Fix CUDA device context for CLIP encoding and VAE encode/decode

Add torch.cuda.set_device() calls to match model's load device in:
- CLIP.encode_from_tokens: fixes 'Can't export tensors on a different
  CUDA device index' when CLIP is loaded on a non-default GPU
- CLIP.encode_from_tokens_scheduled: same fix for the hooks code path
- CLIP.generate: same fix for text generation
- VAE.decode: fixes VAE decoding on non-default GPU
- VAE.encode: fixes VAE encoding on non-default GPU

Same pattern as the existing outer_sample fix in samplers.py - saves
and restores previous CUDA device in a try/finally block.

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

* Extract cuda_device_context manager, fix tiled VAE methods

Add model_management.cuda_device_context() — a context manager that
saves/restores torch.cuda.current_device when operating on a non-default
GPU. Replaces 6 copies of the manual save/set/restore boilerplate.

Refactored call sites:
- CLIP.encode_from_tokens
- CLIP.encode_from_tokens_scheduled (hooks path)
- CLIP.generate
- VAE.decode
- VAE.encode
- samplers.outer_sample

Bug fixes (newly wrapped):
- VAE.decode_tiled: was missing device context entirely, would fail
  on non-default GPU when called from 'VAE Decode (Tiled)' node
- VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

* Restore CheckpointLoaderSimple, add CheckpointLoaderDevice

Revert CheckpointLoaderSimple to its original form (no device input)
so it remains the simple default loader.

Add new CheckpointLoaderDevice node (advanced/loaders) with separate
model_device, clip_device, and vae_device inputs for per-component
GPU placement in multi-GPU setups.

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

---------

Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-04-23 19:10:33 -07:00 committed by GitHub
parent 7b8b3673ff
commit aa464b36b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 371 additions and 178 deletions

View File

@ -28,7 +28,7 @@ import platform
import weakref import weakref
import gc import gc
import os import os
from contextlib import nullcontext from contextlib import contextmanager, nullcontext
import comfy.memory_management import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.quant_ops import comfy.quant_ops
@ -231,6 +231,70 @@ def get_all_torch_devices(exclude_current=False):
devices.remove(get_torch_device()) devices.remove(get_torch_device())
return devices return devices
def get_gpu_device_options():
"""Return list of device option strings for node widgets.
Always includes "default" and "cpu". When multiple GPUs are present,
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
"""
options = ["default", "cpu"]
devices = get_all_torch_devices()
if len(devices) > 1:
for i in range(len(devices)):
options.append(f"gpu:{i}")
return options
def resolve_gpu_device_option(option: str):
"""Resolve a device option string to a torch.device.
Returns None for "default" (let the caller use its normal default).
Returns torch.device("cpu") for "cpu".
For "gpu:N", returns the Nth torch device. Falls back to None if
the index is out of range (caller should use default).
"""
if option is None or option == "default":
return None
if option == "cpu":
return torch.device("cpu")
if option.startswith("gpu:"):
try:
idx = int(option[4:])
devices = get_all_torch_devices()
if 0 <= idx < len(devices):
return devices[idx]
else:
logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.")
return None
except (ValueError, IndexError):
logging.warning(f"Invalid device option '{option}', using default.")
return None
logging.warning(f"Unrecognized device option '{option}', using default.")
return None
@contextmanager
def cuda_device_context(device):
"""Context manager that sets torch.cuda.current_device to match *device*.
Used when running operations on a non-default CUDA device so that custom
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
device index. The previous device is restored on exit.
No-op when *device* is not CUDA, has no explicit index, or already matches
the current device.
"""
prev = None
if device.type == "cuda" and device.index is not None:
prev = torch.cuda.current_device()
if prev != device.index:
torch.cuda.set_device(device)
else:
prev = None
try:
yield
finally:
if prev is not None:
torch.cuda.set_device(prev)
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled global directml_enabled
if dev is None: if dev is None:

View File

@ -1208,23 +1208,24 @@ class CFGGuider:
all_devices = [device] + extra_devices all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
try: with comfy.model_management.cuda_device_context(device):
noise = noise.to(device=device, dtype=torch.float32) try:
latent_image = latent_image.to(device=device, dtype=torch.float32) noise = noise.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device) latent_image = latent_image.to(device=device, dtype=torch.float32)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
self.model_patcher.pre_run() self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers: for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run() multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally: finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None) thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None: if thread_pool is not None:
thread_pool.shutdown() thread_pool.shutdown()
self.model_patcher.cleanup() self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers: for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup() multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model del self.inner_model

View File

@ -324,41 +324,43 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens) self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
all_hooks.reset() all_hooks.reset()
self.patcher.patch_hooks(None) self.patcher.patch_hooks(None)
if show_pbar: if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes)) pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes: with model_management.cuda_device_context(device):
t_range = scheduled_opts[0] for scheduled_opts in scheduled_keyframes:
# don't bother encoding any conds outside of start_percent and end_percent bounds t_range = scheduled_opts[0]
if "start_percent" in add_dict: # don't bother encoding any conds outside of start_percent and end_percent bounds
if t_range[1] < add_dict["start_percent"]: if "start_percent" in add_dict:
continue if t_range[1] < add_dict["start_percent"]:
if "end_percent" in add_dict: continue
if t_range[0] > add_dict["end_percent"]: if "end_percent" in add_dict:
continue if t_range[0] > add_dict["end_percent"]:
hooks_keyframes = scheduled_opts[1] continue
for hook, keyframe in hooks_keyframes: hooks_keyframes = scheduled_opts[1]
hook.hook_keyframe._current_keyframe = keyframe for hook, keyframe in hooks_keyframes:
# apply appropriate hooks with values that match new hook_keyframe hook.hook_keyframe._current_keyframe = keyframe
self.patcher.patch_hooks(all_hooks) # apply appropriate hooks with values that match new hook_keyframe
# perform encoding as normal self.patcher.patch_hooks(all_hooks)
o = self.cond_stage_model.encode_token_weights(tokens) # perform encoding as normal
cond, pooled = o[:2] o = self.cond_stage_model.encode_token_weights(tokens)
pooled_dict = {"pooled_output": pooled} cond, pooled = o[:2]
# add clip_start_percent and clip_end_percent in pooled pooled_dict = {"pooled_output": pooled}
pooled_dict["clip_start_percent"] = t_range[0] # add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_end_percent"] = t_range[1] pooled_dict["clip_start_percent"] = t_range[0]
# add/update any keys with the provided add_dict pooled_dict["clip_end_percent"] = t_range[1]
pooled_dict.update(add_dict) # add/update any keys with the provided add_dict
# add hooks stored on clip pooled_dict.update(add_dict)
self.add_hooks_to_dict(pooled_dict) # add hooks stored on clip
all_cond_pooled.append([cond, pooled_dict]) self.add_hooks_to_dict(pooled_dict)
if show_pbar: all_cond_pooled.append([cond, pooled_dict])
pbar.update(1) if show_pbar:
model_management.throw_exception_if_processing_interrupted() pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
all_hooks.reset() all_hooks.reset()
return all_cond_pooled return all_cond_pooled
@ -372,8 +374,12 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens) self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) device = self.patcher.load_device
o = self.cond_stage_model.encode_token_weights(tokens) self.cond_stage_model.set_clip_options({"execution_device": device})
with model_management.cuda_device_context(device):
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2] cond, pooled = o[:2]
if return_dict: if return_dict:
out = {"cond": cond, "pooled_output": pooled} out = {"cond": cond, "pooled_output": pooled}
@ -428,9 +434,12 @@ class CLIP:
self.cond_stage_model.reset_clip_options() self.cond_stage_model.reset_clip_options()
self.load_model(tokens) self.load_model(tokens)
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) self.cond_stage_model.set_clip_options({"execution_device": device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
with model_management.cuda_device_context(device):
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def decode(self, token_ids, skip_special_tokens=True): def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@ -947,50 +956,52 @@ class VAE:
do_tile = False do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5: if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0] samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes with model_management.cuda_device_context(self.device):
preallocated = False try:
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
preallocated = True free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number): # Pre-allocate output for VAEs that support direct buffer writes
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) preallocated = False
if preallocated: if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
else: preallocated = True
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile: for x in range(0, samples_in.shape[0], batch_number):
comfy.model_management.soft_empty_cache() samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
dims = samples_in.ndim - 2 if preallocated:
if dims == 1 or self.extra_1d_channel is not None: self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
pixel_samples = self.decode_tiled_1d(samples_in) else:
elif dims == 2: out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
pixel_samples = self.decode_tiled_(samples_in) if pixel_samples is None:
elif dims == 3: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
tile = 256 // self.spacial_compression_decode() pixel_samples[x:x+batch_number].copy_(out)
overlap = tile // 4 del out
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
@ -1008,20 +1019,21 @@ class VAE:
if overlap is not None: if overlap is not None:
args["overlap"] = overlap args["overlap"] = overlap
if dims == 1 or self.extra_1d_channel is not None: with model_management.cuda_device_context(self.device):
args.pop("tile_y") if dims == 1 or self.extra_1d_channel is not None:
output = self.decode_tiled_1d(samples, **args) args.pop("tile_y")
elif dims == 2: output = self.decode_tiled_1d(samples, **args)
output = self.decode_tiled_(samples, **args) elif dims == 2:
elif dims == 3: output = self.decode_tiled_(samples, **args)
if overlap_t is None: elif dims == 3:
args["overlap"] = (1, overlap, overlap) if overlap_t is None:
else: args["overlap"] = (1, overlap, overlap)
args["overlap"] = (max(1, overlap_t), overlap, overlap) else:
if tile_t is not None: args["overlap"] = (max(1, overlap_t), overlap, overlap)
args["tile_t"] = max(2, tile_t) if tile_t is not None:
args["tile_t"] = max(2, tile_t)
output = self.decode_tiled_3d(samples, **args) output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1) return output.movedim(1, -1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
@ -1034,44 +1046,46 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else: else:
pixel_samples = pixel_samples.unsqueeze(2) pixel_samples = pixel_samples.unsqueeze(2)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) with model_management.cuda_device_context(self.device):
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) try:
free_memory = self.patcher.get_free_memory(self.device) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
batch_number = int(free_memory / max(1, memory_used)) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
batch_number = max(1, batch_number) free_memory = self.patcher.get_free_memory(self.device)
samples = None batch_number = int(free_memory / max(1, memory_used))
for x in range(0, pixel_samples.shape[0], batch_number): batch_number = max(1, batch_number)
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) samples = None
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): for x in range(0, pixel_samples.shape[0], batch_number):
out = self.first_stage_model.encode(pixels_in, device=self.device) pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else: else:
pixels_in = pixels_in.to(self.device) samples = self.encode_tiled_(pixel_samples)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
return samples return samples
@ -1097,26 +1111,27 @@ class VAE:
if overlap is not None: if overlap is not None:
args["overlap"] = overlap args["overlap"] = overlap
if dims == 1: with model_management.cuda_device_context(self.device):
args.pop("tile_y") if dims == 1:
samples = self.encode_tiled_1d(pixel_samples, **args) args.pop("tile_y")
elif dims == 2: samples = self.encode_tiled_1d(pixel_samples, **args)
samples = self.encode_tiled_(pixel_samples, **args) elif dims == 2:
elif dims == 3: samples = self.encode_tiled_(pixel_samples, **args)
if tile_t is not None: elif dims == 3:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) if tile_t is not None:
else: tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
tile_t_latent = 9999 else:
args["tile_t"] = self.upscale_ratio[0](tile_t_latent) tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None: if overlap_t is None:
args["overlap"] = (1, overlap, overlap) args["overlap"] = (1, overlap, overlap)
else: else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2] maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples return samples
@ -1633,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device() load_device = model_options.get("load_device", model_management.get_torch_device())
custom_operations = model_options.get("custom_operations", None) custom_operations = model_options.get("custom_operations", None)
if custom_operations is None: if custom_operations is None:
@ -1673,13 +1688,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) offload_device = model_options.get("offload_device", model_management.unet_offload_device())
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae: if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd) vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata) vae_device = model_options.get("load_device", None)
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
if output_clip: if output_clip:
if te_model_options.get("custom_operations", None) is None: if te_model_options.get("custom_operations", None) is None:
@ -1763,7 +1780,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd) weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device() load_device = model_options.get("load_device", model_management.get_torch_device())
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None: if model_config is not None:
@ -1788,7 +1805,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
else: else:
logging.warning("{} {}".format(diffusers_keys[k], k)) logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device() offload_device = model_options.get("offload_device", model_management.unet_offload_device())
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None: if model_config.quant_config is not None:
weight_dtype = None weight_dtype = None

View File

@ -188,7 +188,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
), ),
io.Combo.Input( io.Combo.Input(
"device", "device",
options=["default", "cpu"], options=comfy.model_management.get_gpu_device_options(),
advanced=True, advanced=True,
) )
], ],
@ -203,8 +203,12 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {} model_options = {}
if device == "cpu": resolved = comfy.model_management.resolve_gpu_device_option(device)
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip) return io.NodeOutput(clip)

127
nodes.py
View File

@ -608,6 +608,73 @@ class CheckpointLoaderSimple:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3] return out[:3]
class CheckpointLoaderDevice:
@classmethod
def INPUT_TYPES(s):
device_options = comfy.model_management.get_gpu_device_options()
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
},
"optional": {
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
"clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}),
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.")
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups."
@classmethod
def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"):
return True
def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {}
resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
if resolved_model is not None:
if resolved_model.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved_model
else:
model_options["load_device"] = resolved_model
te_model_options = {}
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device)
if resolved_clip is not None:
if resolved_clip.type == "cpu":
te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip
else:
te_model_options["load_device"] = resolved_clip
# VAE device is passed via model_options["load_device"] which
# load_state_dict_guess_config forwards to the VAE constructor.
# If vae_device differs from model_device, we override after loading.
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options)
model_patcher, clip, vae = out[:3]
# Apply VAE device override if it differs from the model device
if resolved_vae is not None and vae is not None:
vae.device = resolved_vae
if resolved_vae.type == "cpu":
offload = resolved_vae
else:
offload = comfy.model_management.vae_offload_device()
vae.patcher.load_device = resolved_vae
vae.patcher.offload_device = offload
return (model_patcher, clip, vae)
class DiffusersLoader: class DiffusersLoader:
SEARCH_ALIASES = ["load diffusers model"] SEARCH_ALIASES = ["load diffusers model"]
@ -807,14 +874,21 @@ class VAELoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(s), )}} return {"required": { "vae_name": (s.vae_list(s), )},
"optional": {
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("VAE",) RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae" FUNCTION = "load_vae"
CATEGORY = "loaders" CATEGORY = "loaders"
@classmethod
def VALIDATE_INPUTS(cls, device="default"):
return True
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name, device="default"):
metadata = None metadata = None
if vae_name == "pixel_space": if vae_name == "pixel_space":
sd = {} sd = {}
@ -827,7 +901,8 @@ class VAELoader:
else: else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
vae = comfy.sd.VAE(sd=sd, metadata=metadata) resolved = comfy.model_management.resolve_gpu_device_option(device)
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
vae.throw_exception_if_invalid() vae.throw_exception_if_invalid()
return (vae,) return (vae,)
@ -953,13 +1028,20 @@ class UNETLoader:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True}) "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
},
"optional": {
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet" FUNCTION = "load_unet"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype): @classmethod
def VALIDATE_INPUTS(cls, device="default"):
return True
def load_unet(self, unet_name, weight_dtype, device="default"):
model_options = {} model_options = {}
if weight_dtype == "fp8_e4m3fn": if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn model_options["dtype"] = torch.float8_e4m3fn
@ -969,6 +1051,13 @@ class UNETLoader:
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)
@ -980,7 +1069,7 @@ class CLIPLoader:
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
@ -989,12 +1078,20 @@ class CLIPLoader:
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
@classmethod
def VALIDATE_INPUTS(cls, device="default"):
return True
def load_clip(self, clip_name, type="stable_diffusion", device="default"): def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
model_options = {} model_options = {}
if device == "cpu": resolved = comfy.model_management.resolve_gpu_device_option(device)
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
@ -1008,7 +1105,7 @@ class DualCLIPLoader:
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ), "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
@ -1017,6 +1114,10 @@ class DualCLIPLoader:
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
@classmethod
def VALIDATE_INPUTS(cls, device="default"):
return True
def load_clip(self, clip_name1, clip_name2, type, device="default"): def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@ -1024,8 +1125,12 @@ class DualCLIPLoader:
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
model_options = {} model_options = {}
if device == "cpu": resolved = comfy.model_management.resolve_gpu_device_option(device)
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return (clip,) return (clip,)
@ -2098,6 +2203,7 @@ NODE_CLASS_MAPPINGS = {
"InpaintModelConditioning": InpaintModelConditioning, "InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"CheckpointLoaderDevice": CheckpointLoaderDevice,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent, "LoadLatent": LoadLatent,
@ -2115,6 +2221,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# Loaders # Loaders
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint", "CheckpointLoaderSimple": "Load Checkpoint",
"CheckpointLoaderDevice": "Load Checkpoint (Device)",
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"LoraLoader": "Load LoRA (Model and CLIP)", "LoraLoader": "Load LoRA (Model and CLIP)",
"LoraLoaderModelOnly": "Load LoRA", "LoraLoaderModelOnly": "Load LoRA",