Add GPU device selection to all loader nodes

- Add get_gpu_device_options() and resolve_gpu_device_option() helpers
  in model_management.py for vendor-agnostic GPU device selection
- Add device widget to CheckpointLoaderSimple, UNETLoader, VAELoader
- Expand device options in CLIPLoader, DualCLIPLoader, LTXAVTextEncoderLoader
  from [default, cpu] to include gpu:0, gpu:1, etc. on multi-GPU systems
- Wire load_diffusion_model_state_dict and load_state_dict_guess_config
  to respect model_options['load_device']
- Graceful fallback: unrecognized devices (e.g. gpu:1 on single-GPU)
  silently fall back to default

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-04-20 09:12:39 -07:00
parent bae191c294
commit bf0e7bd246
4 changed files with 87 additions and 17 deletions

View File

@ -231,6 +231,46 @@ def get_all_torch_devices(exclude_current=False):
devices.remove(get_torch_device())
return devices
def get_gpu_device_options():
"""Return list of device option strings for node widgets.
Always includes "default" and "cpu". When multiple GPUs are present,
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
"""
options = ["default", "cpu"]
devices = get_all_torch_devices()
if len(devices) > 1:
for i in range(len(devices)):
options.append(f"gpu:{i}")
return options
def resolve_gpu_device_option(option: str):
"""Resolve a device option string to a torch.device.
Returns None for "default" (let the caller use its normal default).
Returns torch.device("cpu") for "cpu".
For "gpu:N", returns the Nth torch device. Falls back to None if
the index is out of range (caller should use default).
"""
if option is None or option == "default":
return None
if option == "cpu":
return torch.device("cpu")
if option.startswith("gpu:"):
try:
idx = int(option[4:])
devices = get_all_torch_devices()
if idx < len(devices):
return devices[idx]
else:
logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.")
return None
except (ValueError, IndexError):
logging.warning(f"Invalid device option '{option}', using default.")
return None
logging.warning(f"Unrecognized device option '{option}', using default.")
return None
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:

View File

@ -1633,7 +1633,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
load_device = model_options.get("load_device", model_management.get_torch_device())
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
@ -1763,7 +1763,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
load_device = model_options.get("load_device", model_management.get_torch_device())
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:

View File

@ -188,7 +188,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
),
io.Combo.Input(
"device",
options=["default", "cpu"],
options=comfy.model_management.get_gpu_device_options(),
advanced=True,
)
],
@ -203,8 +203,12 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip)

View File

@ -591,6 +591,9 @@ class CheckpointLoaderSimple:
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
},
"optional": {
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
@ -603,9 +606,13 @@ class CheckpointLoaderSimple:
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"]
def load_checkpoint(self, ckpt_name):
def load_checkpoint(self, ckpt_name, device="default"):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
model_options = {}
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
model_options["load_device"] = resolved
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
return out[:3]
class DiffusersLoader:
@ -807,14 +814,17 @@ class VAELoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(s), )}}
return {"required": { "vae_name": (s.vae_list(s), )},
"optional": {
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
CATEGORY = "loaders"
#TODO: scale factor?
def load_vae(self, vae_name):
def load_vae(self, vae_name, device="default"):
metadata = None
if vae_name == "pixel_space":
sd = {}
@ -827,7 +837,8 @@ class VAELoader:
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
resolved = comfy.model_management.resolve_gpu_device_option(device)
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
vae.throw_exception_if_invalid()
return (vae,)
@ -953,13 +964,16 @@ class UNETLoader:
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
},
"optional": {
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype):
def load_unet(self, unet_name, weight_dtype, device="default"):
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
@ -969,6 +983,10 @@ class UNETLoader:
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
model_options["load_device"] = resolved
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
@ -980,7 +998,7 @@ class CLIPLoader:
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -993,8 +1011,12 @@ class CLIPLoader:
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
@ -1008,7 +1030,7 @@ class DualCLIPLoader:
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -1024,8 +1046,12 @@ class DualCLIPLoader:
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return (clip,)