mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-25 18:02:37 +08:00
Multi-GPU device selection for loader nodes + CUDA context fixes (#13483)
* Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2) Amp-Thread-ID: https://ampcode.com/threads/T-019da964-2cc8-77f9-9aae-23f65da233db Co-authored-by: Amp <amp@ampcode.com> * Add GPU device selection to all loader nodes - Add get_gpu_device_options() and resolve_gpu_device_option() helpers in model_management.py for vendor-agnostic GPU device selection - Add device widget to CheckpointLoaderSimple, UNETLoader, VAELoader - Expand device options in CLIPLoader, DualCLIPLoader, LTXAVTextEncoderLoader from [default, cpu] to include gpu:0, gpu:1, etc. on multi-GPU systems - Wire load_diffusion_model_state_dict and load_state_dict_guess_config to respect model_options['load_device'] - Graceful fallback: unrecognized devices (e.g. gpu:1 on single-GPU) silently fall back to default Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp <amp@ampcode.com> * Add VALIDATE_INPUTS to skip device combo validation for workflow portability When a workflow saved on a 2-GPU machine (with device=gpu:1) is loaded on a 1-GPU machine, the combo validation would reject the unknown value. VALIDATE_INPUTS with the device parameter bypasses combo validation for that input only, allowing resolve_gpu_device_option to handle the graceful fallback at runtime. Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp <amp@ampcode.com> * Set CUDA device context in outer_sample to match model load_device Custom CUDA kernels (comfy_kitchen fp8 quantization) use torch.cuda.current_device() for DLPack tensor export. When a model is loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match or the kernel fails with 'Can't export tensors on a different CUDA device index'. Save and restore the previous device around sampling. Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp <amp@ampcode.com> * Fix code review bugs: negative index guard, CPU offload_device, checkpoint te_model_options - resolve_gpu_device_option: reject negative indices (gpu:-1) - UNETLoader: set offload_device when cpu is selected - CheckpointLoaderSimple: pass te_model_options for CLIP device, set offload_device for cpu, pass load_device to VAE - load_diffusion_model_state_dict: respect offload_device from model_options - load_state_dict_guess_config: respect offload_device, pass load_device to VAE Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp <amp@ampcode.com> * Fix CUDA device context for CLIP encoding and VAE encode/decode Add torch.cuda.set_device() calls to match model's load device in: - CLIP.encode_from_tokens: fixes 'Can't export tensors on a different CUDA device index' when CLIP is loaded on a non-default GPU - CLIP.encode_from_tokens_scheduled: same fix for the hooks code path - CLIP.generate: same fix for text generation - VAE.decode: fixes VAE decoding on non-default GPU - VAE.encode: fixes VAE encoding on non-default GPU Same pattern as the existing outer_sample fix in samplers.py - saves and restores previous CUDA device in a try/finally block. Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp <amp@ampcode.com> * Extract cuda_device_context manager, fix tiled VAE methods Add model_management.cuda_device_context() — a context manager that saves/restores torch.cuda.current_device when operating on a non-default GPU. Replaces 6 copies of the manual save/set/restore boilerplate. Refactored call sites: - CLIP.encode_from_tokens - CLIP.encode_from_tokens_scheduled (hooks path) - CLIP.generate - VAE.decode - VAE.encode - samplers.outer_sample Bug fixes (newly wrapped): - VAE.decode_tiled: was missing device context entirely, would fail on non-default GPU when called from 'VAE Decode (Tiled)' node - VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp <amp@ampcode.com> * Restore CheckpointLoaderSimple, add CheckpointLoaderDevice Revert CheckpointLoaderSimple to its original form (no device input) so it remains the simple default loader. Add new CheckpointLoaderDevice node (advanced/loaders) with separate model_device, clip_device, and vae_device inputs for per-component GPU placement in multi-GPU setups. Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp <amp@ampcode.com> --------- Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
7b8b3673ff
commit
aa464b36b3
@ -28,7 +28,7 @@ import platform
|
|||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
@ -231,6 +231,70 @@ def get_all_torch_devices(exclude_current=False):
|
|||||||
devices.remove(get_torch_device())
|
devices.remove(get_torch_device())
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
def get_gpu_device_options():
|
||||||
|
"""Return list of device option strings for node widgets.
|
||||||
|
|
||||||
|
Always includes "default" and "cpu". When multiple GPUs are present,
|
||||||
|
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
|
||||||
|
"""
|
||||||
|
options = ["default", "cpu"]
|
||||||
|
devices = get_all_torch_devices()
|
||||||
|
if len(devices) > 1:
|
||||||
|
for i in range(len(devices)):
|
||||||
|
options.append(f"gpu:{i}")
|
||||||
|
return options
|
||||||
|
|
||||||
|
def resolve_gpu_device_option(option: str):
|
||||||
|
"""Resolve a device option string to a torch.device.
|
||||||
|
|
||||||
|
Returns None for "default" (let the caller use its normal default).
|
||||||
|
Returns torch.device("cpu") for "cpu".
|
||||||
|
For "gpu:N", returns the Nth torch device. Falls back to None if
|
||||||
|
the index is out of range (caller should use default).
|
||||||
|
"""
|
||||||
|
if option is None or option == "default":
|
||||||
|
return None
|
||||||
|
if option == "cpu":
|
||||||
|
return torch.device("cpu")
|
||||||
|
if option.startswith("gpu:"):
|
||||||
|
try:
|
||||||
|
idx = int(option[4:])
|
||||||
|
devices = get_all_torch_devices()
|
||||||
|
if 0 <= idx < len(devices):
|
||||||
|
return devices[idx]
|
||||||
|
else:
|
||||||
|
logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.")
|
||||||
|
return None
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
logging.warning(f"Invalid device option '{option}', using default.")
|
||||||
|
return None
|
||||||
|
logging.warning(f"Unrecognized device option '{option}', using default.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def cuda_device_context(device):
|
||||||
|
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||||
|
|
||||||
|
Used when running operations on a non-default CUDA device so that custom
|
||||||
|
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||||
|
device index. The previous device is restored on exit.
|
||||||
|
|
||||||
|
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||||
|
the current device.
|
||||||
|
"""
|
||||||
|
prev = None
|
||||||
|
if device.type == "cuda" and device.index is not None:
|
||||||
|
prev = torch.cuda.current_device()
|
||||||
|
if prev != device.index:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
else:
|
||||||
|
prev = None
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if prev is not None:
|
||||||
|
torch.cuda.set_device(prev)
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
|
|||||||
@ -1208,23 +1208,24 @@ class CFGGuider:
|
|||||||
all_devices = [device] + extra_devices
|
all_devices = [device] + extra_devices
|
||||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||||
|
|
||||||
try:
|
with comfy.model_management.cuda_device_context(device):
|
||||||
noise = noise.to(device=device, dtype=torch.float32)
|
try:
|
||||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
noise = noise.to(device=device, dtype=torch.float32)
|
||||||
sigmas = sigmas.to(device)
|
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
sigmas = sigmas.to(device)
|
||||||
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
|
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
for multigpu_patcher in multigpu_patchers:
|
for multigpu_patcher in multigpu_patchers:
|
||||||
multigpu_patcher.pre_run()
|
multigpu_patcher.pre_run()
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
finally:
|
finally:
|
||||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||||
if thread_pool is not None:
|
if thread_pool is not None:
|
||||||
thread_pool.shutdown()
|
thread_pool.shutdown()
|
||||||
self.model_patcher.cleanup()
|
self.model_patcher.cleanup()
|
||||||
for multigpu_patcher in multigpu_patchers:
|
for multigpu_patcher in multigpu_patchers:
|
||||||
multigpu_patcher.cleanup()
|
multigpu_patcher.cleanup()
|
||||||
|
|
||||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
del self.inner_model
|
del self.inner_model
|
||||||
|
|||||||
313
comfy/sd.py
313
comfy/sd.py
@ -324,41 +324,43 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
device = self.patcher.load_device
|
||||||
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
self.patcher.patch_hooks(None)
|
self.patcher.patch_hooks(None)
|
||||||
if show_pbar:
|
if show_pbar:
|
||||||
pbar = ProgressBar(len(scheduled_keyframes))
|
pbar = ProgressBar(len(scheduled_keyframes))
|
||||||
|
|
||||||
for scheduled_opts in scheduled_keyframes:
|
with model_management.cuda_device_context(device):
|
||||||
t_range = scheduled_opts[0]
|
for scheduled_opts in scheduled_keyframes:
|
||||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
t_range = scheduled_opts[0]
|
||||||
if "start_percent" in add_dict:
|
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||||
if t_range[1] < add_dict["start_percent"]:
|
if "start_percent" in add_dict:
|
||||||
continue
|
if t_range[1] < add_dict["start_percent"]:
|
||||||
if "end_percent" in add_dict:
|
continue
|
||||||
if t_range[0] > add_dict["end_percent"]:
|
if "end_percent" in add_dict:
|
||||||
continue
|
if t_range[0] > add_dict["end_percent"]:
|
||||||
hooks_keyframes = scheduled_opts[1]
|
continue
|
||||||
for hook, keyframe in hooks_keyframes:
|
hooks_keyframes = scheduled_opts[1]
|
||||||
hook.hook_keyframe._current_keyframe = keyframe
|
for hook, keyframe in hooks_keyframes:
|
||||||
# apply appropriate hooks with values that match new hook_keyframe
|
hook.hook_keyframe._current_keyframe = keyframe
|
||||||
self.patcher.patch_hooks(all_hooks)
|
# apply appropriate hooks with values that match new hook_keyframe
|
||||||
# perform encoding as normal
|
self.patcher.patch_hooks(all_hooks)
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
# perform encoding as normal
|
||||||
cond, pooled = o[:2]
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
pooled_dict = {"pooled_output": pooled}
|
cond, pooled = o[:2]
|
||||||
# add clip_start_percent and clip_end_percent in pooled
|
pooled_dict = {"pooled_output": pooled}
|
||||||
pooled_dict["clip_start_percent"] = t_range[0]
|
# add clip_start_percent and clip_end_percent in pooled
|
||||||
pooled_dict["clip_end_percent"] = t_range[1]
|
pooled_dict["clip_start_percent"] = t_range[0]
|
||||||
# add/update any keys with the provided add_dict
|
pooled_dict["clip_end_percent"] = t_range[1]
|
||||||
pooled_dict.update(add_dict)
|
# add/update any keys with the provided add_dict
|
||||||
# add hooks stored on clip
|
pooled_dict.update(add_dict)
|
||||||
self.add_hooks_to_dict(pooled_dict)
|
# add hooks stored on clip
|
||||||
all_cond_pooled.append([cond, pooled_dict])
|
self.add_hooks_to_dict(pooled_dict)
|
||||||
if show_pbar:
|
all_cond_pooled.append([cond, pooled_dict])
|
||||||
pbar.update(1)
|
if show_pbar:
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
pbar.update(1)
|
||||||
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
return all_cond_pooled
|
return all_cond_pooled
|
||||||
|
|
||||||
@ -372,8 +374,12 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
device = self.patcher.load_device
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
|
|
||||||
|
with model_management.cuda_device_context(device):
|
||||||
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
|
|
||||||
cond, pooled = o[:2]
|
cond, pooled = o[:2]
|
||||||
if return_dict:
|
if return_dict:
|
||||||
out = {"cond": cond, "pooled_output": pooled}
|
out = {"cond": cond, "pooled_output": pooled}
|
||||||
@ -428,9 +434,12 @@ class CLIP:
|
|||||||
self.cond_stage_model.reset_clip_options()
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
|
device = self.patcher.load_device
|
||||||
self.cond_stage_model.set_clip_options({"layer": None})
|
self.cond_stage_model.set_clip_options({"layer": None})
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
|
||||||
|
with model_management.cuda_device_context(device):
|
||||||
|
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
@ -947,50 +956,52 @@ class VAE:
|
|||||||
do_tile = False
|
do_tile = False
|
||||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||||
samples_in = samples_in[:, :, 0]
|
samples_in = samples_in[:, :, 0]
|
||||||
try:
|
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
|
||||||
free_memory = self.patcher.get_free_memory(self.device)
|
|
||||||
batch_number = int(free_memory / memory_used)
|
|
||||||
batch_number = max(1, batch_number)
|
|
||||||
|
|
||||||
# Pre-allocate output for VAEs that support direct buffer writes
|
with model_management.cuda_device_context(self.device):
|
||||||
preallocated = False
|
try:
|
||||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
preallocated = True
|
free_memory = self.patcher.get_free_memory(self.device)
|
||||||
|
batch_number = int(free_memory / memory_used)
|
||||||
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
# Pre-allocate output for VAEs that support direct buffer writes
|
||||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
preallocated = False
|
||||||
if preallocated:
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
else:
|
preallocated = True
|
||||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
|
||||||
if pixel_samples is None:
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
|
||||||
pixel_samples[x:x+batch_number].copy_(out)
|
|
||||||
del out
|
|
||||||
self.process_output(pixel_samples[x:x+batch_number])
|
|
||||||
except Exception as e:
|
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
|
||||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
|
||||||
#exception is fully off the books.
|
|
||||||
do_tile = True
|
|
||||||
|
|
||||||
if do_tile:
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
comfy.model_management.soft_empty_cache()
|
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||||
dims = samples_in.ndim - 2
|
if preallocated:
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
else:
|
||||||
elif dims == 2:
|
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
if pixel_samples is None:
|
||||||
elif dims == 3:
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
tile = 256 // self.spacial_compression_decode()
|
pixel_samples[x:x+batch_number].copy_(out)
|
||||||
overlap = tile // 4
|
del out
|
||||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
self.process_output(pixel_samples[x:x+batch_number])
|
||||||
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
dims = samples_in.ndim - 2
|
||||||
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
|
elif dims == 2:
|
||||||
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
elif dims == 3:
|
||||||
|
tile = 256 // self.spacial_compression_decode()
|
||||||
|
overlap = tile // 4
|
||||||
|
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@ -1008,20 +1019,21 @@ class VAE:
|
|||||||
if overlap is not None:
|
if overlap is not None:
|
||||||
args["overlap"] = overlap
|
args["overlap"] = overlap
|
||||||
|
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
with model_management.cuda_device_context(self.device):
|
||||||
args.pop("tile_y")
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
output = self.decode_tiled_1d(samples, **args)
|
args.pop("tile_y")
|
||||||
elif dims == 2:
|
output = self.decode_tiled_1d(samples, **args)
|
||||||
output = self.decode_tiled_(samples, **args)
|
elif dims == 2:
|
||||||
elif dims == 3:
|
output = self.decode_tiled_(samples, **args)
|
||||||
if overlap_t is None:
|
elif dims == 3:
|
||||||
args["overlap"] = (1, overlap, overlap)
|
if overlap_t is None:
|
||||||
else:
|
args["overlap"] = (1, overlap, overlap)
|
||||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
else:
|
||||||
if tile_t is not None:
|
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||||
args["tile_t"] = max(2, tile_t)
|
if tile_t is not None:
|
||||||
|
args["tile_t"] = max(2, tile_t)
|
||||||
|
|
||||||
output = self.decode_tiled_3d(samples, **args)
|
output = self.decode_tiled_3d(samples, **args)
|
||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
@ -1034,44 +1046,46 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
pixel_samples = pixel_samples.unsqueeze(2)
|
pixel_samples = pixel_samples.unsqueeze(2)
|
||||||
try:
|
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
with model_management.cuda_device_context(self.device):
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
try:
|
||||||
free_memory = self.patcher.get_free_memory(self.device)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
batch_number = int(free_memory / max(1, memory_used))
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
batch_number = max(1, batch_number)
|
free_memory = self.patcher.get_free_memory(self.device)
|
||||||
samples = None
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
batch_number = max(1, batch_number)
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
samples = None
|
||||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||||
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
|
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||||
|
else:
|
||||||
|
pixels_in = pixels_in.to(self.device)
|
||||||
|
out = self.first_stage_model.encode(pixels_in)
|
||||||
|
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||||
|
if samples is None:
|
||||||
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
if self.latent_dim == 3:
|
||||||
|
tile = 256
|
||||||
|
overlap = tile // 4
|
||||||
|
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||||
|
samples = self.encode_tiled_1d(pixel_samples)
|
||||||
else:
|
else:
|
||||||
pixels_in = pixels_in.to(self.device)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
out = self.first_stage_model.encode(pixels_in)
|
|
||||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
|
||||||
if samples is None:
|
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
|
||||||
samples[x:x + batch_number] = out
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
|
||||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
|
||||||
#exception is fully off the books.
|
|
||||||
do_tile = True
|
|
||||||
|
|
||||||
if do_tile:
|
|
||||||
comfy.model_management.soft_empty_cache()
|
|
||||||
if self.latent_dim == 3:
|
|
||||||
tile = 256
|
|
||||||
overlap = tile // 4
|
|
||||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
|
||||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
|
||||||
samples = self.encode_tiled_1d(pixel_samples)
|
|
||||||
else:
|
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -1097,26 +1111,27 @@ class VAE:
|
|||||||
if overlap is not None:
|
if overlap is not None:
|
||||||
args["overlap"] = overlap
|
args["overlap"] = overlap
|
||||||
|
|
||||||
if dims == 1:
|
with model_management.cuda_device_context(self.device):
|
||||||
args.pop("tile_y")
|
if dims == 1:
|
||||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
args.pop("tile_y")
|
||||||
elif dims == 2:
|
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||||
samples = self.encode_tiled_(pixel_samples, **args)
|
elif dims == 2:
|
||||||
elif dims == 3:
|
samples = self.encode_tiled_(pixel_samples, **args)
|
||||||
if tile_t is not None:
|
elif dims == 3:
|
||||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
if tile_t is not None:
|
||||||
else:
|
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||||
tile_t_latent = 9999
|
else:
|
||||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
tile_t_latent = 9999
|
||||||
|
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||||
|
|
||||||
if overlap_t is None:
|
if overlap_t is None:
|
||||||
args["overlap"] = (1, overlap, overlap)
|
args["overlap"] = (1, overlap, overlap)
|
||||||
else:
|
else:
|
||||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||||
maximum = pixel_samples.shape[2]
|
maximum = pixel_samples.shape[2]
|
||||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||||
|
|
||||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -1633,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||||
|
|
||||||
custom_operations = model_options.get("custom_operations", None)
|
custom_operations = model_options.get("custom_operations", None)
|
||||||
if custom_operations is None:
|
if custom_operations is None:
|
||||||
@ -1673,13 +1688,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||||
|
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
vae_device = model_options.get("load_device", None)
|
||||||
|
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
if te_model_options.get("custom_operations", None) is None:
|
if te_model_options.get("custom_operations", None) is None:
|
||||||
@ -1763,7 +1780,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
|||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
@ -1788,7 +1805,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
|||||||
else:
|
else:
|
||||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.quant_config is not None:
|
if model_config.quant_config is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|||||||
@ -188,7 +188,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
|||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"device",
|
"device",
|
||||||
options=["default", "cpu"],
|
options=comfy.model_management.get_gpu_device_options(),
|
||||||
advanced=True,
|
advanced=True,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -203,8 +203,12 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
|||||||
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
if resolved is not None:
|
||||||
|
if resolved.type == "cpu":
|
||||||
|
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||||
|
else:
|
||||||
|
model_options["load_device"] = resolved
|
||||||
|
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||||
return io.NodeOutput(clip)
|
return io.NodeOutput(clip)
|
||||||
|
|||||||
127
nodes.py
127
nodes.py
@ -608,6 +608,73 @@ class CheckpointLoaderSimple:
|
|||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointLoaderDevice:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
device_options = comfy.model_management.get_gpu_device_options()
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
|
||||||
|
"clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}),
|
||||||
|
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
|
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
|
||||||
|
"The CLIP model used for encoding text prompts.",
|
||||||
|
"The VAE model used for encoding and decoding images to and from latent space.")
|
||||||
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders"
|
||||||
|
DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"):
|
||||||
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
|
||||||
|
model_options = {}
|
||||||
|
resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
|
||||||
|
if resolved_model is not None:
|
||||||
|
if resolved_model.type == "cpu":
|
||||||
|
model_options["load_device"] = model_options["offload_device"] = resolved_model
|
||||||
|
else:
|
||||||
|
model_options["load_device"] = resolved_model
|
||||||
|
|
||||||
|
te_model_options = {}
|
||||||
|
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device)
|
||||||
|
if resolved_clip is not None:
|
||||||
|
if resolved_clip.type == "cpu":
|
||||||
|
te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip
|
||||||
|
else:
|
||||||
|
te_model_options["load_device"] = resolved_clip
|
||||||
|
|
||||||
|
# VAE device is passed via model_options["load_device"] which
|
||||||
|
# load_state_dict_guess_config forwards to the VAE constructor.
|
||||||
|
# If vae_device differs from model_device, we override after loading.
|
||||||
|
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
|
||||||
|
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options)
|
||||||
|
model_patcher, clip, vae = out[:3]
|
||||||
|
|
||||||
|
# Apply VAE device override if it differs from the model device
|
||||||
|
if resolved_vae is not None and vae is not None:
|
||||||
|
vae.device = resolved_vae
|
||||||
|
if resolved_vae.type == "cpu":
|
||||||
|
offload = resolved_vae
|
||||||
|
else:
|
||||||
|
offload = comfy.model_management.vae_offload_device()
|
||||||
|
vae.patcher.load_device = resolved_vae
|
||||||
|
vae.patcher.offload_device = offload
|
||||||
|
|
||||||
|
return (model_patcher, clip, vae)
|
||||||
|
|
||||||
class DiffusersLoader:
|
class DiffusersLoader:
|
||||||
SEARCH_ALIASES = ["load diffusers model"]
|
SEARCH_ALIASES = ["load diffusers model"]
|
||||||
|
|
||||||
@ -807,14 +874,21 @@ class VAELoader:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "vae_name": (s.vae_list(s), )}}
|
return {"required": { "vae_name": (s.vae_list(s), )},
|
||||||
|
"optional": {
|
||||||
|
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("VAE",)
|
RETURN_TYPES = ("VAE",)
|
||||||
FUNCTION = "load_vae"
|
FUNCTION = "load_vae"
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, device="default"):
|
||||||
|
return True
|
||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name, device="default"):
|
||||||
metadata = None
|
metadata = None
|
||||||
if vae_name == "pixel_space":
|
if vae_name == "pixel_space":
|
||||||
sd = {}
|
sd = {}
|
||||||
@ -827,7 +901,8 @@ class VAELoader:
|
|||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
|
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
@ -953,13 +1028,20 @@ class UNETLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_unet(self, unet_name, weight_dtype):
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, device="default"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_unet(self, unet_name, weight_dtype, device="default"):
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
model_options["dtype"] = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
@ -969,6 +1051,13 @@ class UNETLoader:
|
|||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
model_options["dtype"] = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
|
if resolved is not None:
|
||||||
|
if resolved.type == "cpu":
|
||||||
|
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||||
|
else:
|
||||||
|
model_options["load_device"] = resolved
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
@ -980,7 +1069,7 @@ class CLIPLoader:
|
|||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@ -989,12 +1078,20 @@ class CLIPLoader:
|
|||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, device="default"):
|
||||||
|
return True
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
if resolved is not None:
|
||||||
|
if resolved.type == "cpu":
|
||||||
|
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||||
|
else:
|
||||||
|
model_options["load_device"] = resolved
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||||
@ -1008,7 +1105,7 @@ class DualCLIPLoader:
|
|||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@ -1017,6 +1114,10 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, device="default"):
|
||||||
|
return True
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
|
|
||||||
@ -1024,8 +1125,12 @@ class DualCLIPLoader:
|
|||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
if resolved is not None:
|
||||||
|
if resolved.type == "cpu":
|
||||||
|
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||||
|
else:
|
||||||
|
model_options["load_device"] = resolved
|
||||||
|
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
@ -2098,6 +2203,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"InpaintModelConditioning": InpaintModelConditioning,
|
"InpaintModelConditioning": InpaintModelConditioning,
|
||||||
|
|
||||||
"CheckpointLoader": CheckpointLoader,
|
"CheckpointLoader": CheckpointLoader,
|
||||||
|
"CheckpointLoaderDevice": CheckpointLoaderDevice,
|
||||||
"DiffusersLoader": DiffusersLoader,
|
"DiffusersLoader": DiffusersLoader,
|
||||||
|
|
||||||
"LoadLatent": LoadLatent,
|
"LoadLatent": LoadLatent,
|
||||||
@ -2115,6 +2221,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
# Loaders
|
# Loaders
|
||||||
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
||||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||||
|
"CheckpointLoaderDevice": "Load Checkpoint (Device)",
|
||||||
"VAELoader": "Load VAE",
|
"VAELoader": "Load VAE",
|
||||||
"LoraLoader": "Load LoRA (Model and CLIP)",
|
"LoraLoader": "Load LoRA (Model and CLIP)",
|
||||||
"LoraLoaderModelOnly": "Load LoRA",
|
"LoraLoaderModelOnly": "Load LoRA",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user