diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index bf98e6b82..a3b148d7d 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -23,6 +23,69 @@ class ImageOnlyCheckpointLoader: return (out[0], out[3], out[2]) +class ImageOnlyCheckpointLoaderDevice: + @classmethod + def INPUT_TYPES(s): + device_options = comfy.model_management.get_gpu_device_options() + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + }, + "optional": { + "model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}), + "clip_vision_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP vision encoder."}), + "vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}), + } + } + RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders/video_models" + + @classmethod + def VALIDATE_INPUTS(cls, model_device="default", clip_vision_device="default", vae_device="default"): + return True + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True, model_device="default", clip_vision_device="default", vae_device="default"): + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + model_options = {} + resolved_model = comfy.model_management.resolve_gpu_device_option(model_device) + if resolved_model is not None: + if resolved_model.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved_model + else: + model_options["load_device"] = resolved_model + + cv_model_options = {} + resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_vision_device) + if resolved_clip is not None: + if resolved_clip.type == "cpu": + cv_model_options["load_device"] = cv_model_options["offload_device"] = resolved_clip + else: + cv_model_options["load_device"] = resolved_clip + + # VAE device is passed via model_options["load_device"] which + # load_state_dict_guess_config forwards to the VAE constructor. + # If vae_device differs from model_device, we override after loading. + resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device) + + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + model_patcher, clip, vae, clip_vision = out[:4] + + # Apply VAE device override if it differs from the model device + if resolved_vae is not None and vae is not None: + vae.device = resolved_vae + if resolved_vae.type == "cpu": + offload = resolved_vae + else: + offload = comfy.model_management.vae_offload_device() + vae.patcher.load_device = resolved_vae + vae.patcher.offload_device = offload + + return (model_patcher, clip_vision, vae) + + class SVD_img2vid_Conditioning: @classmethod def INPUT_TYPES(s): @@ -149,6 +212,7 @@ class ConditioningSetAreaPercentageVideo: NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, + "ImageOnlyCheckpointLoaderDevice": ImageOnlyCheckpointLoaderDevice, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, @@ -158,4 +222,5 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)", + "ImageOnlyCheckpointLoaderDevice": "Image Only Checkpoint Loader (Device)", }