mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 16:22:38 +08:00
Add device selection on Image Only Load Checkpoint (CORE-158) (#13748)
* Add device selection on Image Only Load Checkpoint * Rename variables * Update variable name * Fix linting
This commit is contained in:
parent
1b96430c60
commit
a61e2bbb85
@ -23,6 +23,69 @@ class ImageOnlyCheckpointLoader:
|
|||||||
return (out[0], out[3], out[2])
|
return (out[0], out[3], out[2])
|
||||||
|
|
||||||
|
|
||||||
|
class ImageOnlyCheckpointLoaderDevice:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
device_options = comfy.model_management.get_gpu_device_options()
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
|
||||||
|
"clip_vision_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP vision encoder."}),
|
||||||
|
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||||
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
|
CATEGORY = "loaders/video_models"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, model_device="default", clip_vision_device="default", vae_device="default"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True, model_device="default", clip_vision_device="default", vae_device="default"):
|
||||||
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
|
||||||
|
model_options = {}
|
||||||
|
resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
|
||||||
|
if resolved_model is not None:
|
||||||
|
if resolved_model.type == "cpu":
|
||||||
|
model_options["load_device"] = model_options["offload_device"] = resolved_model
|
||||||
|
else:
|
||||||
|
model_options["load_device"] = resolved_model
|
||||||
|
|
||||||
|
cv_model_options = {}
|
||||||
|
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_vision_device)
|
||||||
|
if resolved_clip is not None:
|
||||||
|
if resolved_clip.type == "cpu":
|
||||||
|
cv_model_options["load_device"] = cv_model_options["offload_device"] = resolved_clip
|
||||||
|
else:
|
||||||
|
cv_model_options["load_device"] = resolved_clip
|
||||||
|
|
||||||
|
# VAE device is passed via model_options["load_device"] which
|
||||||
|
# load_state_dict_guess_config forwards to the VAE constructor.
|
||||||
|
# If vae_device differs from model_device, we override after loading.
|
||||||
|
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
|
||||||
|
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
model_patcher, clip, vae, clip_vision = out[:4]
|
||||||
|
|
||||||
|
# Apply VAE device override if it differs from the model device
|
||||||
|
if resolved_vae is not None and vae is not None:
|
||||||
|
vae.device = resolved_vae
|
||||||
|
if resolved_vae.type == "cpu":
|
||||||
|
offload = resolved_vae
|
||||||
|
else:
|
||||||
|
offload = comfy.model_management.vae_offload_device()
|
||||||
|
vae.patcher.load_device = resolved_vae
|
||||||
|
vae.patcher.offload_device = offload
|
||||||
|
|
||||||
|
return (model_patcher, clip_vision, vae)
|
||||||
|
|
||||||
|
|
||||||
class SVD_img2vid_Conditioning:
|
class SVD_img2vid_Conditioning:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -149,6 +212,7 @@ class ConditioningSetAreaPercentageVideo:
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||||
|
"ImageOnlyCheckpointLoaderDevice": ImageOnlyCheckpointLoaderDevice,
|
||||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||||
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
||||||
@ -158,4 +222,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
|
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
|
||||||
|
"ImageOnlyCheckpointLoaderDevice": "Image Only Checkpoint Loader (Device)",
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user