From deebee4ff6fd1b2713683d22e5e2e07170daa867 Mon Sep 17 00:00:00 2001 From: guill Date: Thu, 14 Aug 2025 18:46:55 -0700 Subject: [PATCH 01/14] Update default parameters for Moonvalley video nodes (#9290) * Update default parameters for Moonvalley video nodes - Changed default negative prompts to a more extensive list for both BaseMoonvalleyVideoNode and MoonvalleyVideo2VideoNode. - Updated default guidance scale values for both nodes to enhance prompt adherence. - Set a fixed default seed value for consistency in video generation. * no message * ruff fix --------- Co-authored-by: thorsten --- comfy_api_nodes/nodes_moonvalley.py | 128 ++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 37 deletions(-) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 164ca3ea5..806a70e06 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,6 +1,5 @@ import logging from typing import Any, Callable, Optional, TypeVar -import random import torch from comfy_api_nodes.util.validation_utils import ( get_image_dimensions, @@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: def _validate_video_dimensions(width: int, height: int) -> None: """Validates video dimensions meet Moonvalley V2V requirements.""" supported_resolutions = { - (1920, 1080), (1080, 1920), (1152, 1152), - (1536, 1152), (1152, 1536) + (1920, 1080), + (1080, 1920), + (1152, 1152), + (1536, 1152), + (1152, 1536), } if (width, height) not in supported_resolutions: - supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)]) - raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") + supported_list = ", ".join( + [f"{w}x{h}" for w, h in sorted(supported_resolutions)] + ) + raise ValueError( + f"Resolution {width}x{height} not supported. Supported: {supported_list}" + ) def _validate_container_format(video: VideoInput) -> None: """Validates video container format is MP4.""" container_format = video.get_container_format() - if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']: - raise ValueError(f"Only MP4 container format supported. Got: {container_format}") + if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: + raise ValueError( + f"Only MP4 container format supported. Got: {container_format}" + ) def _validate_and_trim_duration(video: VideoInput) -> VideoInput: @@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: return video - def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: """ Returns a new VideoInput object trimmed from the beginning to the specified duration, @@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: # Calculate target frame count that's divisible by 16 fps = input_container.streams.video[0].average_rate estimated_frames = int(duration_sec * fps) - target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 + target_frames = ( + estimated_frames // 16 + ) * 16 # Round down to nearest multiple of 16 if target_frames == 0: raise ValueError("Video too short: need at least 16 frames for Moonvalley") @@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode: MoonvalleyTextToVideoInferenceParams, "negative_prompt", multiline=True, - default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts", + default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", ), "resolution": ( IO.COMBO, @@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode: "tooltip": "Resolution of the output video", }, ), - # "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}), "prompt_adherence": model_field_to_node_input( IO.FLOAT, MoonvalleyTextToVideoInferenceParams, "guidance_scale", - default=7.0, + default=10.0, step=1, min=1, max=20, @@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode: IO.INT, MoonvalleyTextToVideoInferenceParams, "seed", - default=random.randint(0, 2**32 - 1), + default=9, min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", - control_after_generate=True, ), "steps": model_field_to_node_input( IO.INT, @@ -532,9 +539,11 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - image_url = (await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type - ))[0] + image_url = ( + await upload_images_to_comfyapi( + image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type + ) + )[0] request = MoonvalleyTextToVideoRequest( image_url=image_url, prompt_text=prompt, inference_params=inference_params @@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): return { "required": { "prompt": model_field_to_node_input( - IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text", - multiline=True + IO.STRING, + MoonvalleyVideoToVideoRequest, + "prompt_text", + multiline=True, ), "negative_prompt": model_field_to_node_input( IO.STRING, MoonvalleyVideoToVideoInferenceParams, "negative_prompt", multiline=True, - default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts" + default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", + ), + "seed": model_field_to_node_input( + IO.INT, + MoonvalleyVideoToVideoInferenceParams, + "seed", + default=9, + min=0, + max=4294967295, + step=1, + display="number", + tooltip="Random seed value", + control_after_generate=False, + ), + "prompt_adherence": model_field_to_node_input( + IO.FLOAT, + MoonvalleyVideoToVideoInferenceParams, + "guidance_scale", + default=10.0, + step=1, + min=1, + max=20, ), - "seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", @@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): "unique_id": "UNIQUE_ID", }, "optional": { - "video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}), + "video": ( + IO.VIDEO, + { + "default": "", + "multiline": False, + "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", + }, + ), "control_type": ( ["Motion Transfer", "Pose Transfer"], {"default": "Motion Transfer"}, @@ -602,8 +640,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): "max": 100, "tooltip": "Only used if control_type is 'Motion Transfer'", }, - ) - } + ), + "image": model_field_to_node_input( + IO.IMAGE, + MoonvalleyTextToVideoRequest, + "image_url", + tooltip="The reference image used to generate the video", + ), + }, } RETURN_TYPES = ("VIDEO",) @@ -613,6 +657,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): video = kwargs.get("video") + image = kwargs.get("image", None) if not video: raise MoonvalleyApiError("video is required") @@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): video_url = "" if video: validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi( + validated_video, auth_kwargs=kwargs + ) + mime_type = "image/png" + if not image is None: + validate_input_image(image, with_frame_conditioning=True) + image_url = await upload_images_to_comfyapi( + image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type + ) control_type = kwargs.get("control_type") motion_intensity = kwargs.get("motion_intensity") @@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): # Only include motion_intensity for Motion Transfer control_params = {} if control_type == "Motion Transfer" and motion_intensity is not None: - control_params['motion_intensity'] = motion_intensity + control_params["motion_intensity"] = motion_intensity - inference_params=MoonvalleyVideoToVideoInferenceParams( + inference_params = MoonvalleyVideoToVideoInferenceParams( negative_prompt=negative_prompt, seed=kwargs.get("seed"), - control_params=control_params + control_params=control_params, ) control = self.parseControlParameter(control_type) @@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): prompt_text=prompt, inference_params=inference_params, ) + request.image_url = image_url if not image is None else None initial_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -694,15 +748,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) - inference_params=MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), - num_frames=128, - width=width_height.get("width"), - height=width_height.get("height"), - ) + inference_params = MoonvalleyTextToVideoInferenceParams( + negative_prompt=negative_prompt, + steps=kwargs.get("steps"), + seed=kwargs.get("seed"), + guidance_scale=kwargs.get("prompt_adherence"), + num_frames=128, + width=width_height.get("width"), + height=width_height.get("height"), + ) request = MoonvalleyTextToVideoRequest( prompt_text=prompt, inference_params=inference_params ) From 5d65d6753b195d674ce16522d6c34f9a33f36269 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 15 Aug 2025 04:48:41 +0300 Subject: [PATCH 02/14] convert WAN nodes to V3 schema (#9201) --- comfy_extras/nodes_wan.py | 549 +++++++++++++++++++++----------------- 1 file changed, 298 insertions(+), 251 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index f80c83ba6..694a183f6 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -9,29 +9,35 @@ import comfy.clip_vision import json import numpy as np from typing import Tuple +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class WanImageToVideo: +class WanImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -51,32 +57,36 @@ class WanImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunControlToVideo: +class WanFunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -101,31 +111,34 @@ class WanFunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class Wan22FunControlToVideo: +class Wan22FunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"ref_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - # "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22FunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -158,32 +171,36 @@ class Wan22FunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFirstLastFrameToVideo: +class WanFirstLastFrameToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), - "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFirstLastFrameToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_end_image", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -224,62 +241,70 @@ class WanFirstLastFrameToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunInpaintToVideo: +class WanFunInpaintToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunInpaintToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput: flfv = WanFirstLastFrameToVideo() - return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) + return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) -class WanVaceToVideo: +class WanVaceToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - }, - "optional": {"control_video": ("IMAGE", ), - "control_masks": ("MASK", ), - "reference_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanVaceToVideo", + category="conditioning/video_models", + is_experimental=True, + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01), + io.Image.Input("control_video", optional=True), + io.Mask.Input("control_masks", optional=True), + io.Image.Input("reference_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") - RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - EXPERIMENTAL = True - - def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput: latent_length = ((length - 1) // 4) + 1 if control_video is not None: control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -336,52 +361,59 @@ class WanVaceToVideo: latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent, trim_latent) + return io.NodeOutput(positive, negative, out_latent, trim_latent) -class TrimVideoLatent: +class TrimVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), - }} + def define_schema(cls): + return io.Schema( + node_id="TrimVideoLatent", + category="latent/video", + is_experimental=True, + inputs=[ + io.Latent.Input("samples"), + io.Int.Input("trim_amount", default=0, min=0, max=99999), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/video" - - EXPERIMENTAL = True - - def op(self, samples, trim_amount): + @classmethod + def execute(cls, samples, trim_amount) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1[:, :, trim_amount:] - return (samples_out,) + return io.NodeOutput(samples_out) -class WanCameraImageToVideo: +class WanCameraImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "camera_conditions": ("WAN_CAMERA_EMBEDDING", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanCameraImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.WanCameraEmbedding.Input("camera_conditions", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -404,29 +436,34 @@ class WanCameraImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanPhantomSubjectToVideo: +class WanPhantomSubjectToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"images": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("images", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative_text"), + io.Conditioning.Output(display_name="negative_img_text"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, images): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) cond2 = negative if images is not None: @@ -442,7 +479,7 @@ class WanPhantomSubjectToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, cond2, negative, out_latent) + return io.NodeOutput(positive, cond2, negative, out_latent) def parse_json_tracks(tracks): """Parse JSON track data into a standardized format""" @@ -655,39 +692,41 @@ def patch_motion( return out_mask_full, out_feature_full -class WanTrackToVideo: +class WanTrackToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "tracks": ("STRING", {"multiline": True, "default": "[]"}), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}), - "topk": ("INT", {"default": 2, "min": 1, "max": 10}), - "start_image": ("IMAGE", ), - }, - "optional": { - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.String.Input("tracks", multiline=True, default="[]"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input("topk", default=2, min=1, max=10), + io.Image.Input("start_image"), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, tracks, width, height, length, batch_size, - temperature, topk, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size, + temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput: tracks_data = parse_json_tracks(tracks) if not tracks_data: - return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) + return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) @@ -741,34 +780,36 @@ class WanTrackToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class Wan22ImageToVideoLatent: +class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Latent.Output(), + ], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) if start_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -783,19 +824,25 @@ class Wan22ImageToVideoLatent: latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -NODE_CLASS_MAPPINGS = { - "WanTrackToVideo": WanTrackToVideo, - "WanImageToVideo": WanImageToVideo, - "WanFunControlToVideo": WanFunControlToVideo, - "Wan22FunControlToVideo": Wan22FunControlToVideo, - "WanFunInpaintToVideo": WanFunInpaintToVideo, - "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, - "WanVaceToVideo": WanVaceToVideo, - "TrimVideoLatent": TrimVideoLatent, - "WanCameraImageToVideo": WanCameraImageToVideo, - "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, - "Wan22ImageToVideoLatent": Wan22ImageToVideoLatent, -} +class WanExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanTrackToVideo, + WanImageToVideo, + WanFunControlToVideo, + Wan22FunControlToVideo, + WanFunInpaintToVideo, + WanFirstLastFrameToVideo, + WanVaceToVideo, + TrimVideoLatent, + WanCameraImageToVideo, + WanPhantomSubjectToVideo, + Wan22ImageToVideoLatent, + ] + +async def comfy_entrypoint() -> WanExtension: + return WanExtension() From ad19a069f68a19566632b9bda3e72f4eed8a22d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:16:01 -0700 Subject: [PATCH 03/14] Make SLG nodes work on Qwen Image model. (#9345) --- comfy/ldm/qwen_image/model.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index c15ab8e40..99843f88d 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -356,6 +356,7 @@ class QwenImageTransformer2DModel(nn.Module): context, attention_mask=None, guidance: torch.Tensor = None, + transformer_options={}, **kwargs ): timestep = timesteps @@ -383,14 +384,26 @@ class QwenImageTransformer2DModel(nn.Module): else self.time_text_embed(timestep, guidance, hidden_states) ) - for block in self.transformer_blocks: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From f0d5d0111f1f78bc8ce5d1f3968f19e40cd2ce7b Mon Sep 17 00:00:00 2001 From: "Xiangxi Guo (Ryan)" Date: Thu, 14 Aug 2025 20:41:37 -0700 Subject: [PATCH 04/14] Avoid torch compile graphbreak for older pytorch versions (#9344) Turns out torch.compile has some gaps in context manager decorator syntax support. I've sent patches to fix that in PyTorch, but it won't be available for all the folks running older versions of PyTorch, hence this trivial patch. --- comfy/ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index be312d714..2be35f76a 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -41,9 +41,11 @@ try: SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) def scaled_dot_product_attention(q, k, v, *args, **kwargs): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + # Use this (rather than the decorator syntax) to eliminate graph + # break for pytorch < 2.9 + with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") From 4e5c230f6a957962961794c07f02be748076c771 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:44:02 -0700 Subject: [PATCH 05/14] Fix last commit not working on older pytorch. (#9346) --- comfy/ops.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 2be35f76a..18e7db705 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -32,20 +32,21 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): try: if torch.cuda.is_available(): from torch.nn.attention import SDPBackend, sdpa_kernel + import inspect + if "set_priority" in inspect.signature(sdpa_kernel).parameters: + SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] - SDPA_BACKEND_PRIORITY = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - - def scaled_dot_product_attention(q, k, v, *args, **kwargs): - # Use this (rather than the decorator syntax) to eliminate graph - # break for pytorch < 2.9 - with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + else: + logging.warning("Torch version too old to set sdpa backend priority.") except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") From e08ecfbd8a9deda8939b14d7f1ff7d7139f1a4ed Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 21:22:26 -0700 Subject: [PATCH 06/14] Add warning when using old pytorch. (#9347) --- comfy/rmsnorm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 66ae8321d..555542a46 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import numbers +import logging RMSNorm = None @@ -9,6 +10,7 @@ try: RMSNorm = torch.nn.RMSNorm except: rms_norm_torch = None + logging.warning("Please update pytorch to use native RMSNorm") def rms_norm(x, weight=None, eps=1e-6): From 027c63f63a7f5f380a4df1057c548410b0a87606 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 15 Aug 2025 21:57:47 +0300 Subject: [PATCH 07/14] fix(OpenAIGPTImage1): set correct MIME type for multipart uploads to OpenAI edits (#9348) --- comfy_api_nodes/nodes_openai.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index ab3c5363b..cbff2b2d2 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -464,8 +464,6 @@ class OpenAIGPTImage1(ComfyNodeABC): path = "/proxy/openai/images/generations" content_type = "application/json" request_class = OpenAIImageGenerationRequest - img_binaries = [] - mask_binary = None files = [] if image is not None: @@ -484,14 +482,11 @@ class OpenAIGPTImage1(ComfyNodeABC): img_byte_arr = io.BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) - img_binary = img_byte_arr - img_binary.name = f"image_{i}.png" - img_binaries.append(img_binary) if batch_size == 1: - files.append(("image", img_binary)) + files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) else: - files.append(("image[]", img_binary)) + files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) if mask is not None: if image is None: @@ -511,9 +506,7 @@ class OpenAIGPTImage1(ComfyNodeABC): mask_img_byte_arr = io.BytesIO() mask_img.save(mask_img_byte_arr, format="PNG") mask_img_byte_arr.seek(0) - mask_binary = mask_img_byte_arr - mask_binary.name = "mask.png" - files.append(("mask", mask_binary)) + files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) # Build the operation operation = SynchronousOperation( From c308a8840aebf06649364e8e175862250a2d8823 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:50:39 -0700 Subject: [PATCH 08/14] Add FluxKontextMultiReferenceLatentMethod node. (#9356) This node is only useful if someone trains the kontext model to properly use multiple reference images via the index method. The default is the offset method which feeds the multiple images like if they were stitched together as one. This method works with the current flux kontext model. --- comfy/ldm/flux/model.py | 24 ++++++++++++++++-------- comfy/model_base.py | 4 ++++ comfy_extras/nodes_flux.py | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8f4d99f54..c4de82795 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -224,19 +224,27 @@ class Flux(nn.Module): if ref_latents is not None: h = 0 w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "offset") == "index" for ref in ref_latents: - h_offset = 0 - w_offset = 0 - if ref.shape[-2] + h > ref.shape[-1] + w: - w_offset = w + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 else: - h_offset = h + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - h = max(h, ref.shape[-2] + h_offset) - w = max(w, ref.shape[-1] + w_offset) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) diff --git a/comfy/model_base.py b/comfy/model_base.py index cde61df7c..bf874b875 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -890,6 +890,10 @@ class Flux(BaseModel): for lat in ref_latents: latents.append(self.process_latent_in(lat)) out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 8a8a17698..c8db75bb3 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -100,9 +100,28 @@ class FluxKontextImageScale: return (image, ) +class FluxKontextMultiReferenceLatentMethod: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning": ("CONDITIONING", ), + "reference_latents_method": (("offset", "index"), ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + EXPERIMENTAL = True + + CATEGORY = "advanced/conditioning/flux" + + def append(self, conditioning, reference_latents_method): + c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) + return (c, ) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeFlux": CLIPTextEncodeFlux, "FluxGuidance": FluxGuidance, "FluxDisableGuidance": FluxDisableGuidance, "FluxKontextImageScale": FluxKontextImageScale, + "FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod, } From 1702e6df16b0a52e147f19e3d5c5548c25a64339 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:29:58 -0700 Subject: [PATCH 09/14] Implement wan2.2 camera model. (#9357) Use the old WanCameraImageToVideo node. --- comfy/ldm/wan/model.py | 7 ++++++- comfy/model_detection.py | 5 ++++- comfy/supported_models.py | 14 +++++++++++++- comfy_extras/nodes_wan.py | 7 +++++-- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4e2d99566..9d3741be3 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -768,7 +768,12 @@ class CameraWanModel(WanModel): operations=None, ): - super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + if model_type == 'camera': + model_type = 'i2v' + else: + model_type = 't2v' + + super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8acc51e20..2bec0541e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: - dit_config["model_type"] = "camera" + if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "camera" + else: + dit_config["model_type"] = "camera_2.2" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 156ff9e26..7ed6dfd69 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V): def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21_Camera(self, image_to_video=False, device=device) return out + +class WAN22_Camera(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "camera_2.2", + "in_dim": 36, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Camera(self, image_to_video=False, device=device) + return out + class WAN21_Vace(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1260,6 +1272,6 @@ class QwenImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 694a183f6..83a990688 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -422,9 +422,12 @@ class WanCameraImageToVideo(io.ComfyNode): start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + mask[:, :, :start_image.shape[0] + 3] = 0.0 + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask}) if camera_conditions is not None: positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) From ed2e33c69a291094c4fcc13d8426c49844a6363c Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 15 Aug 2025 20:32:58 -0700 Subject: [PATCH 10/14] bump frontend version to 1.25.8 (#9361) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 551002b5b..2ae44ebe1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.24.4 +comfyui-frontend-package==1.25.8 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From 20a84166d0d37dd6833caa6cadf3bfac8c241b48 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Sat, 16 Aug 2025 02:07:12 -0400 Subject: [PATCH 11/14] record audio node (#8716) * record audio node * sf --- comfy_extras/nodes_audio.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index a90b31779..3b23f65d8 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -346,6 +346,24 @@ class LoadAudio: return "Invalid audio file: {}".format(audio) return True +class RecordAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": {"audio": ("AUDIO_RECORD", {})}} + + CATEGORY = "audio" + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + def load(self, audio): + audio_path = folder_paths.get_annotated_filepath(audio) + + waveform, sample_rate = torchaudio.load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return (audio, ) + + NODE_CLASS_MAPPINGS = { "EmptyLatentAudio": EmptyLatentAudio, "VAEEncodeAudio": VAEEncodeAudio, @@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = { "LoadAudio": LoadAudio, "PreviewAudio": PreviewAudio, "ConditioningStableAudio": ConditioningStableAudio, + "RecordAudio": RecordAudio, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveAudio": "Save Audio (FLAC)", "SaveAudioMP3": "Save Audio (MP3)", "SaveAudioOpus": "Save Audio (Opus)", + "RecordAudio": "Record Audio", } From 0f2b8525bcafe213e8421a49564a90f926e81f2e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 16 Aug 2025 14:51:28 -0700 Subject: [PATCH 12/14] Qwen image model refactor. (#9375) --- comfy/ldm/qwen_image/model.py | 36 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 99843f88d..40d8fd979 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module): self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.gradient_checkpointing = False - def pos_embeds(self, x, context): + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) - txt_start = round(max(h_len, w_len)) - txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape def forward( self, @@ -363,13 +367,13 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = context encoder_hidden_states_mask = attention_mask - image_rotary_emb = self.pos_embeds(x, context) + hidden_states, img_ids, orig_shape = self.process_img(x) + num_embeds = hidden_states.shape[1] - hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) - orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) + txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -408,6 +412,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] From ed43784b0d04e5b8e8ff2c057fa84b9c5132aaf2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 17 Aug 2025 13:45:39 -0700 Subject: [PATCH 13/14] WIP Qwen edit model: The diffusion model part. (#9383) --- comfy/ldm/qwen_image/model.py | 26 ++++++++++++++++++++++++++ comfy/model_base.py | 10 ++++++++++ 2 files changed, 36 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 40d8fd979..a3c726299 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -360,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module): context, attention_mask=None, guidance: torch.Tensor = None, + ref_latents=None, transformer_options={}, **kwargs ): @@ -370,6 +371,31 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "index") == "index" + for ref in ref_latents: + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + hidden_states = torch.cat([hidden_states, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) diff --git a/comfy/model_base.py b/comfy/model_base.py index bf874b875..15bd7abef 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1331,4 +1331,14 @@ class QwenImage(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out From d4e353a94ec5a8cb15ed151990a9518b890e5d4f Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Mon, 18 Aug 2025 05:38:40 +0800 Subject: [PATCH 14/14] Update template to 0.1.60 (#9377) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2ae44ebe1..72a700028 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.8 -comfyui-workflow-templates==0.1.59 +comfyui-workflow-templates==0.1.60 comfyui-embedded-docs==0.2.6 torch torchsde