Restore CheckpointLoaderSimple, add CheckpointLoaderDevice
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

Revert CheckpointLoaderSimple to its original form (no device input)
so it remains the simple default loader.

Add new CheckpointLoaderDevice node (advanced/loaders) with separate
model_device, clip_device, and vae_device inputs for per-component
GPU placement in multi-GPU setups.

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-04-21 11:12:35 -07:00
parent 767b4ee099
commit 3931ef6a1f

View File

@ -591,9 +591,6 @@ class CheckpointLoaderSimple:
return { return {
"required": { "required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
},
"optional": {
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True, "tooltip": "Target GPU device. Use 'default' for automatic, 'cpu' for CPU, or 'gpu:N' for a specific GPU."}),
} }
} }
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
@ -606,24 +603,77 @@ class CheckpointLoaderSimple:
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"] SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"]
def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3]
class CheckpointLoaderDevice:
@classmethod @classmethod
def VALIDATE_INPUTS(cls, device="default"): def INPUT_TYPES(s):
device_options = comfy.model_management.get_gpu_device_options()
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
},
"optional": {
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
"clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}),
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.")
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups."
@classmethod
def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"):
return True return True
def load_checkpoint(self, ckpt_name, device="default"): def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {} model_options = {}
te_model_options = {} resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
resolved = comfy.model_management.resolve_gpu_device_option(device) if resolved_model is not None:
if resolved is not None: if resolved_model.type == "cpu":
if resolved.type == "cpu": model_options["load_device"] = model_options["offload_device"] = resolved_model
model_options["load_device"] = model_options["offload_device"] = resolved
te_model_options["load_device"] = te_model_options["offload_device"] = resolved
else: else:
model_options["load_device"] = resolved model_options["load_device"] = resolved_model
te_model_options["load_device"] = resolved
te_model_options = {}
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device)
if resolved_clip is not None:
if resolved_clip.type == "cpu":
te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip
else:
te_model_options["load_device"] = resolved_clip
# VAE device is passed via model_options["load_device"] which
# load_state_dict_guess_config forwards to the VAE constructor.
# If vae_device differs from model_device, we override after loading.
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options)
return out[:3] model_patcher, clip, vae = out[:3]
# Apply VAE device override if it differs from the model device
if resolved_vae is not None and vae is not None:
vae.device = resolved_vae
if resolved_vae.type == "cpu":
offload = resolved_vae
else:
offload = comfy.model_management.vae_offload_device()
vae.patcher.load_device = resolved_vae
vae.patcher.offload_device = offload
return (model_patcher, clip, vae)
class DiffusersLoader: class DiffusersLoader:
SEARCH_ALIASES = ["load diffusers model"] SEARCH_ALIASES = ["load diffusers model"]
@ -2153,6 +2203,7 @@ NODE_CLASS_MAPPINGS = {
"InpaintModelConditioning": InpaintModelConditioning, "InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"CheckpointLoaderDevice": CheckpointLoaderDevice,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent, "LoadLatent": LoadLatent,
@ -2170,6 +2221,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# Loaders # Loaders
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint", "CheckpointLoaderSimple": "Load Checkpoint",
"CheckpointLoaderDevice": "Load Checkpoint (Device)",
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"LoraLoader": "Load LoRA (Model and CLIP)", "LoraLoader": "Load LoRA (Model and CLIP)",
"LoraLoaderModelOnly": "Load LoRA", "LoraLoaderModelOnly": "Load LoRA",