diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 988acdb57..6cb69dcdf 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -236,11 +236,11 @@ class ControlNet(ControlBase): self.cond_hint = None compression_ratio = self.compression_ratio if self.vae is not None: - compression_ratio *= self.vae.downscale_ratio + compression_ratio *= self.vae.spacial_compression_encode() else: if self.latent_format is not None: raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.") - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = self.preprocess_image(self.cond_hint) if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 2503583cb..d0e39833a 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -293,6 +293,7 @@ class QwenImageTransformer2DModel(nn.Module): guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), image_model=None, + final_layer=True, dtype=None, device=None, operations=None, @@ -300,6 +301,7 @@ class QwenImageTransformer2DModel(nn.Module): super().__init__() self.dtype = dtype self.patch_size = patch_size + self.in_channels = in_channels self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim @@ -329,9 +331,9 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) - self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) - self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) - self.gradient_checkpointing = False + if final_layer: + self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape @@ -362,6 +364,7 @@ class QwenImageTransformer2DModel(nn.Module): guidance: torch.Tensor = None, ref_latents=None, transformer_options={}, + control=None, **kwargs ): timestep = timesteps @@ -443,6 +446,13 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = out["img"] encoder_hidden_states = out["txt"] + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states += add + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2bec0541e..0caff53e0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -492,6 +492,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" + dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') return dit_config if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index db24e6da4..d28895f3e 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -1,8 +1,8 @@ -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from inspect import cleandoc +from io import BytesIO +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from PIL import Image import numpy as np -import io import torch from comfy_api_nodes.apis import ( IdeogramGenerateRequest, @@ -246,90 +246,81 @@ def display_image_urls_on_node(image_urls, node_id): PromptServer.instance.send_progress_text(urls_text, node_id) -class IdeogramV1(ComfyNodeABC): - """ - Generates images using the Ideogram V1 model. - """ - - def __init__(self): - pass +class IdeogramV1(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV1", + display_name="Ideogram V1", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V1 model.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "turbo": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", - } + comfy_io.Boolean.Input( + "turbo", + default=False, + tooltip="Whether to use turbo mode (faster generation, potentially lower quality)", ), - }, - "optional": { - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V1_V2_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V1_V2_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Description of what to exclude from the image", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Description of what to exclude from the image", + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, turbo=False, aspect_ratio="1:1", @@ -337,13 +328,15 @@ class IdeogramV1(ComfyNodeABC): seed=0, negative_prompt="", num_images=1, - unique_id=None, - **kwargs, ): # Determine the model based on turbo setting aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) model = "V_1_TURBO" if turbo else "V_1" + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/ideogram/generate", @@ -364,7 +357,7 @@ class IdeogramV1(ComfyNodeABC): negative_prompt=negative_prompt if negative_prompt else None, ) ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response = await operation.execute() @@ -377,93 +370,85 @@ class IdeogramV1(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -class IdeogramV2(ComfyNodeABC): - """ - Generates images using the Ideogram V2 model. - """ - - def __init__(self): - pass +class IdeogramV2(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV2", + display_name="Ideogram V2", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V2 model.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "turbo": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", - } + comfy_io.Boolean.Input( + "turbo", + default=False, + tooltip="Whether to use turbo mode (faster generation, potentially lower quality)", ), - }, - "optional": { - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V1_V2_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V1_V2_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "options": list(V1_V1_RES_MAP.keys()), - "default": "Auto", - "tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.", - }, + comfy_io.Combo.Input( + "resolution", + options=list(V1_V1_RES_MAP.keys()), + default="Auto", + tooltip="The resolution for image generation. " + "If not set to AUTO, this overrides the aspect_ratio setting.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "style_type": ( - IO.COMBO, - { - "options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], - "default": "NONE", - "tooltip": "Style type for generation (V2 only)", - }, + comfy_io.Combo.Input( + "style_type", + options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], + default="NONE", + tooltip="Style type for generation (V2 only)", + optional=True, ), - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Description of what to exclude from the image", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Description of what to exclude from the image", + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), #"color_palette": ( # IO.STRING, @@ -473,22 +458,20 @@ class IdeogramV2(ComfyNodeABC): # "tooltip": "Color palette preset name or hex colors with weights", # }, #), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, turbo=False, aspect_ratio="1:1", @@ -499,8 +482,6 @@ class IdeogramV2(ComfyNodeABC): negative_prompt="", num_images=1, color_palette="", - unique_id=None, - **kwargs, ): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) resolution = V1_V1_RES_MAP.get(resolution, None) @@ -517,6 +498,10 @@ class IdeogramV2(ComfyNodeABC): else: final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/ideogram/generate", @@ -540,7 +525,7 @@ class IdeogramV2(ComfyNodeABC): color_palette=color_palette if color_palette else None, ) ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response = await operation.execute() @@ -553,108 +538,99 @@ class IdeogramV2(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -class IdeogramV3(ComfyNodeABC): - """ - Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask. - """ - def __init__(self): - pass +class IdeogramV3(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation or editing", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV3", + display_name="Ideogram V3", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V3 model. " + "Supports both regular image generation from text prompts and image editing with mask.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation or editing", ), - }, - "optional": { - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, + comfy_io.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, + comfy_io.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, ), - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V3_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V3_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "options": V3_RESOLUTIONS, - "default": "Auto", - "tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.", - }, + comfy_io.Combo.Input( + "resolution", + options=V3_RESOLUTIONS, + default="Auto", + tooltip="The resolution for image generation. " + "If not set to Auto, this overrides the aspect_ratio setting.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "rendering_speed": ( - IO.COMBO, - { - "options": ["BALANCED", "TURBO", "QUALITY"], - "default": "BALANCED", - "tooltip": "Controls the trade-off between generation speed and quality", - }, + comfy_io.Combo.Input( + "rendering_speed", + options=["BALANCED", "TURBO", "QUALITY"], + default="BALANCED", + tooltip="Controls the trade-off between generation speed and quality", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, image=None, mask=None, @@ -664,9 +640,11 @@ class IdeogramV3(ComfyNodeABC): seed=0, num_images=1, rendering_speed="BALANCED", - unique_id=None, - **kwargs, ): + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # Check if both image and mask are provided for editing mode if image is not None and mask is not None: # Edit mode @@ -686,7 +664,7 @@ class IdeogramV3(ComfyNodeABC): # Process image img_np = (input_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(img_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) img_binary = img_byte_arr @@ -695,7 +673,7 @@ class IdeogramV3(ComfyNodeABC): # Process mask - white areas will be replaced mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) - mask_byte_arr = io.BytesIO() + mask_byte_arr = BytesIO() mask_img.save(mask_byte_arr, format="PNG") mask_byte_arr.seek(0) mask_binary = mask_byte_arr @@ -729,7 +707,7 @@ class IdeogramV3(ComfyNodeABC): "mask": mask_binary, }, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) elif image is not None or mask is not None: @@ -770,7 +748,7 @@ class IdeogramV3(ComfyNodeABC): response_model=IdeogramGenerateResponse, ), request=gen_request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) # Execute the operation and process response @@ -784,18 +762,18 @@ class IdeogramV3(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -NODE_CLASS_MAPPINGS = { - "IdeogramV1": IdeogramV1, - "IdeogramV2": IdeogramV2, - "IdeogramV3": IdeogramV3, -} +class IdeogramExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + IdeogramV1, + IdeogramV2, + IdeogramV3, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "IdeogramV1": "Ideogram V1", - "IdeogramV2": "Ideogram V2", - "IdeogramV3": "Ideogram V3", -} +async def comfy_entrypoint() -> IdeogramExtension: + return IdeogramExtension() diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e25dab2f5..251aecd42 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,17 +1,18 @@ -import io import logging import base64 import aiohttp import torch +from io import BytesIO from typing import Optional +from typing_extensions import override -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis import ( VeoGenVidRequest, VeoGenVidResponse, VeoGenVidPollRequest, - VeoGenVidPollResponse + VeoGenVidPollResponse, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -22,7 +23,7 @@ from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apinode_utils import ( downscale_image_tensor, - tensor_to_base64_string + tensor_to_base64_string, ) AVERAGE_DURATION_VIDEO_GEN = 32 @@ -50,7 +51,7 @@ def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optiona return None -class VeoVideoGenerationNode(ComfyNodeABC): +class VeoVideoGenerationNode(comfy_io.ComfyNode): """ Generates videos from text prompts using Google's Veo API. @@ -59,101 +60,93 @@ class VeoVideoGenerationNode(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text description of the video", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="VeoVideoGenerationNode", + display_name="Google Veo 2 Video Generation", + category="api node/video/Veo", + description="Generates videos from text prompts using Google's Veo 2 API", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", ), - "aspect_ratio": ( - IO.COMBO, - { - "options": ["16:9", "9:16"], - "default": "16:9", - "tooltip": "Aspect ratio of the output video", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Negative text prompt to guide what to avoid in the video", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + optional=True, ), - "duration_seconds": ( - IO.INT, - { - "default": 5, - "min": 5, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "Duration of the output video in seconds", - }, + comfy_io.Int.Input( + "duration_seconds", + default=5, + min=5, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, ), - "enhance_prompt": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Whether to enhance the prompt with AI assistance", - } + comfy_io.Boolean.Input( + "enhance_prompt", + default=True, + tooltip="Whether to enhance the prompt with AI assistance", + optional=True, ), - "person_generation": ( - IO.COMBO, - { - "options": ["ALLOW", "BLOCK"], - "default": "ALLOW", - "tooltip": "Whether to allow generating people in the video", - }, + comfy_io.Combo.Input( + "person_generation", + options=["ALLOW", "BLOCK"], + default="ALLOW", + tooltip="Whether to allow generating people in the video", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFF, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "Seed for video generation (0 for random)", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, ), - "image": (IO.IMAGE, { - "default": None, - "tooltip": "Optional reference image to guide video generation", - }), - "model": ( - IO.COMBO, - { - "options": ["veo-2.0-generate-001"], - "default": "veo-2.0-generate-001", - "tooltip": "Veo 2 model to use for video generation", - }, + comfy_io.Image.Input( + "image", + tooltip="Optional reference image to guide video generation", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + comfy_io.Combo.Input( + "model", + options=["veo-2.0-generate-001"], + default="veo-2.0-generate-001", + tooltip="Veo 2 model to use for video generation", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "generate_video" - CATEGORY = "api node/video/Veo" - DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API" - API_NODE = True - - async def generate_video( - self, + @classmethod + async def execute( + cls, prompt, aspect_ratio="16:9", negative_prompt="", @@ -164,8 +157,6 @@ class VeoVideoGenerationNode(ComfyNodeABC): image=None, model="veo-2.0-generate-001", generate_audio=False, - unique_id: Optional[str] = None, - **kwargs, ): # Prepare the instances for the request instances = [] @@ -202,6 +193,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): if "veo-3.0" in model: parameters["generateAudio"] = generate_audio + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # Initial request to start video generation initial_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -214,7 +209,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): instances=instances, parameters=parameters ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) initial_response = await initial_operation.execute() @@ -248,10 +243,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): request=VeoGenVidPollRequest( operationName=operation_name ), - auth_kwargs=kwargs, + auth_kwargs=auth, poll_interval=5.0, result_url_extractor=get_video_url_from_response, - node_id=unique_id, + node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) @@ -304,10 +299,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): logging.info("Video generation completed successfully") # Convert video data to BytesIO object - video_io = io.BytesIO(video_data) + video_io = BytesIO(video_data) # Return VideoFromFile object - return (VideoFromFile(video_io),) + return comfy_io.NodeOutput(VideoFromFile(video_io)) class Veo3VideoGenerationNode(VeoVideoGenerationNode): @@ -323,51 +318,104 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): """ @classmethod - def INPUT_TYPES(s): - parent_input = super().INPUT_TYPES() - - # Update model options for Veo 3 - parent_input["optional"]["model"] = ( - IO.COMBO, - { - "options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"], - "default": "veo-3.0-generate-001", - "tooltip": "Veo 3 model to use for video generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="Veo3VideoGenerationNode", + display_name="Google Veo 3 Video Generation", + category="api node/video/Veo", + description="Generates videos from text prompts using Google's Veo 3 API", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + optional=True, + ), + comfy_io.Int.Input( + "duration_seconds", + default=8, + min=8, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)", + optional=True, + ), + comfy_io.Boolean.Input( + "enhance_prompt", + default=True, + tooltip="Whether to enhance the prompt with AI assistance", + optional=True, + ), + comfy_io.Combo.Input( + "person_generation", + options=["ALLOW", "BLOCK"], + default="ALLOW", + tooltip="Whether to allow generating people in the video", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Image.Input( + "image", + tooltip="Optional reference image to guide video generation", + optional=True, + ), + comfy_io.Combo.Input( + "model", + options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"], + default="veo-3.0-generate-001", + tooltip="Veo 3 model to use for video generation", + optional=True, + ), + comfy_io.Boolean.Input( + "generate_audio", + default=False, + tooltip="Generate audio for the video. Supported by all Veo 3 models.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - # Add generateAudio parameter - parent_input["optional"]["generate_audio"] = ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Generate audio for the video. Supported by all Veo 3 models.", - } - ) - # Update duration constraints for Veo 3 (only 8 seconds supported) - parent_input["optional"]["duration_seconds"] = ( - IO.INT, - { - "default": 8, - "min": 8, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)", - }, - ) +class VeoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + VeoVideoGenerationNode, + Veo3VideoGenerationNode, + ] - return parent_input - - -# Register the nodes -NODE_CLASS_MAPPINGS = { - "VeoVideoGenerationNode": VeoVideoGenerationNode, - "Veo3VideoGenerationNode": Veo3VideoGenerationNode, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "VeoVideoGenerationNode": "Google Veo 2 Video Generation", - "Veo3VideoGenerationNode": "Google Veo 3 Video Generation", -} +async def comfy_entrypoint() -> VeoExtension: + return VeoExtension() diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index cbfec15a2..1409233c9 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -1,49 +1,63 @@ import torch +from typing_extensions import override + import comfy.model_management import node_helpers +from comfy_api.latest import ComfyExtension, io -class TextEncodeAceStepAudio: + +class TextEncodeAceStepAudio(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "tags": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="TextEncodeAceStepAudio", + category="conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("tags", multiline=True, dynamic_prompts=True), + io.String.Input("lyrics", multiline=True, dynamic_prompts=True), + io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "conditioning" - - def encode(self, clip, tags, lyrics, lyrics_strength): + @classmethod + def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics) conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) - return (conditioning, ) + return io.NodeOutput(conditioning) -class EmptyAceStepLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyAceStepLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyAceStepLatentAudio", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds, batch_size): + def execute(cls, seconds, batch_size) -> io.NodeOutput: length = int(seconds * 44100 / 512 / 8) - latent = torch.zeros([batch_size, 8, 16, length], device=self.device) - return ({"samples": latent, "type": "audio"}, ) + latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) -NODE_CLASS_MAPPINGS = { - "TextEncodeAceStepAudio": TextEncodeAceStepAudio, - "EmptyAceStepLatentAudio": EmptyAceStepLatentAudio, -} +class AceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeAceStepAudio, + EmptyAceStepLatentAudio, + ] + +async def comfy_entrypoint() -> AceExtension: + return AceExtension() diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 5fbb096fb..5532ffe6a 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -1,8 +1,13 @@ +import numpy as np +import torch +from tqdm.auto import trange +from typing_extensions import override + +import comfy.model_patcher import comfy.samplers import comfy.utils -import torch -import numpy as np -from tqdm.auto import trange +from comfy.k_diffusion.sampling import to_d +from comfy_api.latest import ComfyExtension, io @torch.no_grad() @@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable return x -class SamplerLCMUpscale: - upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] +class SamplerLCMUpscale(io.ComfyNode): + UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] @classmethod - def INPUT_TYPES(s): - return {"required": - {"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}), - "scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}), - "upscale_method": (s.upscale_methods,), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerLCMUpscale", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01), + io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1), + io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), + ], + outputs=[io.Sampler.Output()], + ) - FUNCTION = "get_sampler" - - def get_sampler(self, scale_ratio, scale_steps, upscale_method): + @classmethod + def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput: if scale_steps < 0: scale_steps = None sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method}) - return (sampler, ) + return io.NodeOutput(sampler) -from comfy.k_diffusion.sampling import to_d -import comfy.model_patcher @torch.no_grad() def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): @@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No return x -class SamplerEulerCFGpp: +class SamplerEulerCFGpp(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"version": (["regular", "alternative"],),} - } - RETURN_TYPES = ("SAMPLER",) - # CATEGORY = "sampling/custom_sampling/samplers" - CATEGORY = "_for_testing" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerEulerCFGpp", + display_name="SamplerEulerCFG++", + category="_for_testing", # "sampling/custom_sampling/samplers" + inputs=[ + io.Combo.Input("version", options=["regular", "alternative"]), + ], + outputs=[io.Sampler.Output()], + is_experimental=True, + ) - FUNCTION = "get_sampler" - - def get_sampler(self, version): + @classmethod + def execute(cls, version) -> io.NodeOutput: if version == "alternative": sampler = comfy.samplers.KSAMPLER(sample_euler_pp) else: sampler = comfy.samplers.ksampler("euler_cfg_pp") - return (sampler, ) + return io.NodeOutput(sampler) -NODE_CLASS_MAPPINGS = { - "SamplerLCMUpscale": SamplerLCMUpscale, - "SamplerEulerCFGpp": SamplerEulerCFGpp, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerCFGpp": "SamplerEulerCFG++", -} +class AdvancedSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerLCMUpscale, + SamplerEulerCFGpp, + ] + +async def comfy_entrypoint() -> AdvancedSamplersExtension: + return AdvancedSamplersExtension() diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index 25b21b1b8..f27ae7da8 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -1,4 +1,8 @@ import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def project(v0, v1): v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) @@ -6,22 +10,45 @@ def project(v0, v1): v0_orthogonal = v0 - v0_parallel return v0_parallel, v0_orthogonal -class APG: +class APG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), - "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), - "momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - CATEGORY = "sampling/custom_sampling" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="APG", + display_name="Adaptive Projected Guidance", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "eta", + default=1.0, + min=-10.0, + max=10.0, + step=0.01, + tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.", + ), + io.Float.Input( + "norm_threshold", + default=5.0, + min=0.0, + max=50.0, + step=0.1, + tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.", + ), + io.Float.Input( + "momentum", + default=0.0, + min=-5.0, + max=1.0, + step=0.01, + tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.", + ), + ], + outputs=[io.Model.Output()], + ) - def patch(self, model, eta, norm_threshold, momentum): + @classmethod + def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput: running_avg = 0 prev_sigma = None @@ -65,12 +92,15 @@ class APG: m = model.clone() m.set_model_sampler_pre_cfg_function(pre_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "APG": APG, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "APG": "Adaptive Projected Guidance", -} +class ApgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + APG, + ] + +async def comfy_entrypoint() -> ApgExtension: + return ApgExtension() diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index 4747eb395..c0e494c2a 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -1,3 +1,7 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def attention_multiply(attn, model, q, k, v, out): m = model.clone() @@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out): return m -class UNetSelfAttentionMultiply: +class UNetSelfAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetSelfAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, q, k, v, out): + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: m = attention_multiply("attn1", model, q, k, v, out) - return (m, ) + return io.NodeOutput(m) -class UNetCrossAttentionMultiply: + +class UNetCrossAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetCrossAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, q, k, v, out): + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: m = attention_multiply("attn2", model, q, k, v, out) - return (m, ) + return io.NodeOutput(m) -class CLIPAttentionMultiply: + +class CLIPAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip": ("CLIP",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Clip.Input("clip"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, clip, q, k, v, out): + @classmethod + def execute(cls, clip, q, k, v, out) -> io.NodeOutput: m = clip.clone() sd = m.patcher.model_state_dict() @@ -79,23 +97,28 @@ class CLIPAttentionMultiply: m.add_patches({key: (None,)}, 0.0, v) if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"): m.add_patches({key: (None,)}, 0.0, out) - return (m, ) + return io.NodeOutput(m) -class UNetTemporalAttentionMultiply: + +class UNetTemporalAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetTemporalAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal): + @classmethod + def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput: m = model.clone() sd = model.model_state_dict() @@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply: m.add_patches({k: (None,)}, 0.0, cross_temporal) else: m.add_patches({k: (None,)}, 0.0, cross_structural) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply, - "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply, - "CLIPAttentionMultiply": CLIPAttentionMultiply, - "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply, -} + +class AttentionMultiplyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + UNetSelfAttentionMultiply, + UNetCrossAttentionMultiply, + CLIPAttentionMultiply, + UNetTemporalAttentionMultiply, + ] + +async def comfy_entrypoint() -> AttentionMultiplyExtension: + return AttentionMultiplyExtension() diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index b1a8ceef0..571d89f62 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -1,77 +1,91 @@ import re +from typing_extensions import override -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import ComfyExtension, io -class StringConcatenate(): + +class StringConcatenate(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "delimiter": (IO.STRING, {"multiline": False, "default": ""}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringConcatenate", + display_name="Concatenate", + category="utils/string", + inputs=[ + io.String.Input("string_a", multiline=True), + io.String.Input("string_b", multiline=True), + io.String.Input("delimiter", multiline=False, default=""), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, delimiter, **kwargs): - return delimiter.join((string_a, string_b)), - -class StringSubstring(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "start": (IO.INT, {}), - "end": (IO.INT, {}), - } - } + def execute(cls, string_a, string_b, delimiter): + return io.NodeOutput(delimiter.join((string_a, string_b))) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, start, end, **kwargs): - return string[start:end], - -class StringLength(): +class StringSubstring(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringSubstring", + display_name="Substring", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Int.Input("start"), + io.Int.Input("end"), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.INT,) - RETURN_NAMES = ("length",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, **kwargs): - length = len(string) - - return length, - -class CaseConverter(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]}) - } - } + def execute(cls, string, start, end): + return io.NodeOutput(string[start:end]) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, mode, **kwargs): +class StringLength(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StringLength", + display_name="Length", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + ], + outputs=[ + io.Int.Output(display_name="length"), + ] + ) + + @classmethod + def execute(cls, string): + return io.NodeOutput(len(string)) + + +class CaseConverter(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CaseConverter", + display_name="Case Converter", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]), + ], + outputs=[ + io.String.Output(), + ] + ) + + @classmethod + def execute(cls, string, mode): if mode == "UPPERCASE": result = string.upper() elif mode == "lowercase": @@ -83,24 +97,27 @@ class CaseConverter(): else: result = string - return result, + return io.NodeOutput(result) -class StringTrim(): +class StringTrim(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringTrim", + display_name="Trim", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Combo.Input("mode", options=["Both", "Left", "Right"]), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, mode, **kwargs): + @classmethod + def execute(cls, string, mode): if mode == "Both": result = string.strip() elif mode == "Left": @@ -110,70 +127,78 @@ class StringTrim(): else: result = string - return result, + return io.NodeOutput(result) -class StringReplace(): + +class StringReplace(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "find": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringReplace", + display_name="Replace", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("find", multiline=True), + io.String.Input("replace", multiline=True), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, find, replace, **kwargs): - result = string.replace(find, replace) - return result, - - -class StringContains(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "substring": (IO.STRING, {"multiline": True}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } + def execute(cls, string, find, replace): + return io.NodeOutput(string.replace(find, replace)) - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("contains",) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, substring, case_sensitive, **kwargs): +class StringContains(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StringContains", + display_name="Contains", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("substring", multiline=True), + io.Boolean.Input("case_sensitive", default=True), + ], + outputs=[ + io.Boolean.Output(display_name="contains"), + ] + ) + + @classmethod + def execute(cls, string, substring, case_sensitive): if case_sensitive: contains = substring in string else: contains = substring.lower() in string.lower() - return contains, + return io.NodeOutput(contains) -class StringCompare(): +class StringCompare(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringCompare", + display_name="Compare", + category="utils/string", + inputs=[ + io.String.Input("string_a", multiline=True), + io.String.Input("string_b", multiline=True), + io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]), + io.Boolean.Input("case_sensitive", default=True), + ], + outputs=[ + io.Boolean.Output(), + ] + ) - RETURN_TYPES = (IO.BOOLEAN,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, mode, case_sensitive, **kwargs): + @classmethod + def execute(cls, string_a, string_b, mode, case_sensitive): if case_sensitive: a = string_a b = string_b @@ -182,31 +207,34 @@ class StringCompare(): b = string_b.lower() if mode == "Equal": - return a == b, + return io.NodeOutput(a == b) elif mode == "Starts With": - return a.startswith(b), + return io.NodeOutput(a.startswith(b)) elif mode == "Ends With": - return a.endswith(b), + return io.NodeOutput(a.endswith(b)) -class RegexMatch(): + +class RegexMatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}) - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexMatch", + display_name="Regex Match", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.Boolean.Input("case_insensitive", default=True), + io.Boolean.Input("multiline", default=False), + io.Boolean.Input("dotall", default=False), + ], + outputs=[ + io.Boolean.Output(display_name="matches"), + ] + ) - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("matches",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall): flags = 0 if case_insensitive: @@ -223,29 +251,32 @@ class RegexMatch(): except re.error: result = False - return result, + return io.NodeOutput(result) -class RegexExtract(): +class RegexExtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}), - "group_index": (IO.INT, {"default": 1, "min": 0, "max": 100}) - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexExtract", + display_name="Regex Extract", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]), + io.Boolean.Input("case_insensitive", default=True), + io.Boolean.Input("multiline", default=False), + io.Boolean.Input("dotall", default=False), + io.Int.Input("group_index", default=1, min=0, max=100), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index): join_delimiter = "\n" flags = 0 @@ -294,32 +325,33 @@ class RegexExtract(): except re.error: result = "" - return result, + return io.NodeOutput(result) -class RegexReplace(): - DESCRIPTION = "Find and replace text using regex patterns." +class RegexReplace(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}), - }, - "optional": { - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}), - "count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexReplace", + display_name="Regex Replace", + category="utils/string", + description="Find and replace text using regex patterns.", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.String.Input("replace", multiline=True), + io.Boolean.Input("case_insensitive", default=True, optional=True), + io.Boolean.Input("multiline", default=False, optional=True), + io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."), + io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0): flags = 0 if case_insensitive: @@ -329,32 +361,25 @@ class RegexReplace(): if dotall: flags |= re.DOTALL result = re.sub(regex_pattern, replace, string, count=count, flags=flags) - return result, + return io.NodeOutput(result) -NODE_CLASS_MAPPINGS = { - "StringConcatenate": StringConcatenate, - "StringSubstring": StringSubstring, - "StringLength": StringLength, - "CaseConverter": CaseConverter, - "StringTrim": StringTrim, - "StringReplace": StringReplace, - "StringContains": StringContains, - "StringCompare": StringCompare, - "RegexMatch": RegexMatch, - "RegexExtract": RegexExtract, - "RegexReplace": RegexReplace, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "StringConcatenate": "Concatenate", - "StringSubstring": "Substring", - "StringLength": "Length", - "CaseConverter": "Case Converter", - "StringTrim": "Trim", - "StringReplace": "Replace", - "StringContains": "Contains", - "StringCompare": "Compare", - "RegexMatch": "Regex Match", - "RegexExtract": "Regex Extract", - "RegexReplace": "Regex Replace", -} +class StringExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StringConcatenate, + StringSubstring, + StringLength, + CaseConverter, + StringTrim, + StringReplace, + StringContains, + StringCompare, + RegexMatch, + RegexExtract, + RegexReplace, + ] + +async def comfy_entrypoint() -> StringExtension: + return StringExtension()