From fa7553138e3c75befe6aaf988048d4a0a95c1a32 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 1 May 2026 21:09:25 +0300 Subject: [PATCH 01/25] chore(api-nodes): remove Moonvalley API nodes (#13659) Signed-off-by: bigcat88 --- comfy_api_nodes/apis/moonvalley.py | 152 -------- comfy_api_nodes/nodes_moonvalley.py | 534 ---------------------------- 2 files changed, 686 deletions(-) delete mode 100644 comfy_api_nodes/apis/moonvalley.py delete mode 100644 comfy_api_nodes/nodes_moonvalley.py diff --git a/comfy_api_nodes/apis/moonvalley.py b/comfy_api_nodes/apis/moonvalley.py deleted file mode 100644 index 7ec7a4ade..000000000 --- a/comfy_api_nodes/apis/moonvalley.py +++ /dev/null @@ -1,152 +0,0 @@ -from enum import Enum -from typing import Optional, Dict, Any - -from pydantic import BaseModel, Field, StrictBytes - - -class MoonvalleyPromptResponse(BaseModel): - error: Optional[Dict[str, Any]] = None - frame_conditioning: Optional[Dict[str, Any]] = None - id: Optional[str] = None - inference_params: Optional[Dict[str, Any]] = None - meta: Optional[Dict[str, Any]] = None - model_params: Optional[Dict[str, Any]] = None - output_url: Optional[str] = None - prompt_text: Optional[str] = None - status: Optional[str] = None - - -class MoonvalleyTextToVideoInferenceParams(BaseModel): - add_quality_guidance: Optional[bool] = Field( - True, description='Whether to add quality guidance' - ) - caching_coefficient: Optional[float] = Field( - 0.3, description='Caching coefficient for optimization' - ) - caching_cooldown: Optional[int] = Field( - 3, description='Number of caching cooldown steps' - ) - caching_warmup: Optional[int] = Field( - 3, description='Number of caching warmup steps' - ) - clip_value: Optional[float] = Field( - 3, description='CLIP value for generation control' - ) - conditioning_frame_index: Optional[int] = Field( - 0, description='Index of the conditioning frame' - ) - cooldown_steps: Optional[int] = Field( - 75, description='Number of cooldown steps (calculated based on num_frames)' - ) - fps: Optional[int] = Field( - 24, description='Frames per second of the generated video' - ) - guidance_scale: Optional[float] = Field( - 10, description='Guidance scale for generation control' - ) - height: Optional[int] = Field( - 1080, description='Height of the generated video in pixels' - ) - negative_prompt: Optional[str] = Field(None, description='Negative prompt text') - num_frames: Optional[int] = Field(64, description='Number of frames to generate') - seed: Optional[int] = Field( - None, description='Random seed for generation (default: random)' - ) - shift_value: Optional[float] = Field( - 3, description='Shift value for generation control' - ) - steps: Optional[int] = Field(80, description='Number of denoising steps') - use_guidance_schedule: Optional[bool] = Field( - True, description='Whether to use guidance scheduling' - ) - use_negative_prompts: Optional[bool] = Field( - False, description='Whether to use negative prompts' - ) - use_timestep_transform: Optional[bool] = Field( - True, description='Whether to use timestep transformation' - ) - warmup_steps: Optional[int] = Field( - 0, description='Number of warmup steps (calculated based on num_frames)' - ) - width: Optional[int] = Field( - 1920, description='Width of the generated video in pixels' - ) - - -class MoonvalleyTextToVideoRequest(BaseModel): - image_url: Optional[str] = None - inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None - prompt_text: Optional[str] = None - webhook_url: Optional[str] = None - - -class MoonvalleyUploadFileRequest(BaseModel): - file: Optional[StrictBytes] = None - - -class MoonvalleyUploadFileResponse(BaseModel): - access_url: Optional[str] = None - - -class MoonvalleyVideoToVideoInferenceParams(BaseModel): - add_quality_guidance: Optional[bool] = Field( - True, description='Whether to add quality guidance' - ) - caching_coefficient: Optional[float] = Field( - 0.3, description='Caching coefficient for optimization' - ) - caching_cooldown: Optional[int] = Field( - 3, description='Number of caching cooldown steps' - ) - caching_warmup: Optional[int] = Field( - 3, description='Number of caching warmup steps' - ) - clip_value: Optional[float] = Field( - 3, description='CLIP value for generation control' - ) - conditioning_frame_index: Optional[int] = Field( - 0, description='Index of the conditioning frame' - ) - cooldown_steps: Optional[int] = Field( - 36, description='Number of cooldown steps (calculated based on num_frames)' - ) - guidance_scale: Optional[float] = Field( - 15, description='Guidance scale for generation control' - ) - negative_prompt: Optional[str] = Field(None, description='Negative prompt text') - seed: Optional[int] = Field( - None, description='Random seed for generation (default: random)' - ) - shift_value: Optional[float] = Field( - 3, description='Shift value for generation control' - ) - steps: Optional[int] = Field(80, description='Number of denoising steps') - use_guidance_schedule: Optional[bool] = Field( - True, description='Whether to use guidance scheduling' - ) - use_negative_prompts: Optional[bool] = Field( - False, description='Whether to use negative prompts' - ) - use_timestep_transform: Optional[bool] = Field( - True, description='Whether to use timestep transformation' - ) - warmup_steps: Optional[int] = Field( - 24, description='Number of warmup steps (calculated based on num_frames)' - ) - - -class ControlType(str, Enum): - motion_control = 'motion_control' - pose_control = 'pose_control' - - -class MoonvalleyVideoToVideoRequest(BaseModel): - control_type: ControlType = Field( - ..., description='Supported types for video control' - ) - inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None - prompt_text: str = Field(..., description='Describes the video to generate') - video_url: str = Field(..., description='Url to control video') - webhook_url: Optional[str] = Field( - None, description='Optional webhook URL for notifications' - ) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py deleted file mode 100644 index 78a230529..000000000 --- a/comfy_api_nodes/nodes_moonvalley.py +++ /dev/null @@ -1,534 +0,0 @@ -import logging - -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.moonvalley import ( - MoonvalleyPromptResponse, - MoonvalleyTextToVideoInferenceParams, - MoonvalleyTextToVideoRequest, - MoonvalleyVideoToVideoInferenceParams, - MoonvalleyVideoToVideoRequest, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_video_output, - poll_op, - sync_op, - trim_video, - upload_images_to_comfyapi, - upload_video_to_comfyapi, - validate_container_format_is_mp4, - validate_image_dimensions, - validate_string, -) - -API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" -API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" -API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" -API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video" -API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video" - -MIN_WIDTH = 300 -MIN_HEIGHT = 300 - -MAX_WIDTH = 10000 -MAX_HEIGHT = 10000 - -MIN_VID_WIDTH = 300 -MIN_VID_HEIGHT = 300 - -MAX_VID_WIDTH = 10000 -MAX_VID_HEIGHT = 10000 - -MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing - -MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 - - -def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: - """Verifies that the initial response contains a task ID.""" - return bool(response.id) - - -def validate_task_creation_response(response) -> None: - if not is_valid_task_creation_response(response): - error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" - logging.error(error_msg) - raise RuntimeError(error_msg) - - -def validate_video_to_video_input(video: Input.Video) -> Input.Video: - """ - Validates and processes video input for Moonvalley Video-to-Video generation. - - Args: - video: Input video to validate - - Returns: - Validated and potentially trimmed video - - Raises: - ValueError: If video doesn't meet requirements - MoonvalleyApiError: If video duration is too short - """ - width, height = _get_video_dimensions(video) - _validate_video_dimensions(width, height) - validate_container_format_is_mp4(video) - - return _validate_and_trim_duration(video) - - -def _get_video_dimensions(video: Input.Video) -> tuple[int, int]: - """Extracts video dimensions with error handling.""" - try: - return video.get_dimensions() - except Exception as e: - logging.error("Error getting dimensions of video: %s", e) - raise ValueError(f"Cannot get video dimensions: {e}") from e - - -def _validate_video_dimensions(width: int, height: int) -> None: - """Validates video dimensions meet Moonvalley V2V requirements.""" - supported_resolutions = { - (1920, 1080), - (1080, 1920), - (1152, 1152), - (1536, 1152), - (1152, 1536), - } - - if (width, height) not in supported_resolutions: - supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)]) - raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") - - -def _validate_and_trim_duration(video: Input.Video) -> Input.Video: - """Validates video duration and trims to 5 seconds if needed.""" - duration = video.get_duration() - _validate_minimum_duration(duration) - return _trim_if_too_long(video, duration) - - -def _validate_minimum_duration(duration: float) -> None: - """Ensures video is at least 5 seconds long.""" - if duration < 5: - raise ValueError("Input video must be at least 5 seconds long.") - - -def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video: - """Trims video to 5 seconds if longer.""" - if duration > 5: - return trim_video(video, 5) - return video - - -def parse_width_height_from_res(resolution: str): - # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict - res_map = { - "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, - "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, - "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, - "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, - "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, - # "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, - } - return res_map.get(resolution, {"width": 1920, "height": 1080}) - - -def parse_control_parameter(value): - control_map = { - "Motion Transfer": "motion_control", - "Canny": "canny_control", - "Pose Transfer": "pose_control", - "Depth": "depth_control", - } - return control_map.get(value, control_map["Motion Transfer"]) - - -async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse: - return await poll_op( - cls, - ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), - response_model=MoonvalleyPromptResponse, - status_extractor=lambda r: (r.status if r and r.status else None), - poll_interval=16.0, - max_poll_attempts=240, - ) - - -class MoonvalleyImg2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyImg2VideoNode", - display_name="Moonvalley Marey Image to Video", - category="api node/video/Moonvalley Marey", - description="Moonvalley Marey Image to Video Node", - inputs=[ - IO.Image.Input( - "image", - tooltip="The reference image used to generate the video", - ), - IO.String.Input( - "prompt", - multiline=True, - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Combo.Input( - "resolution", - options=[ - "16:9 (1920 x 1080)", - "9:16 (1080 x 1920)", - "1:1 (1152 x 1152)", - "4:3 (1536 x 1152)", - "3:4 (1152 x 1536)", - # "21:9 (2560 x 1080)", - ], - default="16:9 (1920 x 1080)", - tooltip="Resolution of the output video", - ), - IO.Float.Input( - "prompt_adherence", - default=4.5, - min=1.0, - max=20.0, - step=1.0, - tooltip="Guidance scale for generation control", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed value", - control_after_generate=True, - ), - IO.Int.Input( - "steps", - default=80, - min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) - max=100, - step=1, - tooltip="Number of denoising steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 1.5}""", - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - prompt: str, - negative_prompt: str, - resolution: str, - prompt_adherence: float, - seed: int, - steps: int, - ) -> IO.NodeOutput: - validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = parse_width_height_from_res(resolution) - - inference_params = MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=steps, - seed=seed, - guidance_scale=prompt_adherence, - width=width_height["width"], - height=width_height["height"], - use_negative_prompts=True, - ) - - # Get MIME type from tensor - assuming PNG format for image tensors - mime_type = "image/png" - image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyTextToVideoRequest( - image_url=image_url, prompt_text=prompt, inference_params=inference_params - ), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) - - -class MoonvalleyVideo2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyVideo2VideoNode", - display_name="Moonvalley Marey Video to Video", - category="api node/video/Moonvalley Marey", - description="", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - tooltip="Describes the video to generate", - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed value", - control_after_generate=False, - ), - IO.Video.Input( - "video", - tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. " - "Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", - ), - IO.Combo.Input( - "control_type", - options=["Motion Transfer", "Pose Transfer"], - default="Motion Transfer", - optional=True, - ), - IO.Int.Input( - "motion_intensity", - default=100, - min=0, - max=100, - step=1, - tooltip="Only used if control_type is 'Motion Transfer'", - optional=True, - ), - IO.Int.Input( - "steps", - default=60, - min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24) - max=100, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Number of inference steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 2.25}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - negative_prompt: str, - seed: int, - video: Input.Video | None = None, - control_type: str = "Motion Transfer", - motion_intensity: int | None = 100, - steps=60, - prompt_adherence=4.5, - ) -> IO.NodeOutput: - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(cls, validated_video) - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - - # Only include motion_intensity for Motion Transfer - control_params = {} - if control_type == "Motion Transfer" and motion_intensity is not None: - control_params["motion_intensity"] = motion_intensity - - inference_params = MoonvalleyVideoToVideoInferenceParams( - negative_prompt=negative_prompt, - seed=seed, - control_params=control_params, - steps=steps, - guidance_scale=prompt_adherence, - ) - - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyVideoToVideoRequest( - control_type=parse_control_parameter(control_type), - video_url=video_url, - prompt_text=prompt, - inference_params=inference_params, - ), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) - - -class MoonvalleyTxt2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyTxt2VideoNode", - display_name="Moonvalley Marey Text to Video", - category="api node/video/Moonvalley Marey", - description="", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Combo.Input( - "resolution", - options=[ - "16:9 (1920 x 1080)", - "9:16 (1080 x 1920)", - "1:1 (1152 x 1152)", - "4:3 (1536 x 1152)", - "3:4 (1152 x 1536)", - "21:9 (2560 x 1080)", - ], - default="16:9 (1920 x 1080)", - tooltip="Resolution of the output video", - ), - IO.Float.Input( - "prompt_adherence", - default=4.0, - min=1.0, - max=20.0, - step=1.0, - tooltip="Guidance scale for generation control", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Random seed value", - ), - IO.Int.Input( - "steps", - default=80, - min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) - max=100, - step=1, - tooltip="Inference steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 1.5}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - negative_prompt: str, - resolution: str, - prompt_adherence: float, - seed: int, - steps: int, - ) -> IO.NodeOutput: - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = parse_width_height_from_res(resolution) - - inference_params = MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=steps, - seed=seed, - guidance_scale=prompt_adherence, - num_frames=128, - width=width_height["width"], - height=width_height["height"], - ) - - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) - - -class MoonvalleyExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - MoonvalleyImg2VideoNode, - MoonvalleyTxt2VideoNode, - MoonvalleyVideo2VideoNode, - ] - - -async def comfy_entrypoint() -> MoonvalleyExtension: - return MoonvalleyExtension() From 10b45a71cdac2898693bb42aa0a21e2cb23a2daa Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Sat, 2 May 2026 03:11:30 +0800 Subject: [PATCH 02/25] chore: update workflow templates to v0.9.66 (#13662) Co-authored-by: Jedrzej Kosinski --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cb85d970b..932034076 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.15 -comfyui-workflow-templates==0.9.65 +comfyui-workflow-templates==0.9.66 comfyui-embedded-docs==0.4.4 torch torchsde From cf758bd2566a04a156496fa77ec2c7fa76ff8873 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 1 May 2026 22:48:41 +0300 Subject: [PATCH 03/25] chore(api-nodes): increase default timeout for partner API node tasks (#13663) Signed-off-by: bigcat88 Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/nodes_bytedance.py | 3 --- comfy_api_nodes/nodes_hitpaw.py | 2 -- comfy_api_nodes/nodes_kling.py | 3 --- comfy_api_nodes/nodes_magnific.py | 5 ----- comfy_api_nodes/nodes_topaz.py | 1 - comfy_api_nodes/nodes_vidu.py | 3 +-- comfy_api_nodes/nodes_wan.py | 1 - comfy_api_nodes/nodes_wavespeed.py | 2 -- comfy_api_nodes/util/client.py | 4 ++-- 9 files changed, 3 insertions(+), 21 deletions(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index fee0ab888..2f241a775 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1403,7 +1403,6 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), poll_interval=9, - max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) @@ -1585,7 +1584,6 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), poll_interval=9, - max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) @@ -1907,7 +1905,6 @@ class ByteDance2ReferenceNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input), poll_interval=9, - max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py index 488080a74..bca5170e4 100644 --- a/comfy_api_nodes/nodes_hitpaw.py +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -178,7 +178,6 @@ class HitPawGeneralImageEnhance(IO.ComfyNode): status_extractor=lambda x: x.data.status, price_extractor=lambda x: request_price, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url)) @@ -324,7 +323,6 @@ class HitPawVideoEnhance(IO.ComfyNode): status_extractor=lambda x: x.data.status, price_extractor=lambda x: request_price, poll_interval=10.0, - max_poll_attempts=320, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url)) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 709b3726c..efd58fac3 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -276,7 +276,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3062,7 +3061,6 @@ class KlingVideoNode(IO.ComfyNode): cls, ApiEndpoint(path=poll_path), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3188,7 +3186,6 @@ class KlingFirstLastFrameNode(IO.ComfyNode): cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py index 0f53208d4..38b881fea 100644 --- a/comfy_api_nodes/nodes_magnific.py +++ b/comfy_api_nodes/nodes_magnific.py @@ -230,7 +230,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): status_extractor=lambda x: x.status, price_extractor=lambda _: price_usd, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) @@ -391,7 +390,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): status_extractor=lambda x: x.status, price_extractor=lambda _: price_usd, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) @@ -541,7 +539,6 @@ class MagnificImageStyleTransferNode(IO.ComfyNode): response_model=TaskResponse, status_extractor=lambda x: x.status, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) @@ -782,7 +779,6 @@ class MagnificImageRelightNode(IO.ComfyNode): response_model=TaskResponse, status_extractor=lambda x: x.status, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) @@ -924,7 +920,6 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode): response_model=TaskResponse, status_extractor=lambda x: x.status, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index b18b31af1..fe3666ec9 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -453,7 +453,6 @@ class TopazVideoEnhance(IO.ComfyNode): progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), poll_interval=10.0, - max_poll_attempts=320, ) return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index f04407eb5..8d90cefeb 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -38,7 +38,7 @@ async def execute_task( cls: type[IO.ComfyNode], vidu_endpoint: str, payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest, - max_poll_attempts: int = 320, + max_poll_attempts: int = 480, ) -> list[TaskResult]: task_creation_response = await sync_op( cls, @@ -1097,7 +1097,6 @@ class ViduExtendVideoNode(IO.ComfyNode): video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"), images=[image_url] if image_url else None, ), - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 7d7466fb6..68061bb5c 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -818,7 +818,6 @@ class WanReferenceVideoApi(IO.ComfyNode): response_model=VideoTaskStatusResponse, status_extractor=lambda x: x.output.task_status, poll_interval=6, - max_poll_attempts=280, ) return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py index c59fafd3b..65e45f60a 100644 --- a/comfy_api_nodes/nodes_wavespeed.py +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -84,7 +84,6 @@ class WavespeedFlashVSRNode(IO.ComfyNode): response_model=TaskResultResponse, status_extractor=lambda x: "failed" if x.data is None else x.data.status, poll_interval=10.0, - max_poll_attempts=480, ) if final_response.code != 200: raise ValueError( @@ -156,7 +155,6 @@ class WavespeedImageUpscaleNode(IO.ComfyNode): response_model=TaskResultResponse, status_extractor=lambda x: "failed" if x.data is None else x.data.status, poll_interval=10.0, - max_poll_attempts=480, ) if final_response.code != 200: raise ValueError( diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index b0cf97ae4..a0b8d35e1 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -148,7 +148,7 @@ async def poll_op( queued_statuses: list[str | int] | None = None, data: BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 160, + max_poll_attempts: int = 480, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, @@ -254,7 +254,7 @@ async def poll_op_raw( queued_statuses: list[str | int] | None = None, data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 160, + max_poll_attempts: int = 480, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, From 63103d519ec960701438e8617452ef64b02609c7 Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Fri, 1 May 2026 14:16:41 -0700 Subject: [PATCH 04/25] Remove IPEX and clean up checks and add missing synchronize during empty cache. (#13653) --- comfy/cli_args.py | 1 - comfy/model_management.py | 18 +++--------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..cef1a5e6b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -90,7 +90,6 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") -parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") class LatentPreviewMethod(enum.Enum): diff --git a/comfy/model_management.py b/comfy/model_management.py index 95af40012..f86e2a4aa 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -112,10 +112,6 @@ if args.directml is not None: # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. -try: - import intel_extension_for_pytorch as ipex # noqa: F401 -except: - pass try: _ = torch.xpu.device_count() @@ -583,9 +579,6 @@ class LoadedModel: real_model = self.model.model - if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: - with torch.no_grad(): - real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) self.real_model = weakref.ref(real_model) self.model_finalizer = weakref.finalize(real_model, cleanup_models) @@ -1581,10 +1574,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return False if is_intel_xpu(): - if torch_version_numeric < (2, 3): - return True - else: - return torch.xpu.get_device_properties(device).has_fp16 + return torch.xpu.get_device_properties(device).has_fp16 if is_ascend_npu(): return True @@ -1650,10 +1640,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False if is_intel_xpu(): - if torch_version_numeric < (2, 3): - return True - else: - return torch.xpu.is_bf16_supported() + return torch.xpu.is_bf16_supported() if is_ascend_npu(): return True @@ -1784,6 +1771,7 @@ def soft_empty_cache(force=False): if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif is_intel_xpu(): + torch.xpu.synchronize() torch.xpu.empty_cache() elif is_ascend_npu(): torch.npu.empty_cache() From b5921c8ac2d3cd1171bb33245f4343b1471224ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 2 May 2026 00:17:25 +0300 Subject: [PATCH 05/25] SDPose: resize fix (#13656) --- comfy_extras/nodes_sdpose.py | 38 ++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py index 7d54967d5..96b6821bd 100644 --- a/comfy_extras/nodes_sdpose.py +++ b/comfy_extras/nodes_sdpose.py @@ -459,27 +459,23 @@ class SDPoseKeypointExtractor(io.ComfyNode): total_images = image.shape[0] captured_feat = None - model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 - model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 + model_w = int(head.heatmap_size[0]) * 4 # 192 * 4 = 768 + model_h = int(head.heatmap_size[1]) * 4 # 256 * 4 = 1024 def _resize_to_model(imgs): - """Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left).""" + """Stretch BHWC images to (model_h, model_w), model expects no aspect preservation.""" h, w = imgs.shape[-3], imgs.shape[-2] - scale = min(model_h / h, model_w / w) - sh, sw = int(round(h * scale)), int(round(w * scale)) - pt, pl = (model_h - sh) // 2, (model_w - sw) // 2 + method = "area" if (model_h <= h and model_w <= w) else "bilinear" chw = imgs.permute(0, 3, 1, 2).float() - scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled") - padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) - padded[:, :, pt:pt + sh, pl:pl + sw] = scaled - return padded.permute(0, 2, 3, 1), scale, pt, pl + scaled = comfy.utils.common_upscale(chw, model_w, model_h, upscale_method=method, crop="disabled") + return scaled.permute(0, 2, 3, 1), model_w / w, model_h / h - def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0): + def _remap_keypoints(kp, scale_x, scale_y, offset_x=0, offset_y=0): """Remap keypoints from model space back to original image space.""" kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) invalid = kp[..., 0] < 0 - kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x - kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y + kp[..., 0] = kp[..., 0] / scale_x + offset_x + kp[..., 1] = kp[..., 1] / scale_y + offset_y kp[invalid] = -1 return kp @@ -529,18 +525,18 @@ class SDPoseKeypointExtractor(io.ComfyNode): continue crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) - crop_resized, scale, pad_top, pad_left = _resize_to_model(crop) + crop_resized, sx, sy = _resize_to_model(crop) latent_crop = vae.encode(crop_resized) kp_batch, sc_batch = _run_on_latent(latent_crop) - kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1) + kp = _remap_keypoints(kp_batch[0], sx, sy, x1, y1) img_keypoints.append(kp) img_scores.append(sc_batch[0]) else: - img_resized, scale, pad_top, pad_left = _resize_to_model(img) + img_resized, sx, sy = _resize_to_model(img) latent_img = vae.encode(img_resized) kp_batch, sc_batch = _run_on_latent(latent_img) - img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left)) + img_keypoints.append(_remap_keypoints(kp_batch[0], sx, sy)) img_scores.append(sc_batch[0]) all_keypoints.append(img_keypoints) @@ -549,12 +545,12 @@ class SDPoseKeypointExtractor(io.ComfyNode): else: # full-image mode, batched for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"): - batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size]) + batch_resized, sx, sy = _resize_to_model(image[batch_start:batch_start + batch_size]) latent_batch = vae.encode(batch_resized) kp_batch, sc_batch = _run_on_latent(latent_batch) for kp, sc in zip(kp_batch, sc_batch): - all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)]) + all_keypoints.append([_remap_keypoints(kp, sx, sy)]) all_scores.append([sc]) pbar.update(len(kp_batch)) @@ -727,13 +723,13 @@ class CropByBBoxes(io.ComfyNode): scale = min(output_width / crop_w, output_height / crop_h) scaled_w = int(round(crop_w * scale)) scaled_h = int(round(crop_h * scale)) - scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") + scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="area", crop="disabled") pad_left = (output_width - scaled_w) // 2 pad_top = (output_height - scaled_h) // 2 resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device) resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled else: # "stretch" - resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") + resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="area", crop="disabled") crops.append(resized) if not crops: From 0230e0e7cc389979e509cd6237a7b9244798e69c Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sat, 2 May 2026 06:37:18 +0800 Subject: [PATCH 06/25] Adding kijai (#13664) Co-authored-by: Jedrzej Kosinski --- CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index e693955a0..946dbf946 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,2 +1,2 @@ # Admins -* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 +* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai From 67f6cb35273d00278d2b1ef2a8c3efe21238f22d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 1 May 2026 17:19:32 -0700 Subject: [PATCH 07/25] List all the portable downloads in the README section. (#13666) --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f05311421..3b5114633 100644 --- a/README.md +++ b/README.md @@ -193,13 +193,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start. -#### Alternative Downloads: +#### All Official Portable Downloads: [Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z) +[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z) -[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). +[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above). + +[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? From 3e3ed8cc2aaa142711e89e1e799956e1e57af62f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 1 May 2026 17:19:46 -0700 Subject: [PATCH 08/25] Add script in AMD portable to launch with dynamic vram. (#13667) --- ...ble_smart_memory.bat => run_amd_gpu_enable_dynamic_vram.bat} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename .ci/windows_amd_base_files/{run_amd_gpu_disable_smart_memory.bat => run_amd_gpu_enable_dynamic_vram.bat} (66%) diff --git a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat similarity index 66% rename from .ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat rename to .ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat index cece0aeb2..94ad31942 100755 --- a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat +++ b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat @@ -1,2 +1,2 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram pause From 783782d5d742a7bc38dd0b661e030813bc50839a Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 3 May 2026 09:23:24 +1000 Subject: [PATCH 09/25] Implement block prefetch + Lora Async load + and adopt in LTX (Speedup!) (CORE-111) (#13618) * mm: Use Aimdo raw allocator for cast buffers pytorch manages allocation of growing buffers on streams poorly. Pyt has no windows support for the expandable segments allocator (which is the right tool for this job), while also segmenting the memory by stream such that it can be generally re-used. So kick the problem to aimdo which can just grow a virtual region thats freed per stream. * plan * ops: move cpu handler up to the caller * ops: split up prefetch from weight prep block prefetching API Split up the casting and weight formating/lora stuff in prep for arbitrary prefetch support. * ops: implement block prefetching API allow a model to construct a prefetch list and operate it for increased async offload. * ltxv2: Implement block prefetching * Implement lora async offload Implement async offload of loras. --- comfy/ldm/lightricks/av_model.py | 5 + comfy/lora.py | 15 +++ comfy/model_base.py | 5 + comfy/model_management.py | 22 +++- comfy/model_patcher.py | 13 ++- comfy/model_prefetch.py | 65 +++++++++++ comfy/ops.py | 181 ++++++++++++++++++++++--------- execution.py | 2 + 8 files changed, 251 insertions(+), 57 deletions(-) create mode 100644 comfy/model_prefetch.py diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 6f2ba41ef..3fb87b4a3 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit +import comfy.model_prefetch class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" @@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): @@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel): a_prompt_timestep=a_prompt_timestep, ) + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None) + return [vx, ax] def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): diff --git a/comfy/lora.py b/comfy/lora.py index e4337c729..db8f16bcb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -17,6 +17,7 @@ """ from __future__ import annotations +import comfy.memory_management import comfy.utils import comfy.model_management import comfy.model_base @@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori weight = old_weight return weight + +def prefetch_prepared_value(value, allocate_buffer, stream): + if isinstance(value, torch.Tensor): + dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + return comfy.memory_management.interpret_gathered_like([value], dest)[0] + elif isinstance(value, weight_adapter.WeightAdapterBase): + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + elif isinstance(value, tuple): + return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + elif isinstance(value, list): + return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + + return value diff --git a/comfy/model_base.py b/comfy/model_base.py index 50dab5782..b61a2aa09 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module): if "latent_shapes" in extra_conds: xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + transformer_options = transformer_options.copy() + transformer_options["prefetch_dynamic_vbars"] = ( + self.current_patcher is not None and self.current_patcher.is_dynamic() + ) + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) if len(model_output) > 1 and not torch.is_tensor(model_output): model_output, _ = utils.pack_latents(model_output) diff --git a/comfy/model_management.py b/comfy/model_management.py index f86e2a4aa..02ad66656 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.vram_buffer class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -1175,6 +1176,10 @@ stream_counters = {} STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) +STREAM_AIMDO_CAST_BUFFERS = {} +LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + +DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 def get_cast_buffer(offload_stream, device, size, ref): global LARGEST_CASTED_WEIGHT @@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref): return cast_buffer +def get_aimdo_cast_buffer(offload_stream, device): + cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None: + cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) + STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer + + return cast_buffer def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT + global LARGEST_AIMDO_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) - for offload_stream in STREAM_CAST_BUFFERS: - offload_stream.synchronize() + LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + if offload_stream is not None: + offload_stream.synchronize() synchronize() + STREAM_CAST_BUFFERS.clear() + STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e259aed63..7d2d6883f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -121,9 +121,20 @@ class LowVramPatch: self.patches = patches self.convert_func = convert_func # TODO: remove self.set_func = set_func + self.prepared_patches = None + + def prepare(self, allocate_buffer, stream): + self.prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + for patch in self.patches[self.key] + ] + + def clear_prepared(self): + self.prepared_patches = None def __call__(self, weight): - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) + patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key] + return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype) LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py new file mode 100644 index 000000000..0ad35deb5 --- /dev/null +++ b/comfy/model_prefetch.py @@ -0,0 +1,65 @@ +import comfy_aimdo.model_vbar +import comfy.model_management +import comfy.ops + +PREFETCH_QUEUES = [] + +def cleanup_prefetched_modules(comfy_modules): + for s in comfy_modules: + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + continue + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + if prefetch["signature"] is not None: + comfy_aimdo.model_vbar.vbar_unpin(s._v) + delattr(s, "_prefetch") + +def cleanup_prefetch_queues(): + global PREFETCH_QUEUES + + for queue in PREFETCH_QUEUES: + for entry in queue: + if entry is None or not isinstance(entry, tuple): + continue + _, prefetch_state = entry + comfy_modules = prefetch_state[1] + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + PREFETCH_QUEUES = [] + +def prefetch_queue_pop(queue, device, module): + if queue is None: + return + + consumed = queue.pop(0) + if consumed is not None: + offload_stream, prefetch_state = consumed + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + _, comfy_modules = prefetch_state + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + + prefetch = queue[0] + if prefetch is not None: + comfy_modules = [] + for s in prefetch.modules(): + if hasattr(s, "_v"): + comfy_modules.append(s) + + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) + comfy.model_management.sync_stream(device, offload_stream) + queue[0] = (offload_stream, (prefetch, comfy_modules)) + +def make_prefetch_queue(queue, device, transformer_options): + if (not transformer_options.get("prefetch_dynamic_vbars", False) + or comfy.model_management.NUM_STREAMS == 0 + or comfy.model_management.is_device_cpu(device) + or not comfy.model_management.device_supports_non_blocking(device)): + return None + + queue = [None] + queue + [None] + PREFETCH_QUEUES.append(queue) + return queue diff --git a/comfy/ops.py b/comfy/ops.py index 050f7cda0..96db1411c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys): setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): - #vbar doesn't support CPU weights, but some custom nodes have weird paths - #that might switch the layer to the CPU and expect it to work. We have to take - #a clone conservatively as we are mmapped and some SFT files are packed misaligned - #If you are a custom node author reading this, please move your layer to the GPU - #or declare your ModelPatcher as CPU in the first place. - if comfy.model_management.is_device_cpu(device): - materialize_meta_param(s, ["weight", "bias"]) - weight = s.weight.to(dtype=dtype, copy=True) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - bias = None - if s.bias is not None: - bias = s.bias.to(dtype=bias_dtype, copy=True) - return weight, bias, (None, None, None) - +# FIXME: add n=1 cache hit fast path +def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking): offload_stream = None - xfer_dest = None + cast_buffer = None + cast_buffer_offset = 0 + + def ensure_offload_stream(module, required_size, check_largest): + nonlocal offload_stream + nonlocal cast_buffer + + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is None or not check_largest or len(comfy_modules) != 1: + return + + current_size = 0 if cast_buffer is None else cast_buffer.size() + if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = None + if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: + comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size) + + def get_cast_buffer(buffer_size): + nonlocal offload_stream + nonlocal cast_buffer + nonlocal cast_buffer_offset + + if buffer_size == 0: + return None + + if offload_stream is None: + return torch.empty((buffer_size,), dtype=torch.uint8, device=device) + + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device) + cast_buffer_offset += buffer_size + return buffer + + for s in comfy_modules: + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + prefetch = { + "signature": signature, + "resident": resident, + } - signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) - if signature is not None: if resident: - weight = s._v_weight - bias = s._v_bias - else: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + s._prefetch = prefetch + continue - if not resident: materialize_meta_param(s, ["weight", "bias"]) + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None + needs_cast = False xfer_source = [ s.weight, s.bias ] @@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if data is None: continue if data.dtype != geometry.dtype: + needs_cast = True cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) xfer_dest = None break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) - offload_stream = comfy.model_management.get_offload_stream(device) - if xfer_dest is None and offload_stream is not None: - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) - if xfer_dest is None: - offload_stream = comfy.model_management.get_offload_stream(device) - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True) if xfer_dest is None: - xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) - offload_stream = None + xfer_dest = get_cast_buffer(dest_size) if signature is None and pin is None: comfy.pinned_memory.pin_memory(s) @@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) - comfy.model_management.sync_stream(device, offload_stream) - if cast_dest is not None: + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + ensure_offload_stream(s, cast_buffer_offset, False) + lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + + prefetch["xfer_dest"] = xfer_dest + prefetch["cast_dest"] = cast_dest + prefetch["cast_geometry"] = cast_geometry + prefetch["needs_cast"] = needs_cast + s._prefetch = prefetch + + return offload_stream + + +def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant): + + prefetch = getattr(s, "_prefetch", None) + + if prefetch["resident"]: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = prefetch["xfer_dest"] + if prefetch["needs_cast"]: + cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device) for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), - comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest) weight = params[0] bias = params[1] - if signature is not None: + if prefetch["signature"] is not None: s._v_weight = weight s._v_bias = bias - s._v_signature=signature + s._v_signature = prefetch["signature"] def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) + if x is None: + return None + orig = x def to_dequant(tensor, dtype): @@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = f(x) return x - update_weight = signature is not None + update_weight = prefetch["signature"] is not None + weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight) + if bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) - weight = post_cast(s, "weight", weight, dtype, resident, update_weight) - if s.bias is not None: - bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) - - #FIXME: weird offload return protocol - return weight, bias, (offload_stream, device if signature is not None else None, None) + return weight, bias def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): @@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + def format_return(result, offloadable): + weight, bias, offload_stream = result + return (weight, bias, offload_stream) if offloadable else (weight, bias) + non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): - return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if comfy.model_management.is_device_cpu(device): + materialize_meta_param(s, ["weight", "bias"]) + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None + return format_return((weight, bias, (None, None, None)), offloadable) + + prefetched = hasattr(s, "_prefetch") + offload_stream = None + offload_device = None + if not prefetched: + offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking) + comfy.model_management.sync_stream(device, offload_stream) + + weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant) + + if not prefetched: + if getattr(s, "_prefetch")["signature"] is not None: + offload_device = device + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + delattr(s, "_prefetch") + return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): @@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of for f in s.weight_function: weight = f(weight) - if offloadable: - return weight, bias, (offload_stream, weight_a, bias_a) - else: - #Legacy function signature - return weight, bias + return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable) def uncast_bias_weight(s, weight, bias, offload_stream): diff --git a/execution.py b/execution.py index 5a6d3404c..654db8426 100644 --- a/execution.py +++ b/execution.py @@ -15,6 +15,7 @@ import torch from comfy.cli_args import args import comfy.memory_management import comfy.model_management +import comfy.model_prefetch import comfy_aimdo.model_vbar from latent_preview import set_preview_method @@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() + comfy.model_prefetch.cleanup_prefetch_queues() comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: From ef6722f6be7bf073d225d21da47354905a6abd2b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 2 May 2026 17:34:27 -0700 Subject: [PATCH 10/25] Some cleanups to the load image node. (#13677) --- nodes.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/nodes.py b/nodes.py index 99dc07227..710cccffe 100644 --- a/nodes.py +++ b/nodes.py @@ -1694,26 +1694,27 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" + def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) + dtype = comfy.model_management.intermediate_dtype() + device = comfy.model_management.intermediate_device() + components = InputImpl.VideoFromFile(image_path).get_components() if components.images.shape[0] > 0: - return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu")) + return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device)) + # This code is left here to handle animated webp which pyav does not support loading img = node_helpers.pillow(Image.open, image_path) output_images = [] output_masks = [] w, h = None, None - dtype = comfy.model_management.intermediate_dtype() - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") if len(output_images) == 0: @@ -1728,25 +1729,15 @@ class LoadImage: if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) - elif i.mode == 'P' and 'transparency' in i.info: - mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") output_images.append(image.to(dtype=dtype)) output_masks.append(mask.unsqueeze(0).to(dtype=dtype)) - if img.format == "MPO": - break # ignore all frames except the first one for MPO format + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) - if len(output_images) > 1: - output_image = torch.cat(output_images, dim=0) - output_mask = torch.cat(output_masks, dim=0) - else: - output_image = output_images[0] - output_mask = output_masks[0] - - return (output_image, output_mask) + return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype)) @classmethod def IS_CHANGED(s, image): From 1d23a875ed0d4644538265635c1259be08a3370e Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Sun, 3 May 2026 10:06:55 +0800 Subject: [PATCH 11/25] chore: update workflow templates to v0.9.68 (#13678) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 932034076..32826e25a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.15 -comfyui-workflow-templates==0.9.66 +comfyui-workflow-templates==0.9.68 comfyui-embedded-docs==0.4.4 torch torchsde From f756d801a1e5fbafe81cdfdf8a1c0aadf54c9bea Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 3 May 2026 05:29:00 +0300 Subject: [PATCH 12/25] [Partner Nodes] Topaz Astra 2 model (#13672) * feat(api-nodes): add Topaz Astra 2 model Signed-off-by: bigcat88 * feat(api-nodes): make Astra 2 the default Topaz upscaler model Reorder UPSCALER_MODELS_MAP and the upscaler_model dynamic combo so "Astra 2" appears first, surfacing it as the default selection. --------- Signed-off-by: bigcat88 Co-authored-by: Marwan Mostafa --- comfy_api_nodes/apis/topaz.py | 9 +- comfy_api_nodes/nodes_topaz.py | 361 ++++++++++++++++++++++++++++++++- 2 files changed, 365 insertions(+), 5 deletions(-) diff --git a/comfy_api_nodes/apis/topaz.py b/comfy_api_nodes/apis/topaz.py index a9e6235a7..f91980e3d 100644 --- a/comfy_api_nodes/apis/topaz.py +++ b/comfy_api_nodes/apis/topaz.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional from pydantic import BaseModel, Field @@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel): grain: Optional[float] = Field(None, description="Grain after AI model processing") grainSize: Optional[float] = Field(None, description="Size of generated grain") recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") - creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") + creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.") isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") + prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)") + sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness") + realism: float | None = Field(None, description="ast-2 realism control") class OutputInformationVideo(BaseModel): @@ -90,7 +93,7 @@ class Overrides(BaseModel): class CreateVideoRequest(BaseModel): source: CreateVideoRequestSource = Field(...) - filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) + filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...) output: OutputInformationVideo = Field(...) overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index fe3666ec9..e79c16d3c 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -36,11 +36,15 @@ from comfy_api_nodes.util import ( ) UPSCALER_MODELS_MAP = { + "Astra 2": "ast-2", "Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Creative": "slc-1", "Starlight Precise 2.5": "slp-2.5", } +AST2_MAX_FRAMES = 9000 +AST2_MAX_FRAMES_WITH_PROMPT = 450 + class TopazImageEnhance(IO.ComfyNode): @classmethod @@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TopazVideoEnhance", - display_name="Topaz Video Enhance", + display_name="Topaz Video Enhance (Legacy)", category="api node/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), IO.Boolean.Input("upscaler_enabled", default=True), - IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), + IO.Combo.Input( + "upscaler_model", + options=[ + "Starlight (Astra) Fast", + "Starlight (Astra) Creative", + "Starlight Precise 2.5", + ], + ), IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input( "upscaler_creativity", @@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) +class TopazVideoEnhanceV2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazVideoEnhanceV2", + display_name="Topaz Video Enhance", + category="api node/video/Topaz", + description="Breathe new life into video with powerful upscaling and recovery technology.", + inputs=[ + IO.Video.Input("video"), + IO.DynamicCombo.Input( + "upscaler_model", + options=[ + IO.DynamicCombo.Option( + "Astra 2", + [ + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), + IO.Float.Input( + "creativity", + default=0.5, + min=0.0, + max=1.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Creative strength of the upscale.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional descriptive (not instructive) scene prompt." + f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.", + ), + IO.Float.Input( + "sharp", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Pre-enhance sharpness: " + "0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.", + advanced=True, + ), + IO.Float.Input( + "realism", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Pulls output toward photographic realism." + "Leave at 0 for the model default.", + advanced=True, + ), + ], + ), + IO.DynamicCombo.Option( + "Starlight (Astra) Fast", + [IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),], + ), + IO.DynamicCombo.Option( + "Starlight (Astra) Creative", + [ + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), + IO.Combo.Input( + "creativity", + options=["low", "middle", "high"], + default="low", + tooltip="Creative strength of the upscale.", + ), + ], + ), + IO.DynamicCombo.Option( + "Starlight Precise 2.5", + [IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])], + ), + IO.DynamicCombo.Option("Disabled", []), + ], + ), + IO.DynamicCombo.Input( + "interpolation_model", + options=[ + IO.DynamicCombo.Option("Disabled", []), + IO.DynamicCombo.Option( + "apo-8", + [ + IO.Int.Input( + "interpolation_frame_rate", + default=60, + min=15, + max=240, + display_mode=IO.NumberDisplay.number, + tooltip="Output frame rate.", + ), + IO.Int.Input( + "interpolation_slowmo", + default=1, + min=1, + max=16, + display_mode=IO.NumberDisplay.number, + tooltip="Slow-motion factor applied to the input video. " + "For example, 2 makes the output twice as slow and doubles the duration.", + advanced=True, + ), + IO.Boolean.Input( + "interpolation_duplicate", + default=False, + tooltip="Analyze the input for duplicate frames and remove them.", + advanced=True, + ), + IO.Float.Input( + "interpolation_duplicate_threshold", + default=0.01, + min=0.001, + max=0.1, + step=0.001, + display_mode=IO.NumberDisplay.number, + tooltip="Detection sensitivity for duplicate frames.", + advanced=True, + ), + ], + ), + ], + ), + IO.Combo.Input( + "dynamic_compression_level", + options=["Low", "Mid", "High"], + default="Low", + tooltip="CQP level.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=[ + "upscaler_model", + "upscaler_model.upscaler_resolution", + "interpolation_model", + ]), + expr=""" + ( + $model := $lookup(widgets, "upscaler_model"); + $res := $lookup(widgets, "upscaler_model.upscaler_resolution"); + $interp := $lookup(widgets, "interpolation_model"); + $is4k := $contains($res, "4k"); + $hasInterp := $interp != "disabled"; + $rates := { + "starlight (astra) fast": {"hd": 0.43, "uhd": 0.85}, + "starlight precise 2.5": {"hd": 0.70, "uhd": 1.54}, + "astra 2": {"hd": 1.72, "uhd": 2.85}, + "starlight (astra) creative": {"hd": 2.25, "uhd": 3.99} + }; + $surcharge := $is4k ? 0.28 : 0.14; + $entry := $lookup($rates, $model); + $base := $is4k ? $entry.uhd : $entry.hd; + $hi := $base + ($hasInterp ? $surcharge : 0); + $model = "disabled" + ? {"type":"text","text":"Interpolation only"} + : ($hasInterp + ? {"type":"text","text":"~" & $string($base) & "–" & $string($hi) & " credits/src frame"} + : {"type":"text","text":"~" & $string($base) & " credits/src frame"}) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + upscaler_model: dict, + interpolation_model: dict, + dynamic_compression_level: str = "Low", + ) -> IO.NodeOutput: + upscaler_choice = upscaler_model["upscaler_model"] + interpolation_choice = interpolation_model["interpolation_model"] + if upscaler_choice == "Disabled" and interpolation_choice == "Disabled": + raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") + validate_container_format_is_mp4(video) + src_width, src_height = video.get_dimensions() + src_frame_rate = int(video.get_frame_rate()) + duration_sec = video.get_duration() + src_video_stream = video.get_stream_source() + target_width = src_width + target_height = src_height + target_frame_rate = src_frame_rate + filters = [] + if upscaler_choice != "Disabled": + if "1080p" in upscaler_model["upscaler_resolution"]: + target_pixel_p = 1080 + max_long_side = 1920 + else: + target_pixel_p = 2160 + max_long_side = 3840 + ar = src_width / src_height + if src_width >= src_height: + # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width + target_height = target_pixel_p + target_width = int(target_height * ar) + # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) + if target_width > max_long_side: + target_width = max_long_side + target_height = int(target_width / ar) + else: + # Portrait; Attempt to set width to target (e.g., 2160), calculate height + target_width = target_pixel_p + target_height = int(target_width / ar) + # Check if height exceeds standard bounds + if target_height > max_long_side: + target_height = max_long_side + target_width = int(target_height * ar) + if target_width % 2 != 0: + target_width += 1 + if target_height % 2 != 0: + target_height += 1 + model_id = UPSCALER_MODELS_MAP[upscaler_choice] + if model_id == "slc-1": + filters.append( + VideoEnhancementFilter( + model=model_id, + creativity=upscaler_model["creativity"], + isOptimizedMode=True, + ) + ) + elif model_id == "ast-2": + n_frames = video.get_frame_count() + ast2_prompt = (upscaler_model["prompt"] or "").strip() + if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT: + raise ValueError( + f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames " + f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip." + ) + if n_frames > AST2_MAX_FRAMES: + raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.") + realism = upscaler_model["realism"] + filters.append( + VideoEnhancementFilter( + model=model_id, + creativity=upscaler_model["creativity"], + prompt=(ast2_prompt or None), + sharp=upscaler_model["sharp"], + realism=(realism if realism > 0 else None), + ) + ) + else: + filters.append(VideoEnhancementFilter(model=model_id)) + if interpolation_choice != "Disabled": + target_frame_rate = interpolation_model["interpolation_frame_rate"] + filters.append( + VideoFrameInterpolationFilter( + model=interpolation_choice, + slowmo=interpolation_model["interpolation_slowmo"], + fps=interpolation_model["interpolation_frame_rate"], + duplicate=interpolation_model["interpolation_duplicate"], + duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"], + ), + ) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/video/", method="POST"), + response_model=CreateVideoResponse, + data=CreateVideoRequest( + source=CreateVideoRequestSource( + container="mp4", + size=get_fs_object_size(src_video_stream), + duration=int(duration_sec), + frameCount=video.get_frame_count(), + frameRate=src_frame_rate, + resolution=Resolution(width=src_width, height=src_height), + ), + filters=filters, + output=OutputInformationVideo( + resolution=Resolution(width=target_width, height=target_height), + frameRate=target_frame_rate, + audioCodec="AAC", + audioTransfer="Copy", + dynamicCompressionLevel=dynamic_compression_level, + ), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + upload_res = await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/accept", + method="PATCH", + ), + response_model=VideoAcceptResponse, + wait_label="Preparing upload", + final_label_on_success="Upload started", + ) + if len(upload_res.urls) > 1: + raise NotImplementedError( + "Large files are not currently supported. Please open an issue in the ComfyUI repository." + ) + async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session: + if isinstance(src_video_stream, BytesIO): + src_video_stream.seek(0) + async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + else: + with builtins.open(src_video_stream, "rb") as video_file: + async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", + method="PATCH", + ), + response_model=VideoCompleteUploadResponse, + data=VideoCompleteUploadRequest( + uploadResults=[ + VideoCompleteUploadRequestPart( + partNum=1, + eTag=upload_etag, + ), + ], + ), + wait_label="Finalizing upload", + final_label_on_success="Upload completed", + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), + response_model=VideoStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), + poll_interval=10.0, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) + + class TopazExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TopazImageEnhance, TopazVideoEnhance, + TopazVideoEnhanceV2, ] From be95871adccfac92a91ebdc06e52a85511f7b85c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sun, 3 May 2026 05:46:15 +0300 Subject: [PATCH 13/25] feat: Gemma4 text generation support (CORE-30) (#13376) * initial gemma4 support * parity with reference implementation outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize * Cleanup, video fixes * cleanup, enable fused rms norm by default * update comment * Cleanup * Update sd.py * Various fixes * Add fp8 scaled embedding support * small fixes * Translate think tokens * Fix image encoder attention mask type So it works with basic attention * Handle thinking tokens different only for Gemma4 * Code cleanup * Update nodes_textgen.py * Use embed scale class instead of buffer Slight difference to HF, but technically more accurate and simpler code * Default to fused rms_norm * Update gemma4.py --- comfy/ldm/modules/attention.py | 24 +- comfy/ops.py | 87 +++ comfy/rmsnorm.py | 1 + comfy/sd.py | 17 + comfy/text_encoders/gemma4.py | 1298 ++++++++++++++++++++++++++++++++ comfy/text_encoders/llama.py | 40 +- comfy/text_encoders/lt.py | 3 +- comfy/text_encoders/lumina2.py | 3 +- comfy/text_encoders/qwen35.py | 2 - comfy/utils.py | 7 - comfy_extras/nodes_textgen.py | 13 +- 11 files changed, 1453 insertions(+), 42 deletions(-) create mode 100644 comfy/text_encoders/gemma4.py diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b193fe5e8..a68cb8439 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management +TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5) + if model_management.xformers_enabled(): import xformers import xformers.ops @@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape b, _, dim_head = q.shape dim_head //= heads - scale = dim_head ** -0.5 + if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]: + n_rep = q.shape[-3] // k.shape[-3] + k = k.repeat_interleave(n_rep, dim=-3) + v = v.repeat_interleave(n_rep, dim=-3) + + scale = kwargs.get("scale", dim_head ** -0.5) h = heads if skip_reshape: @@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, b, _, dim_head = query.shape dim_head //= heads + if "scale" in kwargs: + # Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head)) + query = query * (kwargs["scale"] * dim_head ** 0.5) + if skip_reshape: query = query.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head) @@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape b, _, dim_head = q.shape dim_head //= heads - scale = dim_head ** -0.5 + scale = kwargs.get("scale", dim_head ** -0.5) if skip_reshape: q, k, v = map( @@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.ndim == 3: mask = mask.unsqueeze(1) + # Pass through extra SDPA kwargs (scale, enable_gqa) if provided + # enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above + sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",) + sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys} + if SDP_BATCH_LIMIT >= b: - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=m, - dropout_p=0.0, is_causal=False + dropout_p=0.0, is_causal=False, **sdpa_extra ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out diff --git a/comfy/ops.py b/comfy/ops.py index 96db1411c..4f0338346 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1246,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self._buffers[key] = fn(buf) return self + class Embedding(manual_cast.Embedding): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + weight_key = f"{prefix}weight" + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + # Only fp8 makes sense for embeddings (per-row dequant via index select). + # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. + quant_format = layer_conf.get("format", None) if layer_conf is not None else None + if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + self.quant_format = quant_format + qconfig = QUANT_ALGOS[quant_format] + layout_cls = get_layout_class(qconfig["comfy_tensor_layout"]) + weight = state_dict.pop(weight_key) + manually_loaded_keys = [weight_key] + + scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(scale_key, None) + if scale is not None: + scale = scale.float() + manually_loaded_keys.append(scale_key) + + params = layout_cls.Params( + scale=scale if scale is not None else torch.ones((), dtype=torch.float32), + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.num_embeddings, self.embedding_dim), + ) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), + requires_grad=False) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) + else: + if layer_conf is not None: + state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + if destination is not None: + sd = destination + else: + sd = {} + + if not hasattr(self, 'weight') or self.weight is None: + return sd + + if isinstance(self.weight, QuantizedTensor): + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] + + quant_conf = {"format": self.quant_format} + sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + else: + sd["{}weight".format(prefix)] = self.weight + return sd + + def forward_comfy_cast_weights(self, input, out_dtype=None): + weight = self.weight + + # Optimized path: lookup in fp8, dequantize only the selected rows. + if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0: + qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True) + if isinstance(qdata, QuantizedTensor): + scale = qdata._params.scale + qdata = qdata._qdata + else: + scale = None + + x = torch.nn.functional.embedding( + input, qdata, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + uncast_bias_weight(self, qdata, None, offload_stream) + target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype + x = x.to(dtype=target_dtype) + if scale is not None and scale != 1.0: + x = x * scale.to(dtype=target_dtype) + return x + + # Fallback for non-quantized or weight_function (LoRA) case + return super().forward_comfy_cast_weights(input, out_dtype=out_dtype) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index ab7cf14fa..e54be98d6 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -3,6 +3,7 @@ import comfy.model_management RMSNorm = torch.nn.RMSNorm +# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding). def rms_norm(x, weight=None, eps=1e-6): if weight is None: return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) diff --git a/comfy/sd.py b/comfy/sd.py index ee66490f5..9fce0e7d0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -65,6 +65,7 @@ import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 import comfy.text_encoders.ernie +import comfy.text_encoders.gemma4 import comfy.model_patcher import comfy.lora @@ -1271,6 +1272,9 @@ class TEModel(Enum): QWEN35_9B = 26 QWEN35_27B = 27 MINISTRAL_3_3B = 28 + GEMMA_4_E4B = 29 + GEMMA_4_E2B = 30 + GEMMA_4_31B = 31 def detect_te_model(sd): @@ -1296,6 +1300,12 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.59.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_4_31B + if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd: + return TEModel.GEMMA_4_E4B + if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd: + return TEModel.GEMMA_4_E2B if 'model.layers.47.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: @@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer + elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): + variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, + TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, + TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model] + clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant) + clip_target.tokenizer = variant.tokenizer + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py new file mode 100644 index 000000000..f050061ed --- /dev/null +++ b/comfy/text_encoders/gemma4.py @@ -0,0 +1,1298 @@ +import torch +import torch.nn as nn +import numpy as np +from dataclasses import dataclass +import math + +from comfy import sd1_clip +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.rmsnorm import rms_norm +from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding + + +# Intentional minor divergences from transformers -reference implementation: +# - Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor. +# - RMSNorm uses torch fused F.rms_norm, very slight numerical differences, but considerably faster +# - Input image and audio resizing/resampling slightly different numerically + + +GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} + +@dataclass +class Gemma4Config: + vocab_size: int = 262144 + hidden_size: int = 2560 + intermediate_size: int = 10240 + num_hidden_layers: int = 42 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma4" + head_dim = 256 + global_head_dim = 512 + rms_norm_add = False + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [512, 512, 512, 512, 512, False] + rope_scale = None + partial_rotary_factor: float = 0.25 + final_norm: bool = True + lm_head: bool = False + final_logit_softcapping: float = 30.0 + hidden_size_per_layer_input: int = 256 + num_kv_shared_layers: int = 18 + use_double_wide_mlp: bool = False + stop_tokens = [1, 50, 106] + vision_config = GEMMA4_VISION_CONFIG + audio_config = GEMMA4_AUDIO_CONFIG + mm_tokens_per_image = 280 + +@dataclass +class Gemma4_E2B_Config(Gemma4Config): + hidden_size: int = 1536 + intermediate_size: int = 6144 + num_hidden_layers: int = 35 + num_key_value_heads: int = 1 + sliding_attention = [512, 512, 512, 512, False] + num_kv_shared_layers: int = 20 + use_double_wide_mlp: bool = True + +@dataclass +class Gemma4_31B_Config(Gemma4Config): + hidden_size: int = 5376 + intermediate_size: int = 21504 + num_hidden_layers: int = 60 + num_attention_heads: int = 32 + num_key_value_heads: int = 16 + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + hidden_size_per_layer_input: int = 0 + num_kv_shared_layers: int = 0 + audio_config = None + vision_config = GEMMA4_VISION_31B_CONFIG + + +# unfused RoPE as addcmul_ RoPE diverges from reference code +def _apply_rotary_pos_emb(x, freqs_cis): + cos, sin = freqs_cis[0], freqs_cis[1] + half = x.shape[-1] // 2 + out = x * cos + out[..., :half] -= x[..., half:] * sin[..., :half] + out[..., half:] += x[..., :half] * sin[..., half:] + return out + +class Gemma4Attention(nn.Module): + def __init__(self, config, head_dim, device=None, dtype=None, ops=None): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + self.head_dim = head_dim + self.inner_size = self.num_heads * head_dim + + self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) + + self.q_norm = None + self.k_norm = None + if config.q_norm == "gemma3": + self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) + if config.k_norm == "gemma3": + self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + freqs_cis=None, + past_key_value=None, + sliding_window=None, + shared_kv=None, + ): + batch_size, seq_length, _ = hidden_states.shape + + xq = self.q_proj(hidden_states) + xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + if self.q_norm is not None: + xq = self.q_norm(xq) + + if shared_kv is not None: + xk, xv = shared_kv + # Apply RoPE to Q only (K already has RoPE from source layer) + xq = _apply_rotary_pos_emb(xq, freqs_cis) + present_key_value = None + shareable_kv = None + else: + xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + if self.k_norm is not None: + xk = self.k_norm(xk) + xv = rms_norm(xv) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + xq = _apply_rotary_pos_emb(xq, freqs_cis) + xk = _apply_rotary_pos_emb(xk, freqs_cis) + + present_key_value = None + if past_key_value is not None: + cumulative_len = 0 + if len(past_key_value) > 0: + past_key, past_value, cumulative_len = past_key_value + xk = torch.cat((past_key, xk), dim=2) + xv = torch.cat((past_value, xv), dim=2) + new_cumulative = cumulative_len + seq_length + if sliding_window is not None and xk.shape[2] > sliding_window - 1: + cache_k = xk[:, :, -(sliding_window - 1):] + cache_v = xv[:, :, -(sliding_window - 1):] + else: + cache_k = xk + cache_v = xv + present_key_value = (cache_k, cache_v, new_cumulative) + + # KV for sharing: full xk/xv that SDPA sees (not evicted cache) + shareable_kv = (xk, xv) + + # GQA: pass unexpanded KV with enable_gqa when no sliding mask, + # expand heads when sliding mask is present + # has to be done within SDPA itself to match the reference code, pre-scaling expansion causes numerical differences + expand_kv = (self.num_heads != self.num_kv_heads and + sliding_window is not None and + xk.shape[2] >= sliding_window) + if expand_kv: + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + gqa_kwargs = {} if expand_kv else ({"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}) + output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0, **gqa_kwargs) + + return self.o_proj(output), present_key_value, shareable_kv + + +class TransformerBlockGemma4(nn.Module): + def __init__(self, config, index, device=None, dtype=None, ops=None): + super().__init__() + if config.sliding_attention is not None: + self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] + else: + self.sliding_attention = False + + head_dim = config.head_dim if self.sliding_attention else config.global_head_dim + + self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) + + num_kv_shared = config.num_kv_shared_layers + first_kv_shared = config.num_hidden_layers - num_kv_shared + mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) + self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) + self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) + else: + self.layer_scalar = None + + def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None): + sliding_window = None + if self.sliding_attention: + sliding_window = self.sliding_attention + # For prefill > sliding window, add sliding window restriction to the causal mask. + if x.shape[1] > self.sliding_attention: + sw_mask = torch.zeros(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) + sw_mask.masked_fill_(torch.ones_like(sw_mask, dtype=torch.bool).tril_(-self.sliding_attention), torch.finfo(x.dtype).min) + attention_mask = attention_mask + sw_mask if attention_mask is not None else sw_mask + freqs_cis = freqs_cis[1] + else: + freqs_cis = freqs_cis[0] + + residual = x + x = self.input_layernorm(x) + x, present_key_value, shareable_kv = self.self_attn( + hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, + past_key_value=past_key_value, sliding_window=sliding_window, shared_kv=shared_kv, + ) + x = self.post_attention_layernorm(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + + if self.hidden_size_per_layer_input and per_layer_input is not None: + residual = x + x = self.per_layer_input_gate(x) + x = torch.nn.functional.gelu(x, approximate="tanh") + x = x * per_layer_input + x = self.per_layer_projection(x) + x = self.post_per_layer_input_norm(x) + x = residual + x + + if self.layer_scalar is not None: + x = x * self.layer_scalar + + return x, present_key_value, shareable_kv + + +class Gemma4Transformer(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.config = config + + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) + + self.layers = nn.ModuleList([ + TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ]) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None + + # Precompute RoPE inv_freq on CPU to match reference code's exact value + rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2) + nope_global = config.global_head_dim // 2 - rope_angles_global + global_inv = 1.0 / (config.rope_theta[0] ** (torch.arange(0, 2 * rope_angles_global, 2).float() / config.global_head_dim)) + if nope_global > 0: + global_inv = torch.cat([global_inv, torch.zeros(nope_global)]) + self.register_buffer("_global_inv_freq", global_inv, persistent=False) + + sliding_inv = 1.0 / (config.rope_theta[1] ** (torch.arange(0, config.head_dim, 2).float() / config.head_dim)) + self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False) + + # Per-layer input mechanism + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype) + self.per_layer_model_projection = ops.Linear( + config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, + bias=False, device=device, dtype=dtype) + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, eps=config.rms_norm_eps, + device=device, dtype=dtype) + + def get_past_len(self, past_key_values): + for kv in past_key_values: + if len(kv) >= 3: + return kv[2] + return 0 + + def _freqs_from_inv(self, inv_freq, position_ids, device, dtype): + """Compute cos/sin from stored inv_freq""" + inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device) + pos_exp = position_ids[:, None, :].float() + freqs = (inv_exp @ pos_exp).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype) + + def compute_freqs_cis(self, position_ids, device, dtype=None): + global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype) + sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype) + return [global_freqs, sliding_freqs] + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, + final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=None, + past_key_values=None, input_ids=None): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) + + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = self.get_past_len(past_key_values) + + if position_ids is None: + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) + + freqs_cis = self.compute_freqs_cis(position_ids, x.device, dtype=x.dtype) + + mask = None + min_val = torch.finfo(x.dtype).min + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), min_val) + + if seq_len > 1: + causal_mask = torch.zeros(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device) + causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val) + mask = mask + causal_mask if mask is not None else causal_mask + + # Per-layer inputs + per_layer_inputs = None + if self.hidden_size_per_layer_input: + num_layers = self.config.num_hidden_layers + hpl = self.hidden_size_per_layer_input + per_layer_proj = self.per_layer_model_projection(x) * (1.0 / (self.config.hidden_size ** 0.5)) + per_layer_proj = self.per_layer_projection_norm(per_layer_proj.reshape(*x.shape[:-1], num_layers, hpl)) + if input_ids is not None and input_ids.shape[1] == x.shape[1]: + per_layer_emb = self.embed_tokens_per_layer(input_ids).reshape(*input_ids.shape, num_layers, hpl) + per_layer_inputs = (per_layer_proj + per_layer_emb) * (0.5 ** 0.5) + else: + per_layer_inputs = per_layer_proj + + # KV sharing: later layers reuse KV from the last non-shared sliding/global layer + num_kv_shared = self.config.num_kv_shared_layers + first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers + shared_sliding_kv = None # KV from last non-shared sliding layer + shared_global_kv = None # KV from last non-shared global layer + + intermediate = None + next_key_values = [] + for i, layer in enumerate(self.layers): + past_kv = past_key_values[i] if past_key_values is not None and len(past_key_values) > 0 else None + + layer_kwargs = {} + if per_layer_inputs is not None: + layer_kwargs['per_layer_input'] = per_layer_inputs[:, :, i, :] + + is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention + if i >= first_kv_shared and num_kv_shared > 0: + shared = shared_sliding_kv if is_sliding else shared_global_kv + if shared is not None: + layer_kwargs['shared_kv'] = shared + + x, current_kv, shareable_kv = layer(x=x, attention_mask=mask, freqs_cis=freqs_cis, past_key_value=past_kv, **layer_kwargs) + + next_key_values.append(current_kv if current_kv is not None else ()) + + # Only track the last sliding/global before the sharing boundary + if i < first_kv_shared and shareable_kv is not None: + if is_sliding: + shared_sliding_kv = shareable_kv + else: + shared_global_kv = shareable_kv + + if i == intermediate_output: + intermediate = x.clone() + + if self.norm is not None: + x = self.norm(x) + + if len(next_key_values) > 0: + return x, intermediate, next_key_values + return x, intermediate + + +class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module): + """Common base for all Gemma4 variants: text model + vision.""" + def _init_model(self, config, dtype, device, operations): + self.num_layers = config.num_hidden_layers + self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations) + self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations) + + def logits(self, x): + logits = super().logits(x) + cap = self.model.config.final_logit_softcapping + if cap: + logits = cap * torch.tanh(logits / cap) + return logits + + def init_kv_cache(self, batch, max_cache_len, device, execution_dtype): + past_key_values = [] + for _ in range(self.model.config.num_hidden_layers): + past_key_values.append(()) + return past_key_values + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] + max_soft_tokens = embed.get("max_soft_tokens", None) + vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) + return self.multi_modal_projector(vision_out), None + return None, None + + +class Gemma4AudioMixin: + """Adds audio support to a Gemma4 model.""" + def _init_audio(self, config, dtype, device, operations): + self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations) + self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations) + + def preprocess_embed(self, embed, device): + result, extra = super().preprocess_embed(embed, device) + if result is not None: + return result, extra + if embed["type"] == "audio": + audio = embed.pop("data").to(device, dtype=torch.float32) + audio_mask = embed.pop("mask", None) + if audio_mask is not None: + audio_mask = audio_mask.to(device) + audio_out = self.audio_model(audio, audio_mask=audio_mask) + return self.audio_projector(audio_out), None + return None, None + + +# Vision Encoder + +def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): + """Compute 2D RoPE for vision: separate frequencies for x and y dimensions. + + Args: + head_dim: dimension per head (e.g. 64) + pixel_position_ids: [batch, num_patches, 2] with (x, y) coords + theta: RoPE base frequency + Returns: + (cos, sin) each of shape [batch, num_patches, head_dim] + """ + rotary_dim_per_axis = head_dim // 2 + freq_indices = torch.arange(0, rotary_dim_per_axis, 2, device=device).float() + inv_freq = 1.0 / (theta ** (freq_indices / rotary_dim_per_axis)) + + all_cos, all_sin = [], [] + for i in range(2): # x and y + dim_positions = pixel_position_ids[:, :, i].float() # [batch, num_patches] + freqs = torch.einsum('bi,j->bij', dim_positions, inv_freq.to(device)) # [batch, num_patches, rotary_dim/2] + emb = torch.cat([freqs, freqs], dim=-1) # [batch, num_patches, rotary_dim] + all_cos.append(emb.cos()) + all_sin.append(emb.sin()) + + cos = torch.cat(all_cos, dim=-1).to(pixel_position_ids.device) # [batch, num_patches, head_dim] + sin = torch.cat(all_sin, dim=-1).to(pixel_position_ids.device) + return cos, sin + + +def _apply_vision_2d_rope(x, freqs): + """Apply 2D RoPE (multidimensional) to vision query/key states. + + Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently. + + x: [batch, heads, seq, head_dim] + freqs: (cos, sin) each [batch, seq, head_dim] + """ + cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim] + sin = freqs[1].unsqueeze(1) + half = x.shape[-1] // 2 + a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half])) + b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:])) + return torch.cat([a, b], dim=-1) + + +class ClippedLinear(nn.Module): + """Linear layer with activation clipping (from quantization-aware training). + + Stores input_max/min and output_max/min as buffers loaded from checkpoint. + """ + def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None): + super().__init__() + self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype)) + self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) + self.register_buffer('output_max', torch.tensor(float('inf'), device=device, dtype=dtype)) + self.register_buffer('output_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) + + @property + def weight(self): + return self.linear.weight + + def forward(self, x): + x = x.clamp(min=self.input_min, max=self.input_max) + x = self.linear(x) + return x.clamp_(min=self.output_min, max=self.output_max) + + +class Gemma4VisionMLP(nn.Module): + """SwiGLU MLP matching gate_proj/up_proj/down_proj structure.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + intermediate_size = config["intermediate_size"] + self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) + + def forward(self, x): + return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)) + + +class Gemma4VisionAttention(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads) + + self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) + + self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + + def forward(self, x, freqs, attention_mask=None): + batch_size, seq_length, _ = x.shape + + xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + + xq = self.q_norm(xq).transpose(1, 2) + xk = self.k_norm(xk).transpose(1, 2) + xv = rms_norm(xv) + + xq = _apply_vision_2d_rope(xq, freqs) + xk = _apply_vision_2d_rope(xk, freqs) + + xv = xv.to(xq.dtype).transpose(1, 2) + + output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0) + return self.o_proj(output) + + +class Gemma4VisionLayer(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) + hidden = config["hidden_size"] + self.input_layernorm = RMSNorm(hidden, **norm_kwargs) + self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs) + self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) + self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) + + def forward(self, x, freqs, attention_mask=None): + residual = x + x = self.input_layernorm(x) + x = self.self_attn(x, freqs, attention_mask=attention_mask) + x = self.post_attention_layernorm(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + return x + + +class Gemma4PatchEmbedder(nn.Module): + """Patch embedding with learned 2D position embeddings via one-hot lookup.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + patch_size = config["patch_size"] + self.patch_size = patch_size + self.position_embedding_size = config.get("position_embedding_size", 10240) + + self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype) + self.position_embedding_table = nn.Parameter( + torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype) + ) + + def forward(self, patches, pixel_position_ids): + """ + patches: [B, num_patches, 3*patch_size²] in [0,1] range (normalized to [-1,1] inside, matching HF) + pixel_position_ids: [B, num_patches, 2] with (x,y) positions, (-1,-1) for padding + """ + hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype)) + + clamped_positions = pixel_position_ids.clamp(min=0) + pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype) + position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]] + + # Zero out position embeddings for padding patches (matching HF) + padding_positions = (pixel_position_ids == -1).all(dim=-1) + position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) + + return hidden_states + position_embeddings + + +class Gemma4VisionEncoderLayers(nn.Module): + """Wrapper to produce state dict keys as encoder.layers.X.*""" + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([ + Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops) + for _ in range(config["num_hidden_layers"]) + ]) + + +class Gemma4VisionEncoder(nn.Module): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.config = config + self.hidden_size = config["hidden_size"] + self.head_dim = config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]) + self.patch_size = config["patch_size"] + self.pooling_kernel_size = config.get("pooling_kernel_size", 3) + self.root_hidden_size = self.hidden_size ** 0.5 + + self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops) + self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops) + + def forward(self, pixel_values, max_soft_tokens=None): + """ + pixel_values: [B, C, H, W] in [0,1] range + max_soft_tokens: if provided, pad to max_soft_tokens * k² total patches + """ + batch_size, _, height, width = pixel_values.shape + ps = self.patch_size + k = self.pooling_kernel_size + patches_h, patches_w = height // ps, width // ps + num_patches = patches_h * patches_w + output_length = max_soft_tokens if max_soft_tokens is not None else num_patches // (k * k) + n_padding = output_length * k * k - num_patches + + # Patchify and build position grid + patches = pixel_values.reshape(batch_size, -1, patches_h, ps, patches_w, ps) + patches = patches.permute(0, 2, 4, 3, 5, 1).reshape(batch_size, num_patches, -1) + grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij') + position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1) + + # Append zero-pixel padding with (-1,-1) positions + if n_padding > 0: + patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1) + position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1) + + padding = (position_ids == -1).all(dim=-1) + + # Embed, encode, pool + x = self.patch_embedder(patches, position_ids) + freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) + freqs = tuple(t.to(x.dtype) for t in freqs) + if n_padding > 0: + mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) + mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min) + else: + mask = None + + for layer in self.encoder.layers: + x = layer(x, freqs, attention_mask=mask) + + if n_padding > 0: + x = x.masked_fill(padding.unsqueeze(-1), 0.0) + + # Average pool by spatial position + clamped = position_ids.clamp(min=0) + max_x = clamped[:, :, 0].max(dim=-1, keepdim=True)[0] + 1 + ki = torch.div(clamped, k, rounding_mode="floor") + ki = ki[:, :, 0] + (max_x // k) * ki[:, :, 1] + weights = torch.nn.functional.one_hot(ki.long(), output_length).float() / (k * k) + x = (weights.transpose(1, 2) @ x.float()).to(x.dtype) + + # Strip empty output tokens + valid_out = ~((weights == 0).all(dim=1)) + if valid_out.any() and not valid_out.all(): + x = x[:, valid_out[0]] if batch_size > 1 else x[valid_out].unsqueeze(0) + + return x * self.root_hidden_size + + +class Gemma4RMSNormProjector(nn.Module): + """Shared projector: parameterless RMSNorm → linear. Used for both vision and audio.""" + def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None): + super().__init__() + self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x): + return self.embedding_projection(rms_norm(x)) + + +class Gemma4MultiModalProjector(Gemma4RMSNormProjector): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops) + + +# Audio Encoder + +class Gemma4AudioConvSubsampler(nn.Module): + """2D convolution subsampling for audio features""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + eps = config["rms_norm_eps"] + self.layer0 = nn.ModuleDict({ + 'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + }) + self.layer1 = nn.ModuleDict({ + 'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + }) + # proj_input_dim = (128 // 4) * 32 = 1024 + self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) + + def _conv_layer(self, x, layer, mask): + if mask is not None: + x = x * mask[:, None, :, None].to(x.device) + x = layer['conv'](x.to(layer['conv'].weight.dtype)) + x = torch.relu(layer['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + if mask is not None: + mask = mask[:, ::2] + return x, mask + + def forward(self, x, mask=None): + x = x.unsqueeze(1) + x, mask = self._conv_layer(x, self.layer0, mask) + x, mask = self._conv_layer(x, self.layer1, mask) + batch_size, _, seq_len, _ = x.shape + x = x.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1) + return self.input_proj_linear(x), mask + + +class Gemma4AudioFeedForward(nn.Module): + """Conformer feed-forward with residual scaling.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + intermediate_size = config.get("intermediate_size", hidden_size * 4) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) + self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.post_layer_scale = config.get("residual_weight", 0.5) + + def forward(self, x): + residual = x + x = self.pre_layer_norm(x) + x = torch.nn.functional.silu(self.ffw_layer_1(x)) + x = self.ffw_layer_2(x) + x = self.post_layer_norm(x) + x = x * self.post_layer_scale + return x + residual + + +class Gemma4AudioRelPositionalEncoding(nn.Module): + """Sinusoidal relative positional encoding for audio attention.""" + def __init__(self, config, device=None, dtype=None): + super().__init__() + hidden_size = config["hidden_size"] + context_left = config.get("attention_context_left", 13) + context_right = config.get("attention_context_right", 0) + self.chunk_size = config.get("attention_chunk_size", 12) + self.context_size = self.chunk_size + context_left - 1 + context_right + + num_timescales = hidden_size // 2 + log_inc = math.log(10000.0) / max(num_timescales - 1, 1) + inv_timescales = torch.exp(torch.arange(num_timescales) * -log_inc).to(dtype=dtype).unsqueeze(0).unsqueeze(0) + self.register_buffer("inv_timescales", inv_timescales, persistent=False) + + def forward(self, hidden_states): + positions = torch.arange(self.chunk_size, -1, -1, device=hidden_states.device).unsqueeze(-1) + scaled = positions * self.inv_timescales.to(device=hidden_states.device) + return torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1).to(dtype=hidden_states.dtype) + + +class Gemma4AudioAttention(nn.Module): + """Chunked block attention with relative position bias and softcap.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = self.hidden_size // self.num_heads + self.chunk_size = config.get("attention_chunk_size", 12) + self.max_past_horizon = config.get("attention_context_left", 13) - 1 + self.max_future_horizon = config.get("attention_context_right", 0) + self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon + + self.q_scale = (self.head_dim ** -0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) + self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False) + + self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype)) + self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) + + def _convert_to_block(self, x): + B, S, H, D = x.shape + num_blocks = (S + self.chunk_size - 1) // self.chunk_size + pad = num_blocks * self.chunk_size - S + x = torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad)) + return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous() + + def _extract_block_context(self, x): + x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)) + x = x.unfold(1, self.context_size, self.chunk_size) + return torch.movedim(x, -1, 2).contiguous() + + def _rel_shift(self, x): + B, H, NB, BS, PL = x.shape + CS = self.context_size + x = torch.nn.functional.pad(x, (0, CS + 1 - PL)) + x = x.view(B, H, NB, BS * (CS + 1)) + x = x[..., :BS * CS] + return x.view(B, H, NB, BS, CS) + + def _build_blocked_mask(self, seq_len, num_blocks, device, audio_mask=None): + """Build 5D boolean blocked attention mask (True=attend, False=mask)""" + q = torch.arange(seq_len, device=device) + dist = q[:, None] - q[None, :] + mask = (dist >= 0) & (dist < self.max_past_horizon) + if self.max_future_horizon > 0: + mask = mask | ((dist < 0) & ((-dist) < self.max_future_horizon)) + if audio_mask is not None: + mask = mask & audio_mask[0, None, :].bool() + m = mask[None, None] + # Reshape to blocked 5D matching reference code + p = num_blocks * self.chunk_size - seq_len + m = torch.nn.functional.pad(m, (0, p, 0, p), value=False) + m = m.reshape(1, 1, num_blocks, self.chunk_size, -1) + m = torch.nn.functional.pad(m, (self.max_past_horizon, self.max_future_horizon), value=False) + idx = (torch.arange(num_blocks, device=device) * self.chunk_size)[:, None] + torch.arange(self.context_size, device=device)[None, :] + return m.gather(-1, idx[None, None, :, None, :].expand(1, 1, -1, self.chunk_size, -1)) + + def forward(self, x, position_embeddings=None, attn_mask=None): + B, S, _ = x.shape + + q = self.q_proj(x).float().view(B, S, self.num_heads, self.head_dim) + k = self.k_proj(x).float().view(B, S, self.num_heads, self.head_dim) + v = self.v_proj(x).float().view(B, S, self.num_heads, self.head_dim) + + q = q * self.q_scale * torch.nn.functional.softplus(self.per_dim_scale) + k = k * self.k_scale + + q_blocks = self._convert_to_block(q) + k_context = self._extract_block_context(k) + v_context = self._extract_block_context(v) + num_blocks = q_blocks.shape[1] + + rel_k = self.relative_k_proj(position_embeddings).view(-1, self.num_heads, self.head_dim).to(q.dtype) + + queries = q_blocks.permute(0, 3, 1, 2, 4) # [B, H, NB, CS, D] + matrix_ac = queries @ k_context.permute(0, 3, 1, 4, 2) + + queries_flat = queries.reshape(B, self.num_heads, -1, self.head_dim) + matrix_bd = queries_flat @ rel_k.permute(1, 2, 0) + matrix_bd = matrix_bd.reshape(B, self.num_heads, num_blocks, self.chunk_size, -1) + matrix_bd = self._rel_shift(matrix_bd) + + attn_weights = matrix_ac + matrix_bd + attn_weights = torch.tanh(attn_weights / self.softcap) * self.softcap + + # Mask out invalid positions in chunk context (matching reference's masked_fill approach) + if attn_mask is None: + attn_mask = self._build_blocked_mask(S, num_blocks, x.device) + attn_weights = attn_weights.masked_fill(attn_mask.logical_not(), -1e9) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype) + out = attn_weights @ v_context.permute(0, 3, 1, 2, 4) + out = out.permute(0, 2, 3, 1, 4).reshape(B, num_blocks * self.chunk_size, -1) + out = out[:, :S].contiguous() + return self.post(out.to(self.post.linear.weight.dtype)) + + +class Gemma4AudioLConv1d(nn.Module): + """Lightweight convolution with standard GLU.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + conv_kernel_size = config.get("conv_kernel_size", 5) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) + # Causal conv: left-pad only + self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) + self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 + self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) + + def forward(self, x): + residual = x + x = self.pre_layer_norm(x) + x = self.linear_start(x) + x = torch.nn.functional.glu(x, dim=-1) + x = x.transpose(1, 2) + x = torch.nn.functional.pad(x, (self.conv_left_pad, 0)) + x = self.depthwise_conv1d(x).transpose(1, 2) + x = self.conv_norm(x) + x = torch.nn.functional.silu(x) + x = self.linear_end(x) + return x + residual + + +class Gemma4AudioLayer(nn.Module): + """Conformer block: FFN1 -> Attention -> LConv -> FFN2.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) + self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) + hidden_size = config["hidden_size"] + self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) + self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) + self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops) + self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) + self.norm_out = RMSNorm(hidden_size, **norm_kwargs) + + def forward(self, x, position_embeddings=None, attn_mask=None): + x = self.feed_forward1(x) + + residual = x + x = self.norm_pre_attn(x) + x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask) + x = self.norm_post_attn(x) + x = x + residual + + x = self.lconv1d(x) + x = self.feed_forward2(x) + + x = self.norm_out(x) + return x + + +class Gemma4AudioEncoder(nn.Module): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.output_proj_dims = config.get("output_proj_dims", 1536) + + self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops) + self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype) + + self.layers = nn.ModuleList([ + Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops) + for _ in range(config["num_hidden_layers"]) + ]) + + self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) + + def forward(self, audio_features, audio_mask=None): + x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask) + position_embeddings = self.rel_pos_enc(x) + + # Build blocked attention mask once for all layers + attn_mask = self.layers[0].self_attn._build_blocked_mask( + x.shape[1], (x.shape[1] + self.layers[0].self_attn.chunk_size - 1) // self.layers[0].self_attn.chunk_size, + x.device, audio_mask=audio_mask) + + for layer in self.layers: + x = layer(x, position_embeddings=position_embeddings, attn_mask=attn_mask) + + x = self.output_proj(x) + return x + + +class Gemma4AudioProjector(Gemma4RMSNormProjector): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops) + + +# Tokenizer and Wrappers + +class Gemma4_Tokenizer(): + tokenizer_json_data = None + + def state_dict(self): + if self.tokenizer_json_data is not None: + return {"tokenizer_json": self.tokenizer_json_data} + return {} + + def _extract_mel_spectrogram(self, waveform, sample_rate): + """Extract 128-bin log mel spectrogram. + Uses numpy for FFT/matmul/log to produce bit-identical results with reference code. + """ + # Mix to mono first, then resample to 16kHz + if waveform.dim() > 1 and waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + audio = waveform.squeeze(0).float().numpy() + if sample_rate != 16000: + # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match) + from scipy.signal import resample_poly, firwin + from math import gcd + g = gcd(sample_rate, 16000) + up, down = 16000 // g, sample_rate // g + L = max(up, down) + h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5)) + audio = resample_poly(audio, up, down, window=h).astype(np.float32) + n = len(audio) + + # Pad to multiple of 128, build sample-level mask + if n % 128 != 0: + audio = np.pad(audio, (0, 128 - n % 128)) + mask_raw = np.ones(len(audio), dtype=np.float32) + mask_raw[n:] = 0.0 + + # Semicausal padding: 160 zeros prepended + audio = np.pad(audio, (160, 0)) + mask_raw = np.pad(mask_raw, (160, 0)) + + # Extract 321-sample frames via stride tricks, drop last → 320 + nf = (len(audio) - 321) // 160 + 1 + strides = (audio.strides[0] * 160, audio.strides[0]) + frames = np.lib.stride_tricks.as_strided(audio, (nf, 321), strides)[..., :-1].copy() + + # Periodic Hann window, FFT magnitude, mel filterbank, log + window = (0.5 - 0.5 * np.cos(2 * np.pi * np.arange(320) / 320)).astype(np.float32) + magnitude = np.abs(np.fft.rfft(frames * window, n=512, axis=-1)) + mel_fb = self._build_mel_filterbank() + log_mel = np.log(np.matmul(magnitude, mel_fb) + np.float64(0.001)).astype(np.float32) + + # Frame mask: valid when last sample in window is real audio + mask = mask_raw[np.arange(nf) * 160 + 320].astype(bool) + log_mel = log_mel * mask[:, None] + return torch.from_numpy(log_mel), torch.from_numpy(mask) # [T, 128], [T] + + @staticmethod + def _build_mel_filterbank(): + """Build 128-bin HTK mel filterbank [257, 128] for 512-pt FFT at 16kHz.""" + mel_freqs = np.linspace(0.0, 2595.0 * np.log10(1.0 + 8000.0 / 700.0), 130) + filter_freqs = 700.0 * (10.0 ** (mel_freqs / 2595.0) - 1.0) + fft_freqs = np.linspace(0, 16000 // 2, 257) + filter_diff = np.diff(filter_freqs) + slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes)) + + def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, video=None, llama_template=None, skip_template=True, thinking=False, **kwargs): + + # Process audio + audio_features = [] + if audio is not None: + waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio + sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000 + mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) + audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T]) + + # Process image/video frames + is_video = video is not None + source = video if is_video else image + images = [] + if source is not None: + samples = source.movedim(-1, 1) # [B, C, H, W] + num_frames = samples.shape[0] + + # Subsample video to 1fps + if is_video: + fps = kwargs.get("fps", 24) + step = max(1, round(fps)) + indices = list(range(0, num_frames, step)) + if len(indices) == 0: + indices = [0] + samples = samples[indices] + num_frames = len(indices) + + h, w = samples.shape[2], samples.shape[3] + patch_size = 16 + pooling_k = 3 + max_soft_tokens = 70 if is_video else 280 # video uses smaller token budget per frame + max_patches = max_soft_tokens * pooling_k * pooling_k + target_px = max_patches * patch_size * patch_size + factor = (target_px / (h * w)) ** 0.5 + side_mult = pooling_k * patch_size + target_h = max(int(factor * h // side_mult) * side_mult, side_mult) + target_w = max(int(factor * w // side_mult) * side_mult, side_mult) + + import torchvision.transforms.functional as TVF + for i in range(num_frames): + # rescaling to match reference code + s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8 + if target_h != h or target_w != w: + s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True) + s = s.float() * (1.0 / 255.0) + images.append({"pixels": s.unsqueeze(0).movedim(1, -1)[:, :, :, :3], "max_soft_tokens": max_soft_tokens}) + + if text.startswith('<|turn>'): + skip_template = True + + if skip_template: + llama_text = text + else: + if llama_template is not None: + llama_text = llama_template.format(text) + else: + # Build template from modalities present + system = "<|turn>system\n<|think|>\n" if thinking else "" + media = "" + if len(images) > 0: + if is_video: + media += "\n\n" + for i in range(len(images)): + ts = f"{int(i // 60):02d}:{int(i % 60):02d}" + sep = "" if i == 0 else " " + media += f"{sep}{ts} <|image><|video|>" + media += "\n\n" + else: + media += "\n\n" + for i in range(len(images)): + if i > 0: + media += "\n\n\n\n" + media += "<|image><|image|>" + media += "\n\n" + if len(audio_features) > 0: + # Compute audio token count (always at 16kHz) + num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1] + _fl = 320 # int(round(16000 * 20.0 / 1000.0)) + _hl = 160 # int(round(16000 * 10.0 / 1000.0)) + _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1 + _t = _nmel + for _ in range(2): + _t = (_t + 2 - 3) // 2 + 1 + n_audio_tokens = min(_t, 750) + media += "<|audio>" + "<|audio|>" * n_audio_tokens + "" + llama_text = f"{system}<|turn>user\n{media}{text}\n<|turn>model\n" + + text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) + + def _replace_placeholders(token_list, token_id, embeds): + """Replace first placeholder with embed dict, remove remaining consecutive ones.""" + embed_idx = 0 + i = 0 + while i < len(token_list): + if token_list[i][0] == token_id and embed_idx < len(embeds): + token_list[i] = (embeds[embed_idx],) + token_list[i][1:] + embed_idx += 1 + i += 1 + while i < len(token_list) and token_list[i][0] == token_id: + token_list.pop(i) + else: + i += 1 + + if len(images) > 0: + img_token_id = 258884 if is_video else 258880 + img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images] + for r in text_tokens: + _replace_placeholders(r, img_token_id, img_embeds) + + if len(audio_features) > 0: + aud_embeds = [{"type": "audio", "data": mel, "mask": mask} for mel, mask in audio_features] + for r in text_tokens: + _replace_placeholders(r, 258881, aud_embeds) + + return text_tokens + + +class _Gemma4Tokenizer: + """Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)""" + def __init__(self, tokenizer_json_bytes=None, **kwargs): + from tokenizers import Tokenizer + if isinstance(tokenizer_json_bytes, torch.Tensor): + tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) + self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) + + @classmethod + def from_pretrained(cls, tokenizer_data, **kwargs): + return cls(tokenizer_json_bytes=tokenizer_data, **kwargs) + + def __call__(self, text): + return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids} + + def get_vocab(self): + return self.tokenizer.get_vocab() + + def convert_tokens_to_ids(self, tokens): + return [self.tokenizer.token_to_id(t) for t in tokens] + + def decode(self, ids, **kwargs): + return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False)) + + +# Tokenizer +class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): + embedding_size = 2560 + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_json = tokenizer_data.get("tokenizer_json", None) + self.tokenizer_json_data = tokenizer_json + super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) + + def decode(self, token_ids, **kwargs): + text = super().decode(token_ids, skip_special_tokens=False) + # Translate thinking channel markers to standard / tags + text = text.replace("<|channel>thought\n", "\n") + text = text.replace("", "") + # Strip remaining special tokens + text = text.replace("", "").replace("", "").strip() + return text + + +class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): + tokenizer_class = Gemma4SDTokenizer + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class) + + +# Model wrappers +class Gemma4Model(sd1_clip.SDClipModel): + model_class = None + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + self.dtypes = set() + self.dtypes.add(dtype) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=self.model_class, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def process_tokens(self, tokens, device): + embeds, _, _, _ = super().process_tokens(tokens, device) + return embeds + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0): + if isinstance(tokens, dict): + tokens = next(iter(tokens.values())) + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device) + seq_len = embeds.shape[1] + ids = [0] * seq_len + expanded_idx = 0 + embed_map = {info["index"]: info["size"] for info in embeds_info} + for t in tokens_only[0]: + if expanded_idx in embed_map: + expanded_idx += embed_map[expanded_idx] + elif isinstance(t, int): + if expanded_idx < seq_len: + ids[expanded_idx] = t + expanded_idx += 1 + else: + expanded_idx += 1 + initial_token_ids = [ids] + input_ids = torch.tensor(initial_token_ids, device=self.execution_device) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) + + +def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None): + clip_model = type('Gemma4Model_', (Gemma4Model,), {'model_class': model_class}) + class Gemma4TEModel_(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, name="gemma4", clip_model=clip_model, model_options=model_options) + return Gemma4TEModel_ + + +# Variants + +def _make_variant(config_cls): + audio = config_cls.audio_config is not None + bases = (Gemma4AudioMixin, Gemma4Base) if audio else (Gemma4Base,) + class Variant(*bases): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self._init_model(config_cls(**config_dict), dtype, device, operations) + if audio: + self._init_audio(self.model.config, dtype, device, operations) + embedding_size = config_cls.hidden_size + if embedding_size != Gemma4SDTokenizer.embedding_size: + tok_cls = type('T', (Gemma4SDTokenizer,), {'embedding_size': embedding_size}) + class Tokenizer(Gemma4Tokenizer): + tokenizer_class = tok_cls + Variant.tokenizer = Tokenizer + else: + Variant.tokenizer = Gemma4Tokenizer + return Variant + +Gemma4_E4B = _make_variant(Gemma4Config) +Gemma4_E2B = _make_variant(Gemma4_E2B_Config) +Gemma4_31B = _make_variant(Gemma4_31B_Config) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 6ea8e36b1..a34c41144 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -521,7 +521,7 @@ class Attention(nn.Module): else: present_key_value = (xk, xv, index + num_tokens) - if sliding_window is not None and xk.shape[2] > sliding_window: + if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1: xk = xk[:, :, -sliding_window:] xv = xv[:, :, -sliding_window:] attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None @@ -533,12 +533,12 @@ class Attention(nn.Module): return self.o_proj(output), present_key_value class MLP(nn.Module): - def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None): super().__init__() - ops = ops or nn - self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) - self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) - self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) + intermediate_size = intermediate_size or config.intermediate_size + self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) if config.mlp_activation == "silu": self.activation = torch.nn.functional.silu elif config.mlp_activation == "gelu_pytorch_tanh": @@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module): return x, present_key_value +def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype): + class ScaledEmbedding(ops.Embedding): + def forward(self, input_ids, out_dtype=None): + return super().forward(input_ids, out_dtype=out_dtype) * scale + return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype) + + class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.config = config self.vocab_size = config.vocab_size - self.embed_tokens = ops.Embedding( - config.vocab_size, - config.hidden_size, - device=device, - dtype=dtype - ) if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": transformer = TransformerBlockGemma2 - self.normalize_in = True + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) else: transformer = TransformerBlock - self.normalize_in = False + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ transformer(config, index=i, device=device, dtype=dtype, ops=ops) @@ -690,15 +691,12 @@ class Llama2_(nn.Module): self.config.rope_dims, device=device) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None): if embeds is not None: x = embeds else: x = self.embed_tokens(x, out_dtype=dtype) - if self.normalize_in: - x *= self.config.hidden_size ** 0.5 - seq_len = x.shape[1] past_len = 0 if past_key_values is not None and len(past_key_values) > 0: @@ -850,7 +848,7 @@ class BaseGenerate: torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) return past_key_values - def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0): + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None): device = embeds.device if stop_tokens is None: @@ -875,14 +873,16 @@ class BaseGenerate: pbar = comfy.utils.ProgressBar(max_length) # Generation loop + current_input_ids = initial_input_ids for step in tqdm(range(max_length), desc="Generating tokens"): - x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids) logits = self.logits(x)[:, -1] next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) token_id = next_token[0].item() generated_token_ids.append(token_id) embeds = self.model.embed_tokens(next_token).to(execution_dtype) + current_input_ids = next_token if initial_input_ids is not None else None pbar.update(1) if token_id in stop_tokens: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5aee1f4c0..bc5cbae28 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty): tokens_only = [[t[0] for t in b] for b in tokens] - embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) - comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is class DualLinearProjection(torch.nn.Module): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 01ebdfabe..b1f1dbb9f 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) def process_tokens(self, tokens, device): - embeds, _, _, embeds_info = super().process_tokens(tokens, device) - comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + embeds, _, _, _ = super().process_tokens(tokens, device) return embeds class LuminaModel(sd1_clip.SD1ClipModel): diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index ce9b07464..d8ed9cd32 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_): nn.Module.__init__(self) self.config = config self.vocab_size = config.vocab_size - self.normalize_in = False - self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops) diff --git a/comfy/utils.py b/comfy/utils.py index 78c491b98..7b7faad3a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res -def normalize_image_embeddings(embeds, embeds_info, scale_factor): - """Normalize image embeddings to match text embedding scale""" - for info in embeds_info: - if info.get("type") == "image": - start_idx = info["index"] - end_idx = start_idx + info["size"] - embeds[:, start_idx:end_idx, :] /= scale_factor diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 1f46d820f..1661a1011 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode): io.Clip.Input("clip"), io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), io.Image.Input("image", optional=True), + io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."), + io.Audio.Input("audio", optional=True), io.Int.Input("max_length", default=256, min=1, max=2048), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), @@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput: - tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking) + tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio) # Get sampling parameters from dynamic combo do_sample = sampling_mode.get("sampling_mode") == "on" @@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode): seed=seed ) - generated_text = clip.decode(generated_ids, skip_special_tokens=True) + generated_text = clip.decode(generated_ids) + return io.NodeOutput(generated_text) @@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput: if image is None: formatted_prompt = f"system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}\nuser\nUser Raw Input Prompt: {prompt}.\nmodel\n" else: formatted_prompt = f"system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}\nuser\n\n\n\nUser Raw Input Prompt: {prompt}.\nmodel\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template) + return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio) class TextgenExtension(ComfyExtension): From f6d5068ac0163e7f626c9cec2e7c663cf6fa64a8 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sun, 3 May 2026 12:20:17 +0800 Subject: [PATCH 14/25] Update README (#13679) Updated the README to include a new screenshot, improved description and add Ernie Image to supported models. --- README.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3b5114633..ee68e8bb8 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
# ComfyUI -**The most powerful and modular visual AI engine and application.** +**The most powerful and modular AI engine for content creation.** [![Website][website-shield]][website-url] @@ -31,10 +31,15 @@ [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases -![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe) +ComfyUI Screenshot
-ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS. +ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more... +- ComfyUI natively supports the latest open-source state of the art models. +- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc. +- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud. +- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode. +- It integrates seamlessly into production pipelines with our API endpoints. ## Get Started @@ -77,6 +82,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) + - Ernie Image - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From b5bb83c964519b7574ce9229b2314e04c17592c0 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sun, 3 May 2026 18:17:08 +0800 Subject: [PATCH 15/25] Fix issue blend images with alpha (#13615) Make ImageBlend and ImageCompositeMasked nodes handle images with different channel counts --- node_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/node_helpers.py b/node_helpers.py index d3d834516..cac4e88dd 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -86,6 +86,6 @@ def image_alpha_fix(destination, source): if destination.shape[-1] < source.shape[-1]: source = source[...,:destination.shape[-1]] elif destination.shape[-1] > source.shape[-1]: - destination = torch.nn.functional.pad(destination, (0, 1)) - destination[..., -1] = 1.0 + source = torch.nn.functional.pad(source, (0, 1)) + source[..., -1] = 1.0 return destination, source From d0f0b15cf5d1fbff67390c8d90ec8654c2582f7a Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sun, 3 May 2026 18:48:58 +0800 Subject: [PATCH 16/25] Update ComfyUI screenshot in README (#13683) Update ComfyUI screenshot to showcase a more modern workflow --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ee68e8bb8..a3bd3ba0a 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases -ComfyUI Screenshot +ComfyUI Screenshot +
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more... From 867b8d2408a8f3062f25bd6707a4b96755d70e1d Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Sun, 3 May 2026 05:44:20 -0700 Subject: [PATCH 17/25] fix: gracefully handle port-in-use error on server startup (#13001) Catch EADDRINUSE OSError when binding the TCP site and exit with a clear error message instead of an unhandled traceback. --- server.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 881da8e66..2f3b438bb 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,4 @@ +import errno import os import sys import asyncio @@ -1245,7 +1246,13 @@ class PromptServer(): address = addr[0] port = addr[1] site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) - await site.start() + try: + await site.start() + except OSError as e: + if e.errno == errno.EADDRINUSE: + logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.") + raise SystemExit(1) + raise if not hasattr(self, 'address'): self.address = address #TODO: remove this From 025e6792ee64181ddce8a84411e0c7311e00b179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sun, 3 May 2026 16:30:00 +0300 Subject: [PATCH 18/25] Batch broadcasting in JoinImageWithAlpha node (#13686) * Batch broadcasting in JoinImageWithAlpha node --- comfy_extras/nodes_compositing.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 3bc9fccb3..5b4423734 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode): @classmethod def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: - batch_size = min(len(image), len(alpha)) - out_images = [] - + batch_size = max(len(image), len(alpha)) alpha = 1.0 - resize_mask(alpha, image.shape[1:]) - for i in range(batch_size): - out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) - - return io.NodeOutput(torch.stack(out_images)) + alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size) + image = comfy.utils.repeat_to_batch_size(image, batch_size) + return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1)) class CompositingExtension(ComfyExtension): From b138133ffa43541c85b5f9ca57f449c8345ca005 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sun, 3 May 2026 20:07:21 +0200 Subject: [PATCH 19/25] Enable triton comfy kitchen via cli-arg (#12730) --- comfy/cli_args.py | 1 + comfy/quant_ops.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cef1a5e6b..d2fde8b67 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE" parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") +parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.") class LatentPreviewMethod(enum.Enum): NoPreviews = "none" diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 42ee08fb2..b90bcfd25 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,6 +1,8 @@ import torch import logging +from comfy.cli_args import args + try: import comfy_kitchen as ck from comfy_kitchen.tensor import ( @@ -21,7 +23,15 @@ try: ck.registry.disable("cuda") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") - ck.registry.disable("triton") + if args.enable_triton_backend: + try: + import triton + logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__) + except ImportError as e: + logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.") + ck.registry.disable("triton") + else: + ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") except ImportError as e: From cea8d0925febb4dd32e400bbbf94243f55af3371 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 3 May 2026 13:18:27 -0700 Subject: [PATCH 20/25] Refactor LoadImageMask to use LoadImage code. (#13687) --- nodes.py | 66 +++++++++++++++++++++++++------------------------------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/nodes.py b/nodes.py index 710cccffe..8f8f90cf6 100644 --- a/nodes.py +++ b/nodes.py @@ -1754,57 +1754,49 @@ class LoadImage: return True -class LoadImageMask: + +class LoadImageMask(LoadImage): ESSENTIALS_CATEGORY = "Image Tools" SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] _color_channels = ["alpha", "red", "green", "blue"] + @classmethod def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] - return {"required": - {"image": (sorted(files), {"image_upload": True}), - "channel": (s._color_channels, ), } - } + types = super().INPUT_TYPES() + return { + "required": { + **types["required"], + "channel": (s._color_channels, ) + } + } CATEGORY = "mask" - RETURN_TYPES = ("MASK",) - FUNCTION = "load_image" - def load_image(self, image, channel): - image_path = folder_paths.get_annotated_filepath(image) - i = node_helpers.pillow(Image.open, image_path) - i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.getbands() != ("R", "G", "B", "A"): - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - i = i.convert("RGBA") - mask = None + FUNCTION = "load_image_mask" + + def load_image_mask(self, image, channel): + image_tensor, mask_tensor = super().load_image(image) c = channel[0].upper() - if c in i.getbands(): - mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 - mask = torch.from_numpy(mask) - if c == 'A': - mask = 1. - mask + + if c == 'A': + return (mask_tensor,) + + channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0) + + if channel_idx < image_tensor.shape[-1]: + return (image_tensor[..., channel_idx].clone(),) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (mask.unsqueeze(0),) + empty_mask = torch.zeros( + image_tensor.shape[:-1], + dtype=image_tensor.dtype, + device=image_tensor.device + ) + return (empty_mask,) @classmethod def IS_CHANGED(s, image, channel): - image_path = folder_paths.get_annotated_filepath(image) - m = hashlib.sha256() - with open(image_path, 'rb') as f: - m.update(f.read()) - return m.digest().hex() - - @classmethod - def VALIDATE_INPUTS(s, image): - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - - return True + return super().IS_CHANGED(image) class LoadImageOutput(LoadImage): From 2806163f6e06465bacb1b16906cd17a8b78c9610 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 3 May 2026 16:21:34 -0700 Subject: [PATCH 21/25] Default control_after_generate to fixed in PrimitiveInt node (#13690) --- comfy_extras/nodes_primitive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 9c2e98758..3c8f90b19 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -49,7 +49,7 @@ class Int(io.ComfyNode): display_name="Int", category="utils/primitive", inputs=[ - io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True), + io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed), ], outputs=[io.Int.Output()], ) From 5538f62b0b81102c382849fd90469283c725b212 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Mon, 4 May 2026 12:33:11 +0800 Subject: [PATCH 22/25] fix: Update ColorTransfer node ref_image to be mandatory (#13691) --- comfy_extras/nodes_post_processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c932b747a..345fdb695 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -666,12 +666,13 @@ class ColorTransfer(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ColorTransfer", + display_name="Color Transfer", category="image/postprocessing", description="Match the colors of one image to another using various algorithms.", search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], inputs=[ io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), - io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), + io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."), io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), io.DynamicCombo.Input("source_stats", tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", From f3ea976cba8743a87efeb9fbca717309e3d65c47 Mon Sep 17 00:00:00 2001 From: Soof Golan <83900570+soof-golan@users.noreply.github.com> Date: Mon, 4 May 2026 10:01:46 +0200 Subject: [PATCH 23/25] Fix a1111 typo in extra_model_paths.yaml (#2720) --- extra_model_paths.yaml.example | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 34df01681..9c395c0b2 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -28,7 +28,7 @@ #config for a1111 ui #all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed -#a111: +#a1111: # base_path: path/to/stable-diffusion-webui/ # checkpoints: models/Stable-diffusion # configs: models/Stable-diffusion From c33d26c283ea53b8ba3e42ef3dca1f03ddf4d7b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 4 May 2026 20:20:40 +0300 Subject: [PATCH 24/25] fix: Proper memory estimation for frame interpolation when not using dynamic VRAM (#13698) --- comfy_extras/frame_interpolation_models/film_net.py | 3 +++ comfy_extras/frame_interpolation_models/ifnet.py | 3 +++ comfy_extras/nodes_frame_interpolation.py | 11 ++++------- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py index cf4f6e1e1..36bc79dc3 100644 --- a/comfy_extras/frame_interpolation_models/film_net.py +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -199,6 +199,9 @@ class FILMNet(nn.Module): def get_dtype(self): return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + def memory_used_forward(self, shape, dtype): + return 1700 * shape[1] * shape[2] * dtype.itemsize + def _build_warp_grids(self, H, W, device): """Pre-compute warp grids for all pyramid levels.""" if (H, W) in self._warp_grids: diff --git a/comfy_extras/frame_interpolation_models/ifnet.py b/comfy_extras/frame_interpolation_models/ifnet.py index 03cb34c50..ad6edbec9 100644 --- a/comfy_extras/frame_interpolation_models/ifnet.py +++ b/comfy_extras/frame_interpolation_models/ifnet.py @@ -74,6 +74,9 @@ class IFNet(nn.Module): def get_dtype(self): return self.encode.cnn0.weight.dtype + def memory_used_forward(self, shape, dtype): + return 300 * shape[1] * shape[2] * dtype.itemsize + def _build_warp_grids(self, H, W, device): if (H, W) in self._warp_grids: return diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index a3b00d36e..fa49c203a 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -37,7 +37,7 @@ class FrameInterpolationModelLoader(io.ComfyNode): model = cls._detect_and_load(sd) dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 model.eval().to(dtype) - patcher = comfy.model_patcher.ModelPatcher( + patcher = comfy.model_patcher.CoreModelPatcher( model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), @@ -98,16 +98,13 @@ class FrameInterpolate(io.ComfyNode): if num_frames < 2 or multiplier < 2: return io.NodeOutput(images) - model_management.load_model_gpu(interp_model) device = interp_model.load_device dtype = interp_model.model_dtype() inference_model = interp_model.model - - # Free VRAM for inference activations (model weights + ~20x a single frame's worth) - H, W = images.shape[1], images.shape[2] - activation_mem = H * W * 3 * images.element_size() * 20 - model_management.free_memory(activation_mem, device) + activation_mem = inference_model.memory_used_forward(images.shape, dtype) + model_management.load_models_gpu([interp_model], memory_required=activation_mem) align = getattr(inference_model, "pad_align", 1) + H, W = images.shape[1], images.shape[2] # Prepare a single padded frame on device for determining output dimensions def prepare_frame(idx): From c47633f3befbc32bf5aeece6a899c20d55a9feb1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 5 May 2026 05:56:05 +1000 Subject: [PATCH 25/25] prefetch: guard against no offload (#13703) cast_ will return no stream if there is no work to do. guard against this is the consume logic. --- comfy/model_prefetch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py index 0ad35deb5..72e11dec6 100644 --- a/comfy/model_prefetch.py +++ b/comfy/model_prefetch.py @@ -37,7 +37,8 @@ def prefetch_queue_pop(queue, device, module): consumed = queue.pop(0) if consumed is not None: offload_stream, prefetch_state = consumed - offload_stream.wait_stream(comfy.model_management.current_stream(device)) + if offload_stream is not None: + offload_stream.wait_stream(comfy.model_management.current_stream(device)) _, comfy_modules = prefetch_state if comfy_modules is not None: cleanup_prefetched_modules(comfy_modules)