From 94ee49b1612824366a8631ea069b2a1fa5c73720 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Thu, 18 Jun 2026 12:30:57 -0700 Subject: [PATCH 1/8] harden: load training-dataset shards with weights_only=True (#14543) LoadTrainingDataset was the only torch.load call in the codebase without weights_only=True; comfy/utils.py and comfy/sd1_clip.py already pass it. Recent PyTorch defaults to weights_only=True, so this is defense-in-depth for installs pinned to older PyTorch. Verified a typical shard (latents + standard conditioning) round-trips cleanly under weights_only=True. --- comfy_extras/nodes_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 0253b4b4f..73fe75b7f 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1583,7 +1583,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f) + shard_data = torch.load(f, weights_only=True) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) From 5ef0092af943cf17c0cd3ba3ea5507137cdc37d0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 18 Jun 2026 19:32:55 -0700 Subject: [PATCH 2/8] Move comfy sys path insert to custom node loading. (#14459) --- nodes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 0b3fdab63..b1a663f4c 100644 --- a/nodes.py +++ b/nodes.py @@ -20,8 +20,6 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch -sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) - import comfy.diffusers_load import comfy.samplers import comfy.sample @@ -2299,6 +2297,9 @@ async def init_external_custom_nodes(): Returns: None """ + # TODO: remove at some point when custom nodes don't break. + sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) + base_node_names = set(NODE_CLASS_MAPPINGS.keys()) node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] From 5955ddff52a2eda2ba0cf7f3fb0927c93fb2fbb8 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 19 Jun 2026 08:46:07 +0300 Subject: [PATCH 3/8] [Partner Nodes] feat(Luma): add support for Luma Rays 3.2 (#14540) Signed-off-by: bigcat88 --- comfy_api_nodes/apis/luma.py | 142 ++++++--- comfy_api_nodes/nodes_luma.py | 564 +++++++++++++++++++++++++++++++++- 2 files changed, 652 insertions(+), 54 deletions(-) diff --git a/comfy_api_nodes/apis/luma.py b/comfy_api_nodes/apis/luma.py index 8c6db2022..2465c3b37 100644 --- a/comfy_api_nodes/apis/luma.py +++ b/comfy_api_nodes/apis/luma.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, confloat class LumaIO: LUMA_REF = "LUMA_REF" LUMA_CONCEPTS = "LUMA_CONCEPTS" + LUMA_RAY32_KEYFRAME = "LUMA_RAY32_KEYFRAME" class LumaReference: @@ -20,13 +21,14 @@ class LumaReference: def create_api_model(self, download_url: str): return LumaImageRef(url=download_url, weight=self.weight) + class LumaReferenceChain: - def __init__(self, first_ref: LumaReference=None): + def __init__(self, first_ref: LumaReference = None): self.refs: list[LumaReference] = [] if first_ref: self.refs.append(first_ref) - def add(self, luma_ref: LumaReference=None): + def add(self, luma_ref: LumaReference = None): self.refs.append(luma_ref) def create_api_model(self, download_urls: list[str], max_refs=4): @@ -124,7 +126,7 @@ def get_luma_concepts(include_none=False): "pull_out", "aerial", "crane_up", - "eye_level" + "eye_level", ] @@ -162,8 +164,8 @@ class LumaVideoModelOutputDuration(str, Enum): class LumaGenerationType(str, Enum): - video = 'video' - image = 'image' + video = "video" + image = "image" class LumaState(str, Enum): @@ -174,86 +176,109 @@ class LumaState(str, Enum): class LumaAssets(BaseModel): - video: Optional[str] = Field(None, description='The URL of the video') - image: Optional[str] = Field(None, description='The URL of the image') - progress_video: Optional[str] = Field(None, description='The URL of the progress video') + video: Optional[str] = Field(None, description="The URL of the video") + image: Optional[str] = Field(None, description="The URL of the image") + progress_video: Optional[str] = Field(None, description="The URL of the progress video") class LumaImageRef(BaseModel): """Used for image gen""" - url: str = Field(..., description='The URL of the image reference') - weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') + + url: str = Field(..., description="The URL of the image reference") + weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference") class LumaImageReference(BaseModel): """Used for video gen""" - type: Optional[str] = Field('image', description='Input type, defaults to image') - url: str = Field(..., description='The URL of the image') + + type: Optional[str] = Field("image", description="Input type, defaults to image") + url: str = Field(..., description="The URL of the image") class LumaModifyImageRef(BaseModel): - url: str = Field(..., description='The URL of the image reference') - weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') + url: str = Field(..., description="The URL of the image reference") + weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference") class LumaCharacterRef(BaseModel): - identity0: LumaImageIdentity = Field(..., description='The image identity object') + identity0: LumaImageIdentity = Field(..., description="The image identity object") class LumaImageIdentity(BaseModel): - images: list[str] = Field(..., description='The URLs of the image identity') + images: list[str] = Field(..., description="The URLs of the image identity") class LumaGenerationReference(BaseModel): - type: str = Field('generation', description='Input type, defaults to generation') - id: str = Field(..., description='The ID of the generation') + type: str = Field("generation", description="Input type, defaults to generation") + id: str = Field(..., description="The ID of the generation") class LumaKeyframes(BaseModel): - frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') - frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') + frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="") + frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="") class LumaConceptObject(BaseModel): - key: str = Field(..., description='Camera Concept name') + key: str = Field(..., description="Camera Concept name") class LumaImageGenerationRequest(BaseModel): - prompt: str = Field(..., description='The prompt of the generation') - model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation') - aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation') - image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects') - style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects') - character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object') - modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object') + prompt: str = Field(..., description="The prompt of the generation") + model: LumaImageModel = Field(LumaImageModel.photon_1, description="The image model used for the generation") + aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9) + image_ref: Optional[list[LumaImageRef]] = Field(None, description="List of image reference objects") + style_ref: Optional[list[LumaImageRef]] = Field(None, description="List of style reference objects") + character_ref: Optional[LumaCharacterRef] = Field(None, description="The image identity object") + modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description="The modify image reference object") class LumaGenerationRequest(BaseModel): - prompt: str = Field(..., description='The prompt of the generation') - model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation') - duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation') - aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation') - resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation') - loop: Optional[bool] = Field(None, description='Whether to loop the video') - keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation') - concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation') + prompt: str = Field(..., description="The prompt of the generation") + model: LumaVideoModel = Field(LumaVideoModel.ray_2, description="The video model used for the generation") + duration: Optional[LumaVideoModelOutputDuration] = Field(None, description="The duration of the generation") + aspect_ratio: Optional[LumaAspectRatio] = Field(None, description="The aspect ratio of the generation") + resolution: Optional[LumaVideoOutputResolution] = Field(None, description="The resolution of the generation") + loop: Optional[bool] = Field(None, description="Whether to loop the video") + keyframes: Optional[LumaKeyframes] = Field(None, description="The keyframes of the generation") + concepts: Optional[list[LumaConceptObject]] = Field(None, description="Camera Concepts to apply to generation") class LumaGeneration(BaseModel): - id: str = Field(..., description='The ID of the generation') - generation_type: LumaGenerationType = Field(..., description='Generation type, image or video') - state: LumaState = Field(..., description='The state of the generation') - failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation') - created_at: str = Field(..., description='The date and time when the generation was created') - assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') - model: str = Field(..., description='The model used for the generation') - request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") + id: str = Field(..., description="The ID of the generation") + generation_type: LumaGenerationType = Field(..., description="Generation type, image or video") + state: LumaState = Field(..., description="The state of the generation") + failure_reason: Optional[str] = Field(None, description="The reason for the state of the generation") + created_at: str = Field(..., description="The date and time when the generation was created") + assets: Optional[LumaAssets] = Field(None, description="The assets of the generation") + model: str = Field(..., description="The model used for the generation") + request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(...) class Luma2ImageRef(BaseModel): url: str | None = None data: str | None = None media_type: str | None = None + generation_id: str | None = Field(None, description="reference a prior generation (extend / source reuse)") + + +class Luma2VideoEdit(BaseModel): + """Edit controls for Ray 3.2 ``video_edit`` generations.""" + + auto_controls: bool | None = Field(None, description="derive a conditioning schedule from the source (recommended)") + strength: str | None = Field(None, description="'adhere_1' .. 'reimagine_3'; constrained by IO.Combo") + + +class Luma2VideoOptions(BaseModel): + """Ray 3.2 ``video`` output settings (text / image / keyframe / edit / extend).""" + + resolution: str | None = Field(None, description="360p | 540p | 720p | 1080p") + duration: str | None = Field(None, description="5s | 10s") + loop: bool | None = Field(None) + start_frame: Luma2ImageRef | None = Field(None) + end_frame: Luma2ImageRef | None = Field(None) + keyframes: list[Luma2ImageRef] | None = Field(None) + keyframe_indexes: list[int] | None = Field(None) + edit: Luma2VideoEdit | None = Field(None) class Luma2GenerationRequest(BaseModel): @@ -266,6 +291,7 @@ class Luma2GenerationRequest(BaseModel): web_search: bool | None = None image_ref: list[Luma2ImageRef] | None = None source: Luma2ImageRef | None = None + video: Luma2VideoOptions | None = Field(None) class Luma2Generation(BaseModel): @@ -277,3 +303,31 @@ class Luma2Generation(BaseModel): output: list[LumaImageReference] | None = None failure_reason: str | None = None failure_code: str | None = None + + +# --- Ray 3.2 multi-keyframe chain --- + +LUMA_KEYFRAME_MODE_FRACTION = "fraction" # value in [0.0, 1.0] of the output video duration +LUMA_KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the output + + +class LumaRay32KeyframeItem: + """One guide image anchored at a position on the Ray 3.2 output timeline.""" + + def __init__(self, image: torch.Tensor, mode: str, value: float): + self.image = image + self.mode = mode # LUMA_KEYFRAME_MODE_FRACTION | LUMA_KEYFRAME_MODE_SECONDS + self.value = value + + +class LumaRay32KeyframeChain: + def __init__(self): + self.items: list[LumaRay32KeyframeItem] = [] + + def add(self, item: LumaRay32KeyframeItem) -> None: + self.items.append(item) + + def clone(self) -> "LumaRay32KeyframeChain": + c = LumaRay32KeyframeChain() + c.items = list(self.items) + return c diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 0d31ac77e..cdfa32d8b 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -3,9 +3,13 @@ from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.luma import ( + LUMA_KEYFRAME_MODE_FRACTION, + LUMA_KEYFRAME_MODE_SECONDS, Luma2Generation, Luma2GenerationRequest, Luma2ImageRef, + Luma2VideoEdit, + Luma2VideoOptions, LumaAspectRatio, LumaCharacterRef, LumaConceptChain, @@ -18,6 +22,8 @@ from comfy_api_nodes.apis.luma import ( LumaIO, LumaKeyframes, LumaModifyImageRef, + LumaRay32KeyframeChain, + LumaRay32KeyframeItem, LumaReference, LumaReferenceChain, LumaVideoModel, @@ -33,6 +39,7 @@ from comfy_api_nodes.util import ( sync_op, upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_string, ) @@ -692,7 +699,10 @@ async def _luma2_upload_image_refs( async def _luma2_submit_and_poll( cls: type[IO.ComfyNode], request: Luma2GenerationRequest, -) -> Input.Image: + *, + estimated_duration: int | None = None, +) -> Luma2Generation: + """Submit a Luma Agents generation and poll until done; returns the completed generation.""" initial = await sync_op( cls, ApiEndpoint(path="/proxy/luma_2/generations", method="POST"), @@ -700,21 +710,21 @@ async def _luma2_submit_and_poll( data=request, ) if not initial.id: - raise RuntimeError("Luma 2 API did not return a generation id.") + raise RuntimeError("Luma API did not return a generation id.") final = await poll_op( cls, ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"), response_model=Luma2Generation, status_extractor=lambda r: r.state, progress_extractor=lambda r: None, + estimated_duration=estimated_duration, ) - if not final.output: + if not final.output or not final.output[0].url: msg = final.failure_reason or "no output returned" - raise RuntimeError(f"Luma 2 generation failed: {msg}") - url = final.output[0].url - if not url: - raise RuntimeError("Luma 2 generation completed without an output URL.") - return await download_url_to_image_tensor(url) + if final.failure_code: + msg = f"{msg} [{final.failure_code}]" + raise RuntimeError(f"Luma generation failed: {msg}") + return final class LumaImageNode(IO.ComfyNode): @@ -843,7 +853,8 @@ class LumaImageNode(IO.ComfyNode): web_search=model["web_search"], image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9), ) - return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + final = await _luma2_submit_and_poll(cls, request) + return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url)) class LumaImageEditNode(IO.ComfyNode): @@ -929,7 +940,533 @@ class LumaImageEditNode(IO.ComfyNode): web_search=model["web_search"], image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8), ) - return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + final = await _luma2_submit_and_poll(cls, request) + return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url)) + + +_BADGE_RAY32_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]), + expr=""" + ( + $p := { + "360p": {"5s": 0.06, "10s": 0.18}, + "540p": {"5s": 0.15, "10s": 0.45}, + "720p": {"5s": 0.3, "10s": 0.9}, + "1080p": {"5s": 1.2, "10s": 3.6} + }; + {"type": "usd", "usd": $lookup($lookup($p, widgets.resolution), widgets.duration)} + ) + """, +) + +_BADGE_RAY32_VIDEO_5S = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := {"360p": 0.06, "540p": 0.15, "720p": 0.3, "1080p": 1.2}; + {"type": "usd", "usd": $lookup($p, widgets.resolution)} + ) + """, +) + +_BADGE_RAY32_EDIT = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := { + "360p": {"min": 0.54, "max": 1.08}, + "540p": {"min": 0.72, "max": 1.44}, + "720p": {"min": 1.08, "max": 2.16}, + "1080p": {"min": 2.16, "max": 4.32} + }; + $r := $lookup($p, widgets.resolution); + {"type": "range_usd", "min_usd": $r.min, "max_usd": $r.max, "format": {"note": "(by source length)"}} + ) + """, +) + +_BADGE_RAY32_REFRAME = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := {"360p": 0.03, "540p": 0.06, "720p": 0.12, "1080p": 0.36}; + {"type": "usd", "usd": $lookup($p, widgets.resolution), "format": {"suffix": "/second"}} + ) + """, +) + + +def _ray32_seed_input() -> IO.Input: + return IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; results are nondeterministic regardless of seed.", + ) + + +async def _ray32_generate(cls: type[IO.ComfyNode], request: Luma2GenerationRequest) -> IO.NodeOutput: + """Run a ray-3.2 generation and return (video, generation_id).""" + final = await _luma2_submit_and_poll(cls, request, estimated_duration=120) + video = await download_url_to_video_output(final.output[0].url) + return IO.NodeOutput(video, final.id or "") + + +class LumaRay32TextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32TextToVideoNode", + display_name="Luma Ray 3.2 Text to Video", + category="partner/video/Luma", + description="Generate a video from a text prompt using Luma's Ray 3.2 model.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input("duration", options=["5s", "10s"]), + IO.Boolean.Input( + "loop", + default=False, + tooltip="Make the video loop seamlessly. Only available with 5s duration.", + ), + _ray32_seed_input(), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO, + ) + + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, resolution: str, duration: str, loop: bool, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + if loop and duration == "10s": + raise ValueError("Looping is only available with 5s duration on Ray 3.2.") + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video", + aspect_ratio=aspect_ratio, + video=Luma2VideoOptions(resolution=resolution, duration=duration, loop=loop or None), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32ImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32ImageToVideoNode", + display_name="Luma Ray 3.2 Image to Video", + category="partner/video/Luma", + description="Generate a video from a start and/or end frame using Luma's Ray 3.2 model. " + "Image-anchored generations are always 5 seconds.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Boolean.Input( + "loop", + default=False, + tooltip="Make the video loop seamlessly. Not available when an end_frame is set.", + ), + _ray32_seed_input(), + IO.Image.Input("start_frame", optional=True, tooltip="First frame of the generated video."), + IO.Image.Input("end_frame", optional=True, tooltip="Last frame of the generated video."), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO_5S, + ) + + @classmethod + async def execute( + cls, + prompt: str, + resolution: str, + loop: bool, + seed: int, + start_frame: torch.Tensor | None = None, + end_frame: torch.Tensor | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + if start_frame is None and end_frame is None: + raise ValueError("Provide at least one of start_frame / end_frame.") + if loop and end_frame is not None: + raise ValueError("Looping is not available when an end_frame is set.") + video = Luma2VideoOptions(resolution=resolution, duration="5s", loop=loop or None) + if start_frame is not None: + url = await upload_image_to_comfyapi(cls, start_frame, mime_type="image/png") + video.start_frame = Luma2ImageRef(url=url) + if end_frame is not None: + url = await upload_image_to_comfyapi(cls, end_frame, mime_type="image/png") + video.end_frame = Luma2ImageRef(url=url) + request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video) + return await _ray32_generate(cls, request) + + +class LumaRay32KeyframeNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32KeyframeNode", + display_name="Luma Ray 3.2 Keyframe", + category="partner/video/Luma", + description="Anchor a guide image to a position on the Ray 3.2 output video timeline. Connect this to " + "the 'keyframes' input of the Luma Ray 3.2 Keyframes to Video node; chain several together via the " + "optional 'keyframes' input below.", + inputs=[ + IO.Image.Input("image", tooltip="Guide image to place at the chosen moment of the output video."), + IO.DynamicCombo.Input( + "position", + options=[ + IO.DynamicCombo.Option( + "Fraction of duration (0.0-1.0)", + [ + IO.Float.Input( + "fraction", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Where in the output video this image applies " "(0.0 = start, 1.0 = end).", + ), + ], + ), + IO.DynamicCombo.Option( + "Absolute time (seconds)", + [ + IO.Float.Input( + "seconds", + default=0.0, + min=0.0, + max=10.0, + step=0.1, + display_mode=IO.NumberDisplay.number, + tooltip="Time in seconds from the start of the output video where this " + "image applies.", + ), + ], + ), + ], + tooltip="How to place this image on the output video's timeline.", + ), + IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input( + "keyframes", + optional=True, + tooltip="Optional earlier keyframes to chain with this one.", + ), + ], + outputs=[IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Output(display_name="keyframes")], + ) + + @classmethod + def execute( + cls, + image: torch.Tensor, + position: dict, + keyframes: LumaRay32KeyframeChain | None = None, + ) -> IO.NodeOutput: + chain = keyframes.clone() if keyframes is not None else LumaRay32KeyframeChain() + if position["position"] == "Absolute time (seconds)": + mode, value = LUMA_KEYFRAME_MODE_SECONDS, float(position["seconds"]) + else: + mode, value = LUMA_KEYFRAME_MODE_FRACTION, float(position["fraction"]) + chain.add(LumaRay32KeyframeItem(image=image, mode=mode, value=value)) + return IO.NodeOutput(chain) + + +class LumaRay32KeyframesToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32KeyframesToVideoNode", + display_name="Luma Ray 3.2 Keyframes to Video", + category="partner/video/Luma", + description="Generate a video that interpolates through a sequence of guide images, each anchored to a " + "position on the timeline, using Luma Ray 3.2. Build the sequence with Luma Ray 3.2 Keyframe nodes " + "(at least 2).", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input("duration", options=["5s", "10s"]), + _ray32_seed_input(), + IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input( + "keyframes", + tooltip="Keyframe sequence from Luma Ray 3.2 Keyframe nodes (at least 2).", + ), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO, + ) + + @classmethod + async def execute( + cls, + prompt: str, + resolution: str, + duration: str, + seed: int, + keyframes: LumaRay32KeyframeChain | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + items = keyframes.items if keyframes is not None else [] + if len(items) < 2: + raise ValueError( + "Connect at least 2 Luma Ray 3.2 Keyframe nodes " + "(use Luma Ray 3.2 Image to Video for a single start/end frame)." + ) + if len(items) > 64: + raise ValueError(f"Ray 3.2 supports at most 64 keyframes; got {len(items)}.") + maxframe = 120 if duration == "5s" else 240 + duration_seconds = maxframe / 24 # 5.0 or 10.0 + # Resolve each keyframe to an output-frame index, then order by position + # (so the user can chain keyframes in any order — the position is what places them) + placed: list[tuple[int, torch.Tensor]] = [] + for item in items: + if item.mode == LUMA_KEYFRAME_MODE_SECONDS: + if item.value > duration_seconds: + raise ValueError( + f"Keyframe position {item.value:g}s is past the end of the {duration} video; " + f"use 0-{duration_seconds:g}s (or switch the keyframe to fraction mode)." + ) + idx = round(item.value * 24) + else: + idx = round(item.value * maxframe) + placed.append((max(0, min(maxframe, idx)), item.image)) + placed.sort(key=lambda p: p[0]) + indexes = [idx for idx, _ in placed] + for a, b in zip(indexes, indexes[1:]): + if a == b: + raise ValueError( + f"Two keyframes resolve to the same output frame ({a}) for a {duration} video " + f"(valid range 0-{maxframe}); give each keyframe a distinct position." + ) + refs: list[Luma2ImageRef] = [] + for _, image in placed: + url = await upload_image_to_comfyapi(cls, image, mime_type="image/png") + refs.append(Luma2ImageRef(url=url)) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video", + video=Luma2VideoOptions(resolution=resolution, duration=duration, keyframes=refs, keyframe_indexes=indexes), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32VideoEditNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32VideoEditNode", + display_name="Luma Ray 3.2 Video Edit", + category="partner/video/Luma", + description="Re-render an existing video under a new prompt using Luma Ray 3.2 (restyle, relight, add " + "or remove elements) while keeping the original motion. Source video up to 18 seconds; the edited " + "video keeps the source's length.", + inputs=[ + IO.Video.Input("video", tooltip="Source video to edit. Up to 18 seconds."), + IO.String.Input("prompt", multiline=True, default="", tooltip="Describes the desired edit."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input( + "strength", + options=[ + "auto", + "adhere_1", + "adhere_2", + "adhere_3", + "flex_1", + "flex_2", + "flex_3", + "reimagine_1", + "reimagine_2", + "reimagine_3", + ], + default="auto", + tooltip="How strongly to preserve vs. reimagine the source. 'auto' lets Ray 3.2 choose; " + "adhere_* preserves the most, flex_* is balanced, reimagine_* changes the most.", + ), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_EDIT, + ) + + @classmethod + async def execute( + cls, video: Input.Video, prompt: str, resolution: str, strength: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + try: + duration = "5s" if video.get_duration() <= 5.0 else "10s" + except Exception: + duration = "10s" + source_url = await upload_video_to_comfyapi(cls, video, max_duration=18) + edit = Luma2VideoEdit(auto_controls=True) if strength == "auto" else Luma2VideoEdit(strength=strength) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video_edit", + source=Luma2ImageRef(url=source_url, media_type="video/mp4"), + video=Luma2VideoOptions(resolution=resolution, duration=duration, edit=edit), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32VideoReframeNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32VideoReframeNode", + display_name="Luma Ray 3.2 Video Reframe", + category="partner/video/Luma", + description="Change the aspect ratio of an existing video, using Luma Ray 3.2 to fill the newly " + "exposed canvas areas. Source video up to 30 seconds. Billed per second of output.", + inputs=[ + IO.Video.Input("video", tooltip="Source video to reframe. Up to 30 seconds."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Describes how the newly exposed canvas areas should be filled.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_REFRAME, + ) + + @classmethod + async def execute( + cls, video: Input.Video, prompt: str, aspect_ratio: str, resolution: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000) + if resolution == "1080p" and aspect_ratio in {"9:16", "3:4"}: + raise ValueError("1080p is not available for vertical aspect ratios (9:16, 3:4) when reframing.") + source_url = await upload_video_to_comfyapi(cls, video, max_duration=30) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video_reframe", + aspect_ratio=aspect_ratio, + source=Luma2ImageRef(url=source_url, media_type="video/mp4"), + video=Luma2VideoOptions(resolution=resolution), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32ExtendVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32ExtendVideoNode", + display_name="Luma Ray 3.2 Extend Video", + category="partner/video/Luma", + description="Extend a previous Ray 3.2 generation forward (continue after it) or backward (lead-in " + "before it). Connect the generation_id output of a prior Luma Ray 3.2 node." + " Extensions are always 5 seconds.", + inputs=[ + IO.String.Input( + "source_generation_id", + default="", + tooltip="generation_id of the prior Ray 3.2 video to extend." + " Connect the generation_id output of another Luma Ray 3.2 node.", + ), + IO.DynamicCombo.Input( + "direction", + options=[ + IO.DynamicCombo.Option( + "Forward (continue after)", + [ + IO.Boolean.Input( + "loop", + default=False, + tooltip="Loop the extended video seamlessly (forward extend only).", + ), + ], + ), + IO.DynamicCombo.Option("Backward (lead-in before)", []), + ], + tooltip="Forward continues after the prior clip; backward is prepended before it.", + ), + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the new content."), + IO.Combo.Input("resolution", options=["540p", "720p", "1080p"], default="720p"), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO_5S, + ) + + @classmethod + async def execute( + cls, source_generation_id: str, direction: dict, prompt: str, resolution: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000) + gen_id = (source_generation_id or "").strip() + if not gen_id: + raise ValueError( + "source_generation_id is required (connect the generation_id output of a prior Luma Ray 3.2 node)." + ) + video = Luma2VideoOptions(resolution=resolution, duration="5s") + ref = Luma2ImageRef(generation_id=gen_id) + if direction["direction"] == "Forward (continue after)": + video.start_frame = ref + if direction.get("loop"): + video.loop = True + else: + video.end_frame = ref + request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video) + return await _ray32_generate(cls, request) class LumaExtension(ComfyExtension): @@ -944,6 +1481,13 @@ class LumaExtension(ComfyExtension): LumaConceptsNode, LumaImageNode, LumaImageEditNode, + LumaRay32TextToVideoNode, + LumaRay32ImageToVideoNode, + LumaRay32KeyframeNode, + LumaRay32KeyframesToVideoNode, + LumaRay32VideoEditNode, + LumaRay32VideoReframeNode, + LumaRay32ExtendVideoNode, ] From bd39bbf0678ebd31c972fd365733a8c729f2cd74 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 19 Jun 2026 11:32:56 +0300 Subject: [PATCH 4/8] [Partner Nodes] fix: respect Retry-After header (#14234) Signed-off-by: bigcat88 --- comfy_api_nodes/util/_helpers.py | 28 ++++++++++++++++++++++++++++ comfy_api_nodes/util/client.py | 3 +++ 2 files changed, 31 insertions(+) diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 83cf7b001..6b8121cab 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -4,6 +4,8 @@ import os import re import time from collections.abc import Callable +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from io import BytesIO from yarl import URL @@ -91,6 +93,32 @@ async def sleep_with_interrupt( await asyncio.sleep(min(1.0, end - now)) +def _retry_after_wait(value: str | None, fallback: float, max_wait: float) -> float: + """Delay before the next retry, honoring a server ``Retry-After`` header.""" + + seconds: float | None = None + if value is not None: + value = value.strip() + if value.isascii() and value.isdigit(): + # delay-seconds form. The ASCII-digit guard keeps exotic Unicode "digit" characters away from float() + # an all-digit string always converts (huge values become inf, never raising). + seconds = float(value) + elif value: + # HTTP-date form. parsedate_to_datetime raises OverflowError (not a ValueError) on absurd years/offsets + try: + parsed = parsedate_to_datetime(value) + except (TypeError, ValueError, OverflowError): + parsed = None + if parsed is not None: + if parsed.tzinfo is None: # naive datetime: HTTP-date is UTC + parsed = parsed.replace(tzinfo=timezone.utc) + delta = (parsed - datetime.now(timezone.utc)).total_seconds() + seconds = delta if delta > 0 else 0.0 + if seconds is None: + return fallback + return min(seconds, max_wait) + + def mimetype_to_extension(mime_type: str) -> str: """Converts a MIME type to a file extension.""" return mime_type.split("/")[-1].lower() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index adcde7bcb..66aab17f8 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -21,6 +21,7 @@ from server import PromptServer from . import request_logger from ._helpers import ( + _retry_after_wait, default_base_url, get_comfy_api_headers, get_node_id, @@ -82,6 +83,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately +_MAX_RETRY_AFTER_WAIT = 150.0 # Cap a server Retry-After at this many seconds so a large hint can't block execution COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"] @@ -747,6 +749,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): should_retry = True if should_retry: + wait_time = _retry_after_wait(resp.headers.get("Retry-After"), wait_time, _MAX_RETRY_AFTER_WAIT) logging.warning( "HTTP %s %s -> %s. Waiting %.2fs (%s).", method, From bc11e8a65a57dc6bd1768edfca14ddb1523f7882 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Sat, 20 Jun 2026 08:01:34 +0900 Subject: [PATCH 5/8] Bump comfyui-frontend-package to 1.45.19 (#14559) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 392709e64..ad8b1c2ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.45.15 +comfyui-frontend-package==1.45.19 comfyui-workflow-templates==0.10.0 comfyui-embedded-docs==0.5.4 torch From 2ab3816dcf66d58f2c0b3e79e910311b21697e0d Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 19 Jun 2026 19:06:55 -0400 Subject: [PATCH 6/8] feat: add Load3DAdvanced node (#14316) --- comfy_extras/nodes_load_3d.py | 63 +++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 455897859..6e3e88471 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -317,11 +317,74 @@ class PreviewPointCloud(IO.ComfyNode): ) +MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + + +class Load3DAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + input_dir = os.path.join(folder_paths.get_input_directory(), "3d") + os.makedirs(input_dir, exist_ok=True) + + input_path = Path(input_dir) + base_path = Path(folder_paths.get_input_directory()) + + files = [ + normalize_path(str(file_path.relative_to(base_path))) + for file_path in input_path.rglob("*") + if file_path.suffix.lower() in MESH_EXTENSIONS + ] + return IO.Schema( + node_id="Load3DAdvanced", + display_name="Load 3D (Advanced)", + category="3d", + search_aliases=[ + "load mesh", + "load gltf", + "load glb", + "load obj", + "load fbx", + "load stl", + ], + is_experimental=True, + inputs=[ + IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model), + IO.Load3D.Input("viewport_state"), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def validate_inputs(cls, model_file, **kwargs) -> bool | str: + if not model_file or model_file == "none": + return True + if not folder_paths.exists_annotated_filepath(model_file): + return f"Invalid 3D model file: {model_file}" + return True + + @classmethod + def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + file_3d = None + if model_file and model_file != "none": + file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) + model_3d_info = viewport_state.get('model_3d_info', []) + return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height) + + class Load3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Load3D, + Load3DAdvanced, Preview3D, Preview3DAdvanced, PreviewGaussianSplat, From 4e716f7c5769fd7bdd851d95a323c2377dfeb5a7 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Fri, 19 Jun 2026 16:39:35 -0700 Subject: [PATCH 7/8] Add jobs-namespace cancel endpoints (POST /api/jobs/{job_id}/cancel, POST /api/jobs/cancel) (#14493) * Add jobs-namespace cancel endpoints Add two cancel endpoints under the jobs namespace so a job can be cancelled by id without the caller needing to know whether the job is running or pending, or branching between /interrupt and /queue. - POST /api/jobs/{job_id}/cancel cancels one job by id. Idempotent: an already-finished or unknown id returns 200 {"cancelled": false} rather than an error. - POST /api/jobs/cancel takes {"job_ids": [...]} and cancels a batch. Fail-fast: if any id is unknown the request returns 404 listing the unknown ids and cancels nothing (no partial side effects). Both are state-agnostic and map onto the existing queue mechanics: a running job is interrupted (same path as /interrupt), a pending job is dequeued (same path as /queue {"delete": [...]}). The cancel logic lives in comfy_execution.jobs as pure, unit-tested helpers; the server handlers are thin wrappers. openapi.yaml documents both routes. * fix: resolve review feedback on cancel endpoints - Guard cancel_job() against TOCTOU: when dequeue() returns False the pending job left the queue between snapshot and delete; return CANCEL_UNKNOWN so callers never report cancelled=True for a remove that did not happen. - Validate each job_ids element in the batch cancel endpoint before any queue access; unhashable or non-UUID values now return 400 instead of raising TypeError (500). - Update batch HTTP tests to use canonical UUID ids (required now that the endpoint validates id format) and add tests for the new guards. * fix: make job cancel atomic and best-effort Addresses two cancel races/edges raised in review. Targeted, atomic interrupt. cancel_job's interrupt callback now takes the prompt id and returns whether it fired; the single-cancel route backs it with the new PromptQueue.interrupt_if_running, which checks the running set and signals the interrupt under the queue mutex. This closes the TOCTOU where a pending job that starts executing between the snapshot and dequeue (or a running job that finishes between the snapshot and interrupt) could be missed or, worse, cause an unrelated prompt to be interrupted. The per-prompt interrupt-flag reset in execute_async keeps a finished job from leaking the interrupt onto its successor. Best-effort batch cancel. POST /api/jobs/cancel no longer fails the whole batch with 404 when one id is unknown/finished; such ids are treated as no-ops, so "cancel all" still cancels the in-progress jobs even if some finished between the client's snapshot and the request. Malformed ids are still rejected with 400. --- comfy_execution/jobs.py | 81 +++- execution.py | 19 + server.py | 111 ++++- tests-unit/jobs_cancel_test/__init__.py | 0 .../jobs_cancel_test/jobs_cancel_test.py | 453 ++++++++++++++++++ 5 files changed, 662 insertions(+), 2 deletions(-) create mode 100644 tests-unit/jobs_cancel_test/__init__.py create mode 100644 tests-unit/jobs_cancel_test/jobs_cancel_test.py diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 20ebae155..fa3ab0faf 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -4,11 +4,22 @@ Provides normalization and helper functions for job status tracking. """ import uuid -from typing import Optional +from typing import Callable, Optional from comfy_api.internal import prune_dict +# Result of classifying a job for cancellation. +# 'running' -> job is currently executing (interrupt it) +# 'pending' -> job is queued but not started (dequeue it) +# 'terminal' -> job already finished (present in history); cancel is a no-op +# 'unknown' -> job id is not present anywhere +CANCEL_RUNNING = 'running' +CANCEL_PENDING = 'pending' +CANCEL_TERMINAL = 'terminal' +CANCEL_UNKNOWN = 'unknown' + + class JobStatus: """Job status constants.""" PENDING = 'pending' @@ -407,3 +418,71 @@ def get_all_jobs( jobs = jobs[:limit] return (jobs, total_count) + + +def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str: + """Classify a job id for cancellation. + + Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN. + + Queue items are tuples whose second element (index 1) is the prompt_id. + History is a dict keyed by prompt_id, so a job present there has already + finished and cancelling it is a no-op. + """ + for item in running: + if item[1] == prompt_id: + return CANCEL_RUNNING + for item in queued: + if item[1] == prompt_id: + return CANCEL_PENDING + if prompt_id in history: + return CANCEL_TERMINAL + return CANCEL_UNKNOWN + + +def cancel_job( + prompt_id: str, + running: list, + queued: list, + history: dict, + interrupt: Callable[[str], bool], + dequeue: Callable[[str], bool], +) -> str: + """Cancel a single job by id, regardless of state. + + Maps the cancel onto the runtime's existing mechanics: + - a running job is interrupted via ``interrupt`` + - a pending job is removed from the queue via ``dequeue`` + - a job that already finished (terminal) is a no-op + - an unknown id is a no-op (callers that need fail-fast behaviour should + validate ids up front with ``classify_job_for_cancel``) + + Both ``interrupt`` and ``dequeue`` take the prompt id and return whether + they acted on a job that was *actually* in that state, so the value returned + here reflects what truly happened rather than the (possibly stale) + classification. This matters around the narrow TOCTOU windows where a job + changes state between the caller's snapshot and the action: + + - a job classified RUNNING may have finished before ``interrupt`` fires: + ``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op). + - a job classified PENDING may have started executing before ``dequeue`` + fires: ``dequeue`` returns False, ``interrupt`` then catches the now- + running job and this returns CANCEL_RUNNING. If it had simply finished + instead, both return False and this returns CANCEL_UNKNOWN. + + ``interrupt`` must be atomic — interrupt the job only if it is still the one + running — so a cancel can never land on an unrelated prompt that started in + the meantime (see ``execution.PromptQueue.interrupt_if_running``). + """ + classification = classify_job_for_cancel(prompt_id, running, queued, history) + if classification == CANCEL_RUNNING: + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + if classification == CANCEL_PENDING: + if dequeue(prompt_id): + return CANCEL_PENDING + # Left the pending queue between classification and dequeue: if it + # started executing, interrupt the now-running job; otherwise it has + # already finished and the cancel is a genuine no-op. + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. + return classification diff --git a/execution.py b/execution.py index 9e16e451d..c45317593 100644 --- a/execution.py +++ b/execution.py @@ -1308,6 +1308,25 @@ class PromptQueue: queued = copy.copy(self.queue) return (running, queued) + def interrupt_if_running(self, prompt_id): + """Interrupt the running prompt with this id, atomically. + + Checks the live running set and signals the interrupt under the queue + mutex, so the worker cannot move the job to done (and start the next + prompt) in between. Returns True if a matching job was running and an + interrupt was signalled, False otherwise. The atomicity is what keeps a + cancel from landing on an unrelated prompt that started after a separate + is-running check: the global interrupt flag is reset at the start of + every prompt (execute_async), so a job that finishes before consuming + the flag cannot leak the interrupt onto its successor. + """ + with self.mutex: + for item in self.currently_running.values(): + if item[1] == prompt_id: + nodes.interrupt_processing() + return True + return False + def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) diff --git a/server.py b/server.py index 6b0029adf..361850f38 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,15 @@ import time import nodes import folder_paths import execution -from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id +from comfy_execution.jobs import ( + JobStatus, + get_job, + get_all_jobs, + validate_job_id, + cancel_job, + CANCEL_PENDING, + CANCEL_RUNNING, +) import uuid import urllib import json @@ -899,6 +907,107 @@ class PromptServer(): return web.json_response(job) + def _cancel_job_by_id(job_id): + """Cancel a single job by id using the queue's existing mechanics. + + Running jobs are interrupted (same mechanism as /interrupt); pending + jobs are dequeued (same mechanism as /queue {"delete": [...]}). + Already-finished or unknown ids are no-ops. State-agnostic. + + Returns True when a cancel was actually dispatched (running or + pending job), False when the call was a no-op (terminal/unknown id). + """ + running, queued = self.prompt_queue.get_current_queue() + history = self.prompt_queue.get_history() + + def interrupt(prompt_id): + logging.info(f"Cancelling running prompt {prompt_id}") + # Atomic: only interrupts if the job is still the one running, + # so a cancel can't land on a prompt that started in the gap + # since the snapshot above. Returns whether it actually fired. + return self.prompt_queue.interrupt_if_running(prompt_id) + + def dequeue(prompt_id): + logging.info(f"Cancelling pending prompt {prompt_id}") + return self.prompt_queue.delete_queue_item(lambda a: a[1] == prompt_id) + + classification = cancel_job(job_id, running, queued, history, interrupt, dequeue) + return classification in (CANCEL_RUNNING, CANCEL_PENDING) + + @routes.post("/api/jobs/{job_id}/cancel") + async def cancel_job_by_id(request): + """Cancel a single job by id, regardless of state. + + Idempotent: cancelling a job that has already finished, or an id + that is not known, returns 200 with {"cancelled": false} rather + than an error. + """ + job_id = request.match_info.get("job_id", None) + if not job_id: + return web.json_response( + {"error": "job_id is required"}, + status=400 + ) + + cancelled = _cancel_job_by_id(job_id) + return web.json_response({"cancelled": cancelled}) + + @routes.post("/api/jobs/cancel") + async def cancel_jobs_batch(request): + """Cancel a batch of jobs by id. + + Body: {"job_ids": ["", ...]} + + Best-effort and idempotent: every well-formed id is cancelled if it + is running or pending; ids that are already finished or unknown are + no-ops, not errors. A batch of all no-ops still returns 200 with + {"cancelled": false}. This matches the single-cancel endpoint and + means "cancel all" still cancels the in-progress jobs even if some + finished between the client's snapshot and the request. Malformed + ids are still rejected up front with 400 (see below). + """ + try: + json_data = await request.json() + except json.JSONDecodeError: + return web.json_response( + {"error": "Request body must be valid JSON"}, + status=400 + ) + + job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None + if not isinstance(job_ids, list): + return web.json_response( + {"error": "job_ids must be a list"}, + status=400 + ) + + # Validate that every element is a well-formed job id before doing + # anything else. An unhashable element (e.g. a nested dict or list) + # would cause a TypeError when used as a history dict key; a + # non-string or non-UUID value is never a valid id. Reject early + # with 400 rather than letting the classify loop raise 500. + invalid_ids = [] + for jid in job_ids: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid if isinstance(jid, str) else repr(jid)) + if invalid_ids: + return web.json_response( + {"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids}, + status=400, + ) + + # Best-effort: cancel each id that is still running/pending; an id + # that has finished or never existed is a no-op rather than a reason + # to fail the whole batch. + cancelled = False + for jid in job_ids: + if _cancel_job_by_id(jid): + cancelled = True + + return web.json_response({"cancelled": cancelled}) + @routes.get("/history") async def get_history(request): max_items = request.rel_url.query.get("max_items", None) diff --git a/tests-unit/jobs_cancel_test/__init__.py b/tests-unit/jobs_cancel_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/jobs_cancel_test/jobs_cancel_test.py b/tests-unit/jobs_cancel_test/jobs_cancel_test.py new file mode 100644 index 000000000..f1d591b0d --- /dev/null +++ b/tests-unit/jobs_cancel_test/jobs_cancel_test.py @@ -0,0 +1,453 @@ +"""Tests for the jobs-namespace cancel endpoints. + +Covers both layers: + +* the pure cancel helpers in ``comfy_execution.jobs`` + (``classify_job_for_cancel`` / ``cancel_job``), which hold the business + logic of mapping a cancel onto interrupt-vs-dequeue, and + +* the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and + ``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and + best-effort batch cancellation that treats unknown/finished ids as no-ops + while still rejecting malformed ids with 400). + +The HTTP layer is exercised against a small aiohttp app whose handlers are a +faithful copy of the wiring in ``server.py`` driven by a fake queue that +mirrors ``execution.PromptQueue`` (``get_current_queue`` / ``get_history`` / +``delete_queue_item``). This keeps the test free of the heavy ComfyUI runtime +(torch, nodes, ...) while still testing the real cancel logic. +""" + +import json + +import pytest +from aiohttp import web + +from comfy_execution.jobs import ( + CANCEL_PENDING, + CANCEL_RUNNING, + CANCEL_TERMINAL, + CANCEL_UNKNOWN, + cancel_job, + classify_job_for_cancel, + validate_job_id, +) + +# Classifications for which a cancel was actually dispatched (vs a no-op). +_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING) + +# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format). +_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa" +_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb" +_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc" +_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd" +_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff" + + +def make_queue_item(prompt_id, number=0): + """Build a queue tuple shaped like the real ones: index 1 is the id.""" + return (number, prompt_id, {}, {}, []) + + +class FakePromptQueue: + """Minimal stand-in for execution.PromptQueue for the cancel paths. + + Tracks interrupts and dequeues so tests can assert side effects. + """ + + def __init__(self, running=None, pending=None, history=None): + self._running = list(running or []) + self._pending = list(pending or []) + self._history = dict(history or {}) + self.interrupt_count = 0 + + def get_current_queue(self): + return (list(self._running), list(self._pending)) + + def get_history(self, prompt_id=None): + if prompt_id is None: + return dict(self._history) + if prompt_id in self._history: + return {prompt_id: self._history[prompt_id]} + return {} + + def delete_queue_item(self, function): + for i, item in enumerate(self._pending): + if function(item): + self._pending.pop(i) + return True + return False + + def interrupt_if_running(self, prompt_id): + # Mirrors execution.PromptQueue.interrupt_if_running: only signals an + # interrupt when the id is actually in the running set. + if any(item[1] == prompt_id for item in self._running): + self.interrupt_count += 1 + return True + return False + + +def build_app(queue): + """Build an aiohttp app exposing the cancel routes against ``queue``. + + Handler bodies mirror server.py exactly. + """ + + def _cancel_job_by_id(job_id): + running, pending = queue.get_current_queue() + history = queue.get_history() + + def interrupt(prompt_id): + return queue.interrupt_if_running(prompt_id) + + def dequeue(prompt_id): + return queue.delete_queue_item(lambda a: a[1] == prompt_id) + + classification = cancel_job( + job_id, running, pending, history, interrupt, dequeue + ) + return classification in _CANCELLED + + async def cancel_job_by_id(request): + job_id = request.match_info.get("job_id", None) + if not job_id: + return web.json_response({"error": "job_id is required"}, status=400) + cancelled = _cancel_job_by_id(job_id) + return web.json_response({"cancelled": cancelled}) + + async def cancel_jobs_batch(request): + try: + json_data = await request.json() + except json.JSONDecodeError: + return web.json_response( + {"error": "Request body must be valid JSON"}, status=400 + ) + + job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None + if not isinstance(job_ids, list): + return web.json_response({"error": "job_ids must be a list"}, status=400) + + invalid_ids = [] + for jid in job_ids: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid if isinstance(jid, str) else repr(jid)) + if invalid_ids: + return web.json_response( + {"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids}, + status=400, + ) + + cancelled = False + for jid in job_ids: + if _cancel_job_by_id(jid): + cancelled = True + return web.json_response({"cancelled": cancelled}) + + app = web.Application() + app.router.add_post("/api/jobs/{job_id}/cancel", cancel_job_by_id) + app.router.add_post("/api/jobs/cancel", cancel_jobs_batch) + return app + + +# --------------------------------------------------------------------------- +# Pure helper tests: classification + cancel side effects +# --------------------------------------------------------------------------- + + +class TestClassifyJobForCancel: + def test_running(self): + running = [make_queue_item("a")] + assert classify_job_for_cancel("a", running, [], {}) == CANCEL_RUNNING + + def test_pending(self): + pending = [make_queue_item("b")] + assert classify_job_for_cancel("b", [], pending, {}) == CANCEL_PENDING + + def test_terminal(self): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + assert classify_job_for_cancel("c", [], [], history) == CANCEL_TERMINAL + + def test_unknown(self): + assert classify_job_for_cancel("z", [], [], {}) == CANCEL_UNKNOWN + + +class TestCancelJobHelper: + """``interrupt`` and ``dequeue`` both take the id and return whether they + actually acted, so cancel_job's return reflects the real outcome.""" + + def test_running_is_interrupted_not_dequeued(self): + interrupts = [] + dequeues = [] + result = cancel_job( + "a", [make_queue_item("a")], [], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_RUNNING + assert interrupts == ["a"] + assert dequeues == [] + + def test_pending_is_dequeued_not_interrupted(self): + interrupts = [] + dequeues = [] + result = cancel_job( + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_PENDING + assert dequeues == ["b"] + assert interrupts == [] + + def test_terminal_is_noop(self): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + interrupts = [] + dequeues = [] + result = cancel_job( + "c", [], [], history, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_TERMINAL + assert interrupts == [] + assert dequeues == [] + + def test_unknown_is_noop(self): + interrupts = [] + dequeues = [] + result = cancel_job( + "z", [], [], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_UNKNOWN + assert interrupts == [] + assert dequeues == [] + + def test_running_but_finished_before_interrupt_returns_unknown(self): + """Classified RUNNING from a stale snapshot, but the job finished before + the atomic interrupt fired (interrupt returns False). cancel_job reports + UNKNOWN rather than claiming a cancel that did not happen — and the + atomic interrupt guarantees no unrelated job was hit.""" + interrupts = [] + result = cancel_job( + "a", [make_queue_item("a")], [], {}, + interrupt=lambda pid: interrupts.append(pid) or False, + dequeue=lambda pid: True, + ) + assert result == CANCEL_UNKNOWN + assert interrupts == ["a"] # interrupt was attempted atomically + + def test_pending_started_running_is_interrupted(self): + """Pending->running race: the job leaves the queue (dequeue False) + because it started executing. The atomic interrupt catches the now- + running job, so cancel_job interrupts it and reports CANCEL_RUNNING.""" + interrupts = [] + dequeues = [] + result = cancel_job( + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: (dequeues.append(pid), False)[1], + ) + assert result == CANCEL_RUNNING + assert dequeues == ["b"] # dequeue attempted first + assert interrupts == ["b"] # then the now-running job was interrupted + + def test_pending_dequeue_miss_not_running_returns_unknown(self): + """Dequeue miss where the job is not running anymore (it finished): the + atomic interrupt finds nothing to interrupt and returns False, so + cancel_job is a no-op reporting UNKNOWN — never reporting a cancel that + did not happen, and never interrupting a bystander.""" + interrupts = [] + dequeues = [] + result = cancel_job( + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or False, + dequeue=lambda pid: (dequeues.append(pid), False)[1], + ) + assert result == CANCEL_UNKNOWN + assert dequeues == ["b"] + assert interrupts == ["b"] # interrupt attempted, found nothing running + + +# --------------------------------------------------------------------------- +# HTTP contract tests: POST /api/jobs/{job_id}/cancel +# --------------------------------------------------------------------------- + + +class TestSingleCancelEndpoint: + @pytest.mark.asyncio + async def test_cancel_running_job_interrupts(self, aiohttp_client): + queue = FakePromptQueue(running=[make_queue_item("a")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/a/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 + + @pytest.mark.asyncio + async def test_cancel_pending_job_dequeues(self, aiohttp_client): + queue = FakePromptQueue(pending=[make_queue_item("b")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/b/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + # Pending job removed from the queue; nothing interrupted. + assert queue.get_current_queue()[1] == [] + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_cancel_terminal_job_is_idempotent_noop(self, aiohttp_client): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + queue = FakePromptQueue(history=history) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/c/cancel") + + # Already-finished job: 200 no-op (cancelled=false), not an error. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_cancel_unknown_id_is_200_noop(self, aiohttp_client): + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/does-not-exist/cancel") + + # Single-cancel of an unknown id is treated as an idempotent no-op. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_cancel_pending_that_started_running_interrupts(self, aiohttp_client): + """Pending->running race end to end: the job is pending at snapshot time + but starts executing by the time we dequeue (delete misses). The live + re-check sees it running and interrupts it, so the cancel is not dropped + and the caller still gets cancelled=True.""" + + class RacingQueue(FakePromptQueue): + def delete_queue_item(self, function): + # The worker picked the job up just before we removed it: it + # leaves the pending queue (delete misses) and is now running. + self._running = list(self._pending) + self._pending = [] + return False + + queue = RacingQueue(pending=[make_queue_item("b")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/b/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 + + +# --------------------------------------------------------------------------- +# HTTP contract tests: POST /api/jobs/cancel (batch) +# --------------------------------------------------------------------------- + + +class TestBatchCancelEndpoint: + @pytest.mark.asyncio + async def test_batch_happy_path(self, aiohttp_client): + queue = FakePromptQueue( + running=[make_queue_item(_UUID_A)], + pending=[make_queue_item(_UUID_B, number=1)], + ) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]}) + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 # running job interrupted + assert queue.get_current_queue()[1] == [] # pending job dequeued + + @pytest.mark.asyncio + async def test_batch_best_effort_skips_unknown_id(self, aiohttp_client): + """An unknown id in the batch is a no-op, not a reason to abort: the + running and pending jobs are still cancelled (200, cancelled=true). This + is the "cancel all as a job finishes" case from review.""" + queue = FakePromptQueue( + running=[make_queue_item(_UUID_A)], + pending=[make_queue_item(_UUID_B, number=1)], + ) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post( + "/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]} + ) + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 # running job interrupted + assert queue.get_current_queue()[1] == [] # pending job dequeued + + @pytest.mark.asyncio + async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client): + history = { + _UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}}, + _UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}}, + } + queue = FakePromptQueue(history=history) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]}) + + # All known but terminal: 200 with cancelled=false, nothing dispatched. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_batch_missing_job_ids_is_400(self, aiohttp_client): + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={}) + + assert resp.status == 400 + + @pytest.mark.asyncio + async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client): + """An unhashable element such as a dict or list must yield 400, not 500. + + Previously, passing e.g. {"job_ids": [{}]} would reach the classify + loop where ``prompt_id in history`` raises TypeError on an unhashable + type, resulting in an unhandled 500. The input-validation guard must + catch this before any queue or history access. + """ + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]}) + + assert resp.status == 400 + body = await resp.json() + assert "invalid_ids" in body + # No queue side effects. + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client): + """A string that is not a valid UUID must be rejected with 400.""" + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post( + "/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]} + ) + + assert resp.status == 400 + body = await resp.json() + assert "invalid_ids" in body From cd77c551d6c7efa46a8ba514fd6f4e04aac76b4d Mon Sep 17 00:00:00 2001 From: Barish Ozbay <17261091+drozbay@users.noreply.github.com> Date: Fri, 19 Jun 2026 19:47:31 -0400 Subject: [PATCH 8/8] feat: Context Windows sampling with LTX2 models and IC-LoRa guides (CORE-3) (#13325) --- comfy/context_windows.py | 466 ++++++++++++++++++++++---- comfy/ldm/lightricks/model.py | 4 +- comfy/model_base.py | 122 +++++++ comfy_extras/nodes_context_windows.py | 77 ++++- 4 files changed, 592 insertions(+), 77 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index db57537a2..5f9899c67 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -8,6 +8,8 @@ from abc import ABC, abstractmethod import logging import comfy.model_management import comfy.patcher_extension +import comfy.utils +import comfy.conds if TYPE_CHECKING: from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher @@ -51,12 +53,18 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0): self.index_list = index_list self.context_length = len(index_list) + self.context_overlap = context_overlap self.dim = dim self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) + self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} + self.guide_frames_indices: list[int] = [] + self.guide_overlap_info: list[tuple[int, int]] = [] + self.guide_kf_local_positions: list[int] = [] + self.guide_downscale_factors: list[int] = [] def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: @@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC): region_idx = int(self.center_ratio * num_regions) return min(max(region_idx, 0), num_regions - 1) + def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow': + if modality_idx == 0: + return self + return self.modality_windows[modality_idx] + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) +def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]): + """Compute which concatenated guide frames overlap with a context window. + + Each guide's latent-space start is derived from its first token's pixel-t-start + in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the + model's temporal_downscale_ratio. + + Args: + guide_entries: list of guide_attention_entry dicts + keyframe_idxs: per-token pixel coords cond tensor for the modality + temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio + window_index_list: the window's frame indices into the video portion + + Returns: + suffix_indices: indices into the guide_frames tensor for frame selection + overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment + kf_local_positions: window-local frame positions for keyframe_idxs regeneration + total_overlap: total number of overlapping guide frames + """ + window_set = set(window_index_list) + window_list = list(window_index_list) + suffix_indices = [] + overlap_info = [] + kf_local_positions = [] + suffix_base = 0 + token_offset = 0 + + for entry_idx, entry in enumerate(guide_entries): + first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item()) + latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio + guide_len = entry["latent_shape"][0] + entry_overlap = 0 + + for local_offset in range(guide_len): + video_pos = latent_start + local_offset + if video_pos in window_set: + suffix_indices.append(suffix_base + local_offset) + kf_local_positions.append(window_list.index(video_pos)) + entry_overlap += 1 + + if entry_overlap > 0: + overlap_info.append((entry_idx, entry_overlap)) + suffix_base += guide_len + token_offset += entry["pre_filter_count"] + + return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) + + +@dataclass +class WindowingState: + """Per-modality context windowing state for each step, + built using IndexListContextHandler._build_window_state(). + For non-multimodal models the lists are length 1 + """ + latents: list[torch.Tensor] # per-modality working latents (guide frames stripped) + guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents + guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata + keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation + latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal) + dim: int = 0 # primary modality temporal dim for context windowing + is_multimodal: bool = False + temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio + + def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow: + """Reformat window for multimodal contexts by deriving per-modality index lists. + Non-multimodal contexts return the input window unchanged. + """ + if not self.is_multimodal: + return window + + x = self.latents[0] + primary_total = self.latent_shapes[0][self.dim] + primary_overlap = window.context_overlap + map_shapes = self.latent_shapes + if x.size(self.dim) != primary_total: + map_shapes = list(self.latent_shapes) + video_shape = list(self.latent_shapes[0]) + video_shape[self.dim] = x.size(self.dim) + map_shapes[0] = torch.Size(video_shape) + try: + per_modality_indices = model.map_context_window_to_modalities( + window.index_list, map_shapes, self.dim) + except AttributeError: + raise NotImplementedError( + f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") + modality_windows = {} + for mod_idx in range(1, len(self.latents)): + modality_total_frames = self.latents[mod_idx].shape[self.dim] + ratio = modality_total_frames / primary_total if primary_total > 0 else 1 + modality_overlap = max(round(primary_overlap * ratio), 0) + modality_windows[mod_idx] = IndexListContextWindow( + per_modality_indices[mod_idx], dim=self.dim, + total_frames=modality_total_frames, + context_overlap=modality_overlap) + return IndexListContextWindow( + window.index_list, dim=self.dim, total_frames=x.shape[self.dim], + modality_windows=modality_windows, context_overlap=primary_overlap) + + def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]: + """Slice latents for a context window, injecting guide frames where applicable. + For multimodal contexts, uses the modality-specific windows derived in prepare_window(). + """ + sliced = [] + guide_frame_counts = [] + for idx in range(len(self.latents)): + modality_window = window.get_window_for_modality(idx) + retain = retain_index_list if idx == 0 else [] + s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain) + if self.guide_entries[idx] is not None: + s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx) + else: + ng = 0 + sliced.append(s) + guide_frame_counts.append(ng) + return sliced, guide_frame_counts + + def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow): + """Strip injected guide frames from per-cond, per-modality outputs in place.""" + for idx in range(len(self.latents)): + if guide_frame_counts[idx] > 0: + window_len = len(window.get_window_for_modality(idx).index_list) + for ci in range(len(out_per_modality)): + out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len) + + def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]: + guide_entries = self.guide_entries[modality_idx] + guide_frames = self.guide_latents[modality_idx] + keyframe_idxs = self.keyframe_idxs[modality_idx] + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap( + guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list) + # Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0. + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + kf_local_pos = [p + 1 for p in kf_local_pos] + window.guide_frames_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. + # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. + guide_downscale_factors = [] + if guide_frame_count > 0: + full_H = guide_frames.shape[3] + for entry_idx, _ in overlap_info: + entry_H = guide_entries[entry_idx]["latent_shape"][1] + guide_downscale_factors.append(full_H // entry_H) + window.guide_downscale_factors = guide_downscale_factors + + if guide_frame_count > 0: + idx = tuple([slice(None)] * self.dim + [suffix_idx]) + return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count + return latent_slice, 0 + + def patch_latent_shapes(self, sub_conds, new_shapes): + if not self.is_multimodal: + return + + for cond_list in sub_conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + + @dataclass class ContextSchedule: name: str @@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co class IndexListContextHandler(ContextHandlerABC): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, - causal_window_fix: bool=True): + latent_retain_index_list: list[int]=[], causal_window_fix: bool=True): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC): self.freenoise = freenoise self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.split_conds_to_windows = split_conds_to_windows + self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else [] self.causal_window_fix = causal_window_fix self.callbacks = {} + @staticmethod + def _get_latent_shapes(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + return model_conds['latent_shapes'].cond + return None + + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + + @staticmethod + def _get_keyframe_idxs(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + kf = model_conds.get('keyframe_idxs') + if kf is not None and hasattr(kf, 'cond') and kf.cond is not None: + return kf.cond + return None + + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: + """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio. + If guide frames are present on the primary modality, only the video portion is shuffled. + """ + guide_entries = self._get_guide_entries(conds) + guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 + + latent_shapes = self._get_latent_shapes(conds) + if latent_shapes is not None and len(latent_shapes) > 1: + modalities = comfy.utils.unpack_latents(noise, latent_shapes) + primary_total = latent_shapes[0][self.dim] + primary_video_count = modalities[0].size(self.dim) - guide_count + apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed) + for i in range(1, len(modalities)): + mod_total = latent_shapes[i][self.dim] + ratio = mod_total / primary_total if primary_total > 0 else 1 + mod_ctx_len = max(round(self.context_length * ratio), 1) + mod_ctx_overlap = max(round(self.context_overlap * ratio), 0) + modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) + noise, _ = comfy.utils.pack_latents(modalities) + return noise + video_count = noise.size(self.dim) - guide_count + apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed) + return noise + + def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState: + """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" + latent_shapes = self._get_latent_shapes(conds) + is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 + unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in] + + unpacked_latents_list = list(unpacked_latents) + guide_latents_list = [None] * len(unpacked_latents) + guide_entries_list = [None] * len(unpacked_latents) + keyframe_idxs_list = [None] * len(unpacked_latents) + + extracted_guide_entries = self._get_guide_entries(conds) + extracted_keyframe_idxs = self._get_keyframe_idxs(conds) + + # Strip guide frames (only from first modality for now) + if extracted_guide_entries is not None: + guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries) + if guide_count > 0: + x = unpacked_latents[0] + latent_count = x.size(self.dim) - guide_count + unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count) + guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count) + guide_entries_list[0] = extracted_guide_entries + keyframe_idxs_list[0] = extracted_keyframe_idxs + + + return WindowingState( + latents=unpacked_latents_list, + guide_latents=guide_latents_list, + guide_entries=guide_entries_list, + keyframe_idxs=keyframe_idxs_list, + latent_shapes=latent_shapes, + dim=self.dim, + is_multimodal=is_multimodal, + temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio) + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - # for now, assume first dim is batch - should have stored on BaseModel in actual implementation - if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute + total_frame_count = window_state.latents[0].size(self.dim) + if total_frame_count > self.context_length: + logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") + if self.latent_retain_index_list: + logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}") return True + logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).") return False def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: @@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) + sample_sigmas = model_options["transformer_options"]["sample_sigmas"] + current_timestep = timestep[0].to(sample_sigmas.dtype) + mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: return # substep from multi-step sampler: keep self._step from the last full step @@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): self._model = model self.set_step(timestep, model_options) - context_windows = self.get_context_windows(model, x_in, model_options) - enumerated_context_windows = list(enumerate(context_windows)) - conds_final = [torch.zeros_like(x_in) for _ in conds] + window_state = self._build_window_state(x_in, conds, model) + num_modalities = len(window_state.latents) + + context_windows = self.get_context_windows(model, window_state.latents[0], model_options) + enumerated_context_windows = list(enumerate(context_windows)) + total_windows = len(enumerated_context_windows) + + # Initialize per-modality accumulators (length 1 for single-modality) + accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] else: - counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) + # accumulate results from each context window for enum_window in enumerated_context_windows: - results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + results = self.evaluate_context_windows( + calc_cond_batch, model, x_in, conds, timestep, [enum_window], + model_options, window_state=window_state, total_windows=total_windows) for result in results: - self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) + # result.sub_conds_out is per-cond, per-modality: list[list[Tensor]] + for mod_idx in range(num_modalities): + mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))] + modality_window = result.window.get_window_for_modality(mod_idx) + self.combine_context_window_results( + window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window, + result.window_idx, total_windows, timestep, + accum[mod_idx], counts[mod_idx], biases[mod_idx]) + + # fuse accumulated results into final conds try: - # finalize conds - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - # relative is already normalized, so return as is - del counts_final - return conds_final - else: - # normalize conds via division by context usage counts - for i in range(len(conds_final)): - conds_final[i] /= counts_final[i] - del counts_final - return conds_final + result_out = [] + for ci in range(len(conds)): + finalized = [] + for mod_idx in range(num_modalities): + if self.fuse_method.name != ContextFuseMethods.RELATIVE: + accum[mod_idx][ci] /= counts[mod_idx][ci] + f = accum[mod_idx][ci] + + # if guide frames were injected, append them to the end of the fused latents for the next step + if window_state.guide_latents[mod_idx] is not None: + f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim) + finalized.append(f) + + # pack modalities together if needed + if window_state.is_multimodal and len(finalized) > 1: + packed, _ = comfy.utils.pack_latents(finalized) + else: + packed = finalized[0] + + result_out.append(packed) + return result_out finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], - model_options, device=None, first_device=None): + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, + timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, window_state: WindowingState, total_windows: int = None, + device=None, first_device=None): + """Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out + + For each window: + 1. Builds windows (for each modality if multimodal) + 2. Slices window for each modality + 3. Injects concatenated latent guide frames where present + 4. Packs together if needed and calls model + 5. Unpacks and strips any guides from outputs + """ + x = window_state.latents[0] + results: list[ContextResults] = [] for window_idx, window in enumerated_context_windows: # allow processing to end between context window executions for faster Cancel comfy.model_management.throw_exception_if_processing_interrupted() - # causal_window_fix: prepend a pre-window frame that will be stripped post-forward + # prepare the window accounting for multimodal windows + window = window_state.prepare_window(window, model) + + # causal_window_fix: prepend a pre-window frame that will be stripped post-forward. + # Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up. anchor_applied = False if self.causal_window_fix: anchor_idx = window.index_list[0] - 1 @@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC): window.causal_anchor_index = anchor_idx anchor_applied = True + # slice the window for each modality, injecting guide frames where applicable + sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device) + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) - # update exposed params + logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}" + + (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "") + ) + + # if multimodal, pack modalities together + if window_state.is_multimodal and len(sliced) > 1: + sub_x, sub_shapes = comfy.utils.pack_latents(sliced) + else: + sub_x, sub_shapes = sliced[0], [sliced[0].shape] + + # get resized conds for window model_options["transformer_options"]["context_window"] = window - # get subsections of x, timestep, conds - sub_x = window.get_tensor(x_in, device) - sub_timestep = window.get_tensor(timestep, device, dim=0) - sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + sub_timestep = window.get_tensor(timestep, dim=0) + sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds] + # if multimodal, patch latent_shapes in conds for correct unpacking in model + window_state.patch_latent_shapes(sub_conds, sub_shapes) + + # call model on window sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - if device is not None: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].to(x_in.device) - # strip causal_window_fix anchor if applied + # unpack outputs + out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + + # strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct if anchor_applied: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1) + for ci in range(len(out_per_modality)): + t = out_per_modality[ci][0] + out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1) - results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + # strip injected guide frames + window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window) + + results.append(ContextResults(window_idx, out_per_modality, sub_conds, window)) return results @@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC): biases_final[i][idx] = bias_total + bias else: # add conds and counts based on weights of fuse method - weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) + weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap) weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) for i in range(len(sub_conds_out)): window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) @@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC): callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) -def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): - # limit noise_shape length to context_length for more accurate vram use estimation +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs): + # Scale noise_shape to a single context window so VRAM estimation budgets per-window. model_options = kwargs.get("model_options", None) if model_options is None: raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") handler: IndexListContextHandler = model_options.get("context_handler", None) if handler is not None: noise_shape = list(noise_shape) - noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) - return executor(model, noise_shape, *args, **kwargs) + is_packed = len(noise_shape) == 3 and noise_shape[1] == 1 + if is_packed: + # TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a + # per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM. + pass + elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + return executor(model, noise_shape, conds, *args, **kwargs) def create_prepare_sampling_wrapper(model: ModelPatcher): @@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + conds = [guider.conds.get('positive', guider.conds.get('negative', []))] + noise = handler._apply_freenoise(noise, conds, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - def create_sampler_sample_wrapper(model: ModelPatcher): model.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, @@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher): _sampler_sample_wrapper ) - def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) @@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: return ContextSchedule(context_schedule, func) -def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): - return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None): + context_overlap = handler.context_overlap if context_overlap is None else context_overlap + return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap) def create_weights_flat(length: int, **kwargs) -> list[float]: @@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]: weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) return weight_sequence -def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): +def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs): # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 # only expected overlap is given different weights weights_torch = torch.ones((length)) # blend left-side on all except first window if min(idxs) > 0: - ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) - weights_torch[:handler.context_overlap] = ramp_up + ramp_up = torch.linspace(1e-37, 1, context_overlap) + weights_torch[:context_overlap] = ramp_up # blend right-side on all except last window if max(idxs) < full_length-1: - ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) - weights_torch[-handler.context_overlap:] = ramp_down + ramp_down = torch.linspace(1, 1e-37, context_overlap) + weights_torch[-context_overlap:] = ramp_down return weights_torch class ContextFuseMethods: diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index e0a4a0f9b..9953b6679 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel): ) grid_mask = None - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: additional_args.update({ "orig_patchified_shape": list(x.shape)}) denoise_mask = self.patchifier.patchify(denoise_mask)[0] grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] @@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel): x = x * (1 + scale) + shift x = self.proj_out(x) - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: grid_mask = kwargs["grid_mask"] orig_patchified_shape = kwargs["orig_patchified_shape"] full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) diff --git a/comfy/model_base.py b/comfy/model_base.py index f49da50ae..264dbb9b3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model +import comfy.ldm.lightricks.symmetric_patchifier import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC @@ -1204,6 +1205,127 @@ class LTXAV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): + result = [primary_indices] + if len(latent_shapes) < 2: + return result + + video_total = latent_shapes[0][dim] + + for i in range(1, len(latent_shapes)): + mod_total = latent_shapes[i][dim] + # Map each primary index to its proportional range of modality indices and + # concatenate in order. Preserves wrapped/strided geometry so the modality + # attends to the same temporal regions as the primary window. + mod_indices = [] + seen = set() + for v_idx in primary_indices: + a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1) + a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total) + if a_end <= a_start: + a_end = a_start + 1 + for a in range(a_start, a_end): + if a not in seen: + seen.add(a) + mod_indices.append(a) + result.append(mod_indices) + + return result + + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # Audio denoise mask — slice using audio modality window + if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: + audio_window = window.modality_windows.get(1) + if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) + return cond_value._copy_with(sliced) + + # Video denoise mask — split into video + guide portions, slice each + if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + cond_tensor = cond_value.cond + guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) + if guide_count > 0: + T_video = x_in.size(window.dim) + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + suffix_indices = window.guide_frames_indices + if suffix_indices: + idx = tuple([slice(None)] * window.dim + [suffix_indices]) + sliced_guide = guide_mask[idx].to(device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + else: + return cond_value._copy_with(sliced_video) + + # Keyframe indices — regenerate pixel coords for window, select guide positions + if cond_key == "keyframe_idxs": + kf_local_pos = window.guide_kf_local_positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + # account for causal_window_fix anchor in coord space size + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + window_len += 1 + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + scale_factors = self.diffusion_model.vae_scale_factors + pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords( + latent_coords, + scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + + # Adjust spatial end positions for dilated (downscaled) guides. + # Each guide entry may have a different downscale factor; expand the + # per-entry factor to cover all tokens belonging to that entry. + downscale_factors = window.guide_downscale_factors + overlap_info = window.guide_overlap_info + if downscale_factors: + per_token_factor = [] + for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors): + per_token_factor.extend([dsf] * (overlap_count * H * W)) + factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype) + spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor( + scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset + + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + # Guide attention entries — adjust per-guide counts based on window overlap + if cond_key == "guide_attention_entries": + overlap_info = window.guide_overlap_info + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) + + return None + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 098c26f23..15d2dc506 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -13,21 +13,22 @@ class ContextWindowsManualNode(io.ComfyNode): description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, - ], tooltip="The stride of the context window."), - io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + ], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."), io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."), ], outputs=[ @@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model: + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -51,6 +52,7 @@ class ContextWindowsManualNode(io.ComfyNode): freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows, + latent_retain_index_list=latent_retain_index_list, causal_window_fix=causal_window_fix, ) # make memory usage calculation only take into account the context window latents @@ -65,33 +67,71 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): schema = super().define_schema() schema.node_id = "WanContextWindowsManual" schema.display_name = "WAN Context Windows (Manual)" - schema.description = "Manually set context windows for WAN-like models (dim=2)." + schema.display_name = "Wan Context Windows" + schema.description = "Set context windows for Wan-like models." schema.category="model/patch/wan" schema.inputs = [ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, - ], tooltip="The stride of the context window."), + ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), - io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), - io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True), + io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True), ] return schema @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: - context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 - context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) + retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 + context_overlap = max(context_overlap // 4, 0) # at least overlap 0 + retain_index_list = "0" if retain_first_frame else "" + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows) + + +class LTXVContextWindowsNode(ContextWindowsManualNode): + @classmethod + def define_schema(cls) -> io.Schema: + schema = super().define_schema() + schema.node_id = "LTXVContextWindows" + schema.display_name = "LTXV Context Windows" + schema.description = "Set context windows for LTXV-like models." + schema.inputs = [ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."), + io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True), + io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True), + ] + return schema + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool, + retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1 + context_overlap = max(context_overlap // 8, 0) # at least overlap 0 + retain_index_list = "0" if retain_first_frame else "" + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, + cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows) class ContextWindowsExtension(ComfyExtension): @@ -99,6 +139,7 @@ class ContextWindowsExtension(ComfyExtension): return [ ContextWindowsManualNode, WanContextWindowsManualNode, + LTXVContextWindowsNode, ] def comfy_entrypoint():