From bc11e8a65a57dc6bd1768edfca14ddb1523f7882 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Sat, 20 Jun 2026 08:01:34 +0900 Subject: [PATCH 1/5] Bump comfyui-frontend-package to 1.45.19 (#14559) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 392709e64..ad8b1c2ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.45.15 +comfyui-frontend-package==1.45.19 comfyui-workflow-templates==0.10.0 comfyui-embedded-docs==0.5.4 torch From 2ab3816dcf66d58f2c0b3e79e910311b21697e0d Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 19 Jun 2026 19:06:55 -0400 Subject: [PATCH 2/5] feat: add Load3DAdvanced node (#14316) --- comfy_extras/nodes_load_3d.py | 63 +++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 455897859..6e3e88471 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -317,11 +317,74 @@ class PreviewPointCloud(IO.ComfyNode): ) +MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + + +class Load3DAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + input_dir = os.path.join(folder_paths.get_input_directory(), "3d") + os.makedirs(input_dir, exist_ok=True) + + input_path = Path(input_dir) + base_path = Path(folder_paths.get_input_directory()) + + files = [ + normalize_path(str(file_path.relative_to(base_path))) + for file_path in input_path.rglob("*") + if file_path.suffix.lower() in MESH_EXTENSIONS + ] + return IO.Schema( + node_id="Load3DAdvanced", + display_name="Load 3D (Advanced)", + category="3d", + search_aliases=[ + "load mesh", + "load gltf", + "load glb", + "load obj", + "load fbx", + "load stl", + ], + is_experimental=True, + inputs=[ + IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model), + IO.Load3D.Input("viewport_state"), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def validate_inputs(cls, model_file, **kwargs) -> bool | str: + if not model_file or model_file == "none": + return True + if not folder_paths.exists_annotated_filepath(model_file): + return f"Invalid 3D model file: {model_file}" + return True + + @classmethod + def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + file_3d = None + if model_file and model_file != "none": + file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) + model_3d_info = viewport_state.get('model_3d_info', []) + return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height) + + class Load3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Load3D, + Load3DAdvanced, Preview3D, Preview3DAdvanced, PreviewGaussianSplat, From 4e716f7c5769fd7bdd851d95a323c2377dfeb5a7 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Fri, 19 Jun 2026 16:39:35 -0700 Subject: [PATCH 3/5] Add jobs-namespace cancel endpoints (POST /api/jobs/{job_id}/cancel, POST /api/jobs/cancel) (#14493) * Add jobs-namespace cancel endpoints Add two cancel endpoints under the jobs namespace so a job can be cancelled by id without the caller needing to know whether the job is running or pending, or branching between /interrupt and /queue. - POST /api/jobs/{job_id}/cancel cancels one job by id. Idempotent: an already-finished or unknown id returns 200 {"cancelled": false} rather than an error. - POST /api/jobs/cancel takes {"job_ids": [...]} and cancels a batch. Fail-fast: if any id is unknown the request returns 404 listing the unknown ids and cancels nothing (no partial side effects). Both are state-agnostic and map onto the existing queue mechanics: a running job is interrupted (same path as /interrupt), a pending job is dequeued (same path as /queue {"delete": [...]}). The cancel logic lives in comfy_execution.jobs as pure, unit-tested helpers; the server handlers are thin wrappers. openapi.yaml documents both routes. * fix: resolve review feedback on cancel endpoints - Guard cancel_job() against TOCTOU: when dequeue() returns False the pending job left the queue between snapshot and delete; return CANCEL_UNKNOWN so callers never report cancelled=True for a remove that did not happen. - Validate each job_ids element in the batch cancel endpoint before any queue access; unhashable or non-UUID values now return 400 instead of raising TypeError (500). - Update batch HTTP tests to use canonical UUID ids (required now that the endpoint validates id format) and add tests for the new guards. * fix: make job cancel atomic and best-effort Addresses two cancel races/edges raised in review. Targeted, atomic interrupt. cancel_job's interrupt callback now takes the prompt id and returns whether it fired; the single-cancel route backs it with the new PromptQueue.interrupt_if_running, which checks the running set and signals the interrupt under the queue mutex. This closes the TOCTOU where a pending job that starts executing between the snapshot and dequeue (or a running job that finishes between the snapshot and interrupt) could be missed or, worse, cause an unrelated prompt to be interrupted. The per-prompt interrupt-flag reset in execute_async keeps a finished job from leaking the interrupt onto its successor. Best-effort batch cancel. POST /api/jobs/cancel no longer fails the whole batch with 404 when one id is unknown/finished; such ids are treated as no-ops, so "cancel all" still cancels the in-progress jobs even if some finished between the client's snapshot and the request. Malformed ids are still rejected with 400. --- comfy_execution/jobs.py | 81 +++- execution.py | 19 + server.py | 111 ++++- tests-unit/jobs_cancel_test/__init__.py | 0 .../jobs_cancel_test/jobs_cancel_test.py | 453 ++++++++++++++++++ 5 files changed, 662 insertions(+), 2 deletions(-) create mode 100644 tests-unit/jobs_cancel_test/__init__.py create mode 100644 tests-unit/jobs_cancel_test/jobs_cancel_test.py diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 20ebae155..fa3ab0faf 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -4,11 +4,22 @@ Provides normalization and helper functions for job status tracking. """ import uuid -from typing import Optional +from typing import Callable, Optional from comfy_api.internal import prune_dict +# Result of classifying a job for cancellation. +# 'running' -> job is currently executing (interrupt it) +# 'pending' -> job is queued but not started (dequeue it) +# 'terminal' -> job already finished (present in history); cancel is a no-op +# 'unknown' -> job id is not present anywhere +CANCEL_RUNNING = 'running' +CANCEL_PENDING = 'pending' +CANCEL_TERMINAL = 'terminal' +CANCEL_UNKNOWN = 'unknown' + + class JobStatus: """Job status constants.""" PENDING = 'pending' @@ -407,3 +418,71 @@ def get_all_jobs( jobs = jobs[:limit] return (jobs, total_count) + + +def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str: + """Classify a job id for cancellation. + + Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN. + + Queue items are tuples whose second element (index 1) is the prompt_id. + History is a dict keyed by prompt_id, so a job present there has already + finished and cancelling it is a no-op. + """ + for item in running: + if item[1] == prompt_id: + return CANCEL_RUNNING + for item in queued: + if item[1] == prompt_id: + return CANCEL_PENDING + if prompt_id in history: + return CANCEL_TERMINAL + return CANCEL_UNKNOWN + + +def cancel_job( + prompt_id: str, + running: list, + queued: list, + history: dict, + interrupt: Callable[[str], bool], + dequeue: Callable[[str], bool], +) -> str: + """Cancel a single job by id, regardless of state. + + Maps the cancel onto the runtime's existing mechanics: + - a running job is interrupted via ``interrupt`` + - a pending job is removed from the queue via ``dequeue`` + - a job that already finished (terminal) is a no-op + - an unknown id is a no-op (callers that need fail-fast behaviour should + validate ids up front with ``classify_job_for_cancel``) + + Both ``interrupt`` and ``dequeue`` take the prompt id and return whether + they acted on a job that was *actually* in that state, so the value returned + here reflects what truly happened rather than the (possibly stale) + classification. This matters around the narrow TOCTOU windows where a job + changes state between the caller's snapshot and the action: + + - a job classified RUNNING may have finished before ``interrupt`` fires: + ``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op). + - a job classified PENDING may have started executing before ``dequeue`` + fires: ``dequeue`` returns False, ``interrupt`` then catches the now- + running job and this returns CANCEL_RUNNING. If it had simply finished + instead, both return False and this returns CANCEL_UNKNOWN. + + ``interrupt`` must be atomic — interrupt the job only if it is still the one + running — so a cancel can never land on an unrelated prompt that started in + the meantime (see ``execution.PromptQueue.interrupt_if_running``). + """ + classification = classify_job_for_cancel(prompt_id, running, queued, history) + if classification == CANCEL_RUNNING: + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + if classification == CANCEL_PENDING: + if dequeue(prompt_id): + return CANCEL_PENDING + # Left the pending queue between classification and dequeue: if it + # started executing, interrupt the now-running job; otherwise it has + # already finished and the cancel is a genuine no-op. + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. + return classification diff --git a/execution.py b/execution.py index 9e16e451d..c45317593 100644 --- a/execution.py +++ b/execution.py @@ -1308,6 +1308,25 @@ class PromptQueue: queued = copy.copy(self.queue) return (running, queued) + def interrupt_if_running(self, prompt_id): + """Interrupt the running prompt with this id, atomically. + + Checks the live running set and signals the interrupt under the queue + mutex, so the worker cannot move the job to done (and start the next + prompt) in between. Returns True if a matching job was running and an + interrupt was signalled, False otherwise. The atomicity is what keeps a + cancel from landing on an unrelated prompt that started after a separate + is-running check: the global interrupt flag is reset at the start of + every prompt (execute_async), so a job that finishes before consuming + the flag cannot leak the interrupt onto its successor. + """ + with self.mutex: + for item in self.currently_running.values(): + if item[1] == prompt_id: + nodes.interrupt_processing() + return True + return False + def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) diff --git a/server.py b/server.py index 6b0029adf..361850f38 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,15 @@ import time import nodes import folder_paths import execution -from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id +from comfy_execution.jobs import ( + JobStatus, + get_job, + get_all_jobs, + validate_job_id, + cancel_job, + CANCEL_PENDING, + CANCEL_RUNNING, +) import uuid import urllib import json @@ -899,6 +907,107 @@ class PromptServer(): return web.json_response(job) + def _cancel_job_by_id(job_id): + """Cancel a single job by id using the queue's existing mechanics. + + Running jobs are interrupted (same mechanism as /interrupt); pending + jobs are dequeued (same mechanism as /queue {"delete": [...]}). + Already-finished or unknown ids are no-ops. State-agnostic. + + Returns True when a cancel was actually dispatched (running or + pending job), False when the call was a no-op (terminal/unknown id). + """ + running, queued = self.prompt_queue.get_current_queue() + history = self.prompt_queue.get_history() + + def interrupt(prompt_id): + logging.info(f"Cancelling running prompt {prompt_id}") + # Atomic: only interrupts if the job is still the one running, + # so a cancel can't land on a prompt that started in the gap + # since the snapshot above. Returns whether it actually fired. + return self.prompt_queue.interrupt_if_running(prompt_id) + + def dequeue(prompt_id): + logging.info(f"Cancelling pending prompt {prompt_id}") + return self.prompt_queue.delete_queue_item(lambda a: a[1] == prompt_id) + + classification = cancel_job(job_id, running, queued, history, interrupt, dequeue) + return classification in (CANCEL_RUNNING, CANCEL_PENDING) + + @routes.post("/api/jobs/{job_id}/cancel") + async def cancel_job_by_id(request): + """Cancel a single job by id, regardless of state. + + Idempotent: cancelling a job that has already finished, or an id + that is not known, returns 200 with {"cancelled": false} rather + than an error. + """ + job_id = request.match_info.get("job_id", None) + if not job_id: + return web.json_response( + {"error": "job_id is required"}, + status=400 + ) + + cancelled = _cancel_job_by_id(job_id) + return web.json_response({"cancelled": cancelled}) + + @routes.post("/api/jobs/cancel") + async def cancel_jobs_batch(request): + """Cancel a batch of jobs by id. + + Body: {"job_ids": ["", ...]} + + Best-effort and idempotent: every well-formed id is cancelled if it + is running or pending; ids that are already finished or unknown are + no-ops, not errors. A batch of all no-ops still returns 200 with + {"cancelled": false}. This matches the single-cancel endpoint and + means "cancel all" still cancels the in-progress jobs even if some + finished between the client's snapshot and the request. Malformed + ids are still rejected up front with 400 (see below). + """ + try: + json_data = await request.json() + except json.JSONDecodeError: + return web.json_response( + {"error": "Request body must be valid JSON"}, + status=400 + ) + + job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None + if not isinstance(job_ids, list): + return web.json_response( + {"error": "job_ids must be a list"}, + status=400 + ) + + # Validate that every element is a well-formed job id before doing + # anything else. An unhashable element (e.g. a nested dict or list) + # would cause a TypeError when used as a history dict key; a + # non-string or non-UUID value is never a valid id. Reject early + # with 400 rather than letting the classify loop raise 500. + invalid_ids = [] + for jid in job_ids: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid if isinstance(jid, str) else repr(jid)) + if invalid_ids: + return web.json_response( + {"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids}, + status=400, + ) + + # Best-effort: cancel each id that is still running/pending; an id + # that has finished or never existed is a no-op rather than a reason + # to fail the whole batch. + cancelled = False + for jid in job_ids: + if _cancel_job_by_id(jid): + cancelled = True + + return web.json_response({"cancelled": cancelled}) + @routes.get("/history") async def get_history(request): max_items = request.rel_url.query.get("max_items", None) diff --git a/tests-unit/jobs_cancel_test/__init__.py b/tests-unit/jobs_cancel_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/jobs_cancel_test/jobs_cancel_test.py b/tests-unit/jobs_cancel_test/jobs_cancel_test.py new file mode 100644 index 000000000..f1d591b0d --- /dev/null +++ b/tests-unit/jobs_cancel_test/jobs_cancel_test.py @@ -0,0 +1,453 @@ +"""Tests for the jobs-namespace cancel endpoints. + +Covers both layers: + +* the pure cancel helpers in ``comfy_execution.jobs`` + (``classify_job_for_cancel`` / ``cancel_job``), which hold the business + logic of mapping a cancel onto interrupt-vs-dequeue, and + +* the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and + ``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and + best-effort batch cancellation that treats unknown/finished ids as no-ops + while still rejecting malformed ids with 400). + +The HTTP layer is exercised against a small aiohttp app whose handlers are a +faithful copy of the wiring in ``server.py`` driven by a fake queue that +mirrors ``execution.PromptQueue`` (``get_current_queue`` / ``get_history`` / +``delete_queue_item``). This keeps the test free of the heavy ComfyUI runtime +(torch, nodes, ...) while still testing the real cancel logic. +""" + +import json + +import pytest +from aiohttp import web + +from comfy_execution.jobs import ( + CANCEL_PENDING, + CANCEL_RUNNING, + CANCEL_TERMINAL, + CANCEL_UNKNOWN, + cancel_job, + classify_job_for_cancel, + validate_job_id, +) + +# Classifications for which a cancel was actually dispatched (vs a no-op). +_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING) + +# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format). +_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa" +_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb" +_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc" +_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd" +_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff" + + +def make_queue_item(prompt_id, number=0): + """Build a queue tuple shaped like the real ones: index 1 is the id.""" + return (number, prompt_id, {}, {}, []) + + +class FakePromptQueue: + """Minimal stand-in for execution.PromptQueue for the cancel paths. + + Tracks interrupts and dequeues so tests can assert side effects. + """ + + def __init__(self, running=None, pending=None, history=None): + self._running = list(running or []) + self._pending = list(pending or []) + self._history = dict(history or {}) + self.interrupt_count = 0 + + def get_current_queue(self): + return (list(self._running), list(self._pending)) + + def get_history(self, prompt_id=None): + if prompt_id is None: + return dict(self._history) + if prompt_id in self._history: + return {prompt_id: self._history[prompt_id]} + return {} + + def delete_queue_item(self, function): + for i, item in enumerate(self._pending): + if function(item): + self._pending.pop(i) + return True + return False + + def interrupt_if_running(self, prompt_id): + # Mirrors execution.PromptQueue.interrupt_if_running: only signals an + # interrupt when the id is actually in the running set. + if any(item[1] == prompt_id for item in self._running): + self.interrupt_count += 1 + return True + return False + + +def build_app(queue): + """Build an aiohttp app exposing the cancel routes against ``queue``. + + Handler bodies mirror server.py exactly. + """ + + def _cancel_job_by_id(job_id): + running, pending = queue.get_current_queue() + history = queue.get_history() + + def interrupt(prompt_id): + return queue.interrupt_if_running(prompt_id) + + def dequeue(prompt_id): + return queue.delete_queue_item(lambda a: a[1] == prompt_id) + + classification = cancel_job( + job_id, running, pending, history, interrupt, dequeue + ) + return classification in _CANCELLED + + async def cancel_job_by_id(request): + job_id = request.match_info.get("job_id", None) + if not job_id: + return web.json_response({"error": "job_id is required"}, status=400) + cancelled = _cancel_job_by_id(job_id) + return web.json_response({"cancelled": cancelled}) + + async def cancel_jobs_batch(request): + try: + json_data = await request.json() + except json.JSONDecodeError: + return web.json_response( + {"error": "Request body must be valid JSON"}, status=400 + ) + + job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None + if not isinstance(job_ids, list): + return web.json_response({"error": "job_ids must be a list"}, status=400) + + invalid_ids = [] + for jid in job_ids: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid if isinstance(jid, str) else repr(jid)) + if invalid_ids: + return web.json_response( + {"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids}, + status=400, + ) + + cancelled = False + for jid in job_ids: + if _cancel_job_by_id(jid): + cancelled = True + return web.json_response({"cancelled": cancelled}) + + app = web.Application() + app.router.add_post("/api/jobs/{job_id}/cancel", cancel_job_by_id) + app.router.add_post("/api/jobs/cancel", cancel_jobs_batch) + return app + + +# --------------------------------------------------------------------------- +# Pure helper tests: classification + cancel side effects +# --------------------------------------------------------------------------- + + +class TestClassifyJobForCancel: + def test_running(self): + running = [make_queue_item("a")] + assert classify_job_for_cancel("a", running, [], {}) == CANCEL_RUNNING + + def test_pending(self): + pending = [make_queue_item("b")] + assert classify_job_for_cancel("b", [], pending, {}) == CANCEL_PENDING + + def test_terminal(self): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + assert classify_job_for_cancel("c", [], [], history) == CANCEL_TERMINAL + + def test_unknown(self): + assert classify_job_for_cancel("z", [], [], {}) == CANCEL_UNKNOWN + + +class TestCancelJobHelper: + """``interrupt`` and ``dequeue`` both take the id and return whether they + actually acted, so cancel_job's return reflects the real outcome.""" + + def test_running_is_interrupted_not_dequeued(self): + interrupts = [] + dequeues = [] + result = cancel_job( + "a", [make_queue_item("a")], [], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_RUNNING + assert interrupts == ["a"] + assert dequeues == [] + + def test_pending_is_dequeued_not_interrupted(self): + interrupts = [] + dequeues = [] + result = cancel_job( + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_PENDING + assert dequeues == ["b"] + assert interrupts == [] + + def test_terminal_is_noop(self): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + interrupts = [] + dequeues = [] + result = cancel_job( + "c", [], [], history, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_TERMINAL + assert interrupts == [] + assert dequeues == [] + + def test_unknown_is_noop(self): + interrupts = [] + dequeues = [] + result = cancel_job( + "z", [], [], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_UNKNOWN + assert interrupts == [] + assert dequeues == [] + + def test_running_but_finished_before_interrupt_returns_unknown(self): + """Classified RUNNING from a stale snapshot, but the job finished before + the atomic interrupt fired (interrupt returns False). cancel_job reports + UNKNOWN rather than claiming a cancel that did not happen — and the + atomic interrupt guarantees no unrelated job was hit.""" + interrupts = [] + result = cancel_job( + "a", [make_queue_item("a")], [], {}, + interrupt=lambda pid: interrupts.append(pid) or False, + dequeue=lambda pid: True, + ) + assert result == CANCEL_UNKNOWN + assert interrupts == ["a"] # interrupt was attempted atomically + + def test_pending_started_running_is_interrupted(self): + """Pending->running race: the job leaves the queue (dequeue False) + because it started executing. The atomic interrupt catches the now- + running job, so cancel_job interrupts it and reports CANCEL_RUNNING.""" + interrupts = [] + dequeues = [] + result = cancel_job( + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: (dequeues.append(pid), False)[1], + ) + assert result == CANCEL_RUNNING + assert dequeues == ["b"] # dequeue attempted first + assert interrupts == ["b"] # then the now-running job was interrupted + + def test_pending_dequeue_miss_not_running_returns_unknown(self): + """Dequeue miss where the job is not running anymore (it finished): the + atomic interrupt finds nothing to interrupt and returns False, so + cancel_job is a no-op reporting UNKNOWN — never reporting a cancel that + did not happen, and never interrupting a bystander.""" + interrupts = [] + dequeues = [] + result = cancel_job( + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or False, + dequeue=lambda pid: (dequeues.append(pid), False)[1], + ) + assert result == CANCEL_UNKNOWN + assert dequeues == ["b"] + assert interrupts == ["b"] # interrupt attempted, found nothing running + + +# --------------------------------------------------------------------------- +# HTTP contract tests: POST /api/jobs/{job_id}/cancel +# --------------------------------------------------------------------------- + + +class TestSingleCancelEndpoint: + @pytest.mark.asyncio + async def test_cancel_running_job_interrupts(self, aiohttp_client): + queue = FakePromptQueue(running=[make_queue_item("a")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/a/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 + + @pytest.mark.asyncio + async def test_cancel_pending_job_dequeues(self, aiohttp_client): + queue = FakePromptQueue(pending=[make_queue_item("b")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/b/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + # Pending job removed from the queue; nothing interrupted. + assert queue.get_current_queue()[1] == [] + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_cancel_terminal_job_is_idempotent_noop(self, aiohttp_client): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + queue = FakePromptQueue(history=history) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/c/cancel") + + # Already-finished job: 200 no-op (cancelled=false), not an error. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_cancel_unknown_id_is_200_noop(self, aiohttp_client): + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/does-not-exist/cancel") + + # Single-cancel of an unknown id is treated as an idempotent no-op. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_cancel_pending_that_started_running_interrupts(self, aiohttp_client): + """Pending->running race end to end: the job is pending at snapshot time + but starts executing by the time we dequeue (delete misses). The live + re-check sees it running and interrupts it, so the cancel is not dropped + and the caller still gets cancelled=True.""" + + class RacingQueue(FakePromptQueue): + def delete_queue_item(self, function): + # The worker picked the job up just before we removed it: it + # leaves the pending queue (delete misses) and is now running. + self._running = list(self._pending) + self._pending = [] + return False + + queue = RacingQueue(pending=[make_queue_item("b")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/b/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 + + +# --------------------------------------------------------------------------- +# HTTP contract tests: POST /api/jobs/cancel (batch) +# --------------------------------------------------------------------------- + + +class TestBatchCancelEndpoint: + @pytest.mark.asyncio + async def test_batch_happy_path(self, aiohttp_client): + queue = FakePromptQueue( + running=[make_queue_item(_UUID_A)], + pending=[make_queue_item(_UUID_B, number=1)], + ) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]}) + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 # running job interrupted + assert queue.get_current_queue()[1] == [] # pending job dequeued + + @pytest.mark.asyncio + async def test_batch_best_effort_skips_unknown_id(self, aiohttp_client): + """An unknown id in the batch is a no-op, not a reason to abort: the + running and pending jobs are still cancelled (200, cancelled=true). This + is the "cancel all as a job finishes" case from review.""" + queue = FakePromptQueue( + running=[make_queue_item(_UUID_A)], + pending=[make_queue_item(_UUID_B, number=1)], + ) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post( + "/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]} + ) + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 # running job interrupted + assert queue.get_current_queue()[1] == [] # pending job dequeued + + @pytest.mark.asyncio + async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client): + history = { + _UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}}, + _UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}}, + } + queue = FakePromptQueue(history=history) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]}) + + # All known but terminal: 200 with cancelled=false, nothing dispatched. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_batch_missing_job_ids_is_400(self, aiohttp_client): + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={}) + + assert resp.status == 400 + + @pytest.mark.asyncio + async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client): + """An unhashable element such as a dict or list must yield 400, not 500. + + Previously, passing e.g. {"job_ids": [{}]} would reach the classify + loop where ``prompt_id in history`` raises TypeError on an unhashable + type, resulting in an unhandled 500. The input-validation guard must + catch this before any queue or history access. + """ + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]}) + + assert resp.status == 400 + body = await resp.json() + assert "invalid_ids" in body + # No queue side effects. + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client): + """A string that is not a valid UUID must be rejected with 400.""" + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post( + "/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]} + ) + + assert resp.status == 400 + body = await resp.json() + assert "invalid_ids" in body From cd77c551d6c7efa46a8ba514fd6f4e04aac76b4d Mon Sep 17 00:00:00 2001 From: Barish Ozbay <17261091+drozbay@users.noreply.github.com> Date: Fri, 19 Jun 2026 19:47:31 -0400 Subject: [PATCH 4/5] feat: Context Windows sampling with LTX2 models and IC-LoRa guides (CORE-3) (#13325) --- comfy/context_windows.py | 466 ++++++++++++++++++++++---- comfy/ldm/lightricks/model.py | 4 +- comfy/model_base.py | 122 +++++++ comfy_extras/nodes_context_windows.py | 77 ++++- 4 files changed, 592 insertions(+), 77 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index db57537a2..5f9899c67 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -8,6 +8,8 @@ from abc import ABC, abstractmethod import logging import comfy.model_management import comfy.patcher_extension +import comfy.utils +import comfy.conds if TYPE_CHECKING: from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher @@ -51,12 +53,18 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0): self.index_list = index_list self.context_length = len(index_list) + self.context_overlap = context_overlap self.dim = dim self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) + self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} + self.guide_frames_indices: list[int] = [] + self.guide_overlap_info: list[tuple[int, int]] = [] + self.guide_kf_local_positions: list[int] = [] + self.guide_downscale_factors: list[int] = [] def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: @@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC): region_idx = int(self.center_ratio * num_regions) return min(max(region_idx, 0), num_regions - 1) + def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow': + if modality_idx == 0: + return self + return self.modality_windows[modality_idx] + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) +def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]): + """Compute which concatenated guide frames overlap with a context window. + + Each guide's latent-space start is derived from its first token's pixel-t-start + in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the + model's temporal_downscale_ratio. + + Args: + guide_entries: list of guide_attention_entry dicts + keyframe_idxs: per-token pixel coords cond tensor for the modality + temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio + window_index_list: the window's frame indices into the video portion + + Returns: + suffix_indices: indices into the guide_frames tensor for frame selection + overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment + kf_local_positions: window-local frame positions for keyframe_idxs regeneration + total_overlap: total number of overlapping guide frames + """ + window_set = set(window_index_list) + window_list = list(window_index_list) + suffix_indices = [] + overlap_info = [] + kf_local_positions = [] + suffix_base = 0 + token_offset = 0 + + for entry_idx, entry in enumerate(guide_entries): + first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item()) + latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio + guide_len = entry["latent_shape"][0] + entry_overlap = 0 + + for local_offset in range(guide_len): + video_pos = latent_start + local_offset + if video_pos in window_set: + suffix_indices.append(suffix_base + local_offset) + kf_local_positions.append(window_list.index(video_pos)) + entry_overlap += 1 + + if entry_overlap > 0: + overlap_info.append((entry_idx, entry_overlap)) + suffix_base += guide_len + token_offset += entry["pre_filter_count"] + + return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) + + +@dataclass +class WindowingState: + """Per-modality context windowing state for each step, + built using IndexListContextHandler._build_window_state(). + For non-multimodal models the lists are length 1 + """ + latents: list[torch.Tensor] # per-modality working latents (guide frames stripped) + guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents + guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata + keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation + latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal) + dim: int = 0 # primary modality temporal dim for context windowing + is_multimodal: bool = False + temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio + + def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow: + """Reformat window for multimodal contexts by deriving per-modality index lists. + Non-multimodal contexts return the input window unchanged. + """ + if not self.is_multimodal: + return window + + x = self.latents[0] + primary_total = self.latent_shapes[0][self.dim] + primary_overlap = window.context_overlap + map_shapes = self.latent_shapes + if x.size(self.dim) != primary_total: + map_shapes = list(self.latent_shapes) + video_shape = list(self.latent_shapes[0]) + video_shape[self.dim] = x.size(self.dim) + map_shapes[0] = torch.Size(video_shape) + try: + per_modality_indices = model.map_context_window_to_modalities( + window.index_list, map_shapes, self.dim) + except AttributeError: + raise NotImplementedError( + f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") + modality_windows = {} + for mod_idx in range(1, len(self.latents)): + modality_total_frames = self.latents[mod_idx].shape[self.dim] + ratio = modality_total_frames / primary_total if primary_total > 0 else 1 + modality_overlap = max(round(primary_overlap * ratio), 0) + modality_windows[mod_idx] = IndexListContextWindow( + per_modality_indices[mod_idx], dim=self.dim, + total_frames=modality_total_frames, + context_overlap=modality_overlap) + return IndexListContextWindow( + window.index_list, dim=self.dim, total_frames=x.shape[self.dim], + modality_windows=modality_windows, context_overlap=primary_overlap) + + def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]: + """Slice latents for a context window, injecting guide frames where applicable. + For multimodal contexts, uses the modality-specific windows derived in prepare_window(). + """ + sliced = [] + guide_frame_counts = [] + for idx in range(len(self.latents)): + modality_window = window.get_window_for_modality(idx) + retain = retain_index_list if idx == 0 else [] + s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain) + if self.guide_entries[idx] is not None: + s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx) + else: + ng = 0 + sliced.append(s) + guide_frame_counts.append(ng) + return sliced, guide_frame_counts + + def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow): + """Strip injected guide frames from per-cond, per-modality outputs in place.""" + for idx in range(len(self.latents)): + if guide_frame_counts[idx] > 0: + window_len = len(window.get_window_for_modality(idx).index_list) + for ci in range(len(out_per_modality)): + out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len) + + def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]: + guide_entries = self.guide_entries[modality_idx] + guide_frames = self.guide_latents[modality_idx] + keyframe_idxs = self.keyframe_idxs[modality_idx] + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap( + guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list) + # Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0. + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + kf_local_pos = [p + 1 for p in kf_local_pos] + window.guide_frames_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. + # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. + guide_downscale_factors = [] + if guide_frame_count > 0: + full_H = guide_frames.shape[3] + for entry_idx, _ in overlap_info: + entry_H = guide_entries[entry_idx]["latent_shape"][1] + guide_downscale_factors.append(full_H // entry_H) + window.guide_downscale_factors = guide_downscale_factors + + if guide_frame_count > 0: + idx = tuple([slice(None)] * self.dim + [suffix_idx]) + return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count + return latent_slice, 0 + + def patch_latent_shapes(self, sub_conds, new_shapes): + if not self.is_multimodal: + return + + for cond_list in sub_conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + + @dataclass class ContextSchedule: name: str @@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co class IndexListContextHandler(ContextHandlerABC): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, - causal_window_fix: bool=True): + latent_retain_index_list: list[int]=[], causal_window_fix: bool=True): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC): self.freenoise = freenoise self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.split_conds_to_windows = split_conds_to_windows + self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else [] self.causal_window_fix = causal_window_fix self.callbacks = {} + @staticmethod + def _get_latent_shapes(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + return model_conds['latent_shapes'].cond + return None + + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + + @staticmethod + def _get_keyframe_idxs(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + kf = model_conds.get('keyframe_idxs') + if kf is not None and hasattr(kf, 'cond') and kf.cond is not None: + return kf.cond + return None + + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: + """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio. + If guide frames are present on the primary modality, only the video portion is shuffled. + """ + guide_entries = self._get_guide_entries(conds) + guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 + + latent_shapes = self._get_latent_shapes(conds) + if latent_shapes is not None and len(latent_shapes) > 1: + modalities = comfy.utils.unpack_latents(noise, latent_shapes) + primary_total = latent_shapes[0][self.dim] + primary_video_count = modalities[0].size(self.dim) - guide_count + apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed) + for i in range(1, len(modalities)): + mod_total = latent_shapes[i][self.dim] + ratio = mod_total / primary_total if primary_total > 0 else 1 + mod_ctx_len = max(round(self.context_length * ratio), 1) + mod_ctx_overlap = max(round(self.context_overlap * ratio), 0) + modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) + noise, _ = comfy.utils.pack_latents(modalities) + return noise + video_count = noise.size(self.dim) - guide_count + apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed) + return noise + + def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState: + """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" + latent_shapes = self._get_latent_shapes(conds) + is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 + unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in] + + unpacked_latents_list = list(unpacked_latents) + guide_latents_list = [None] * len(unpacked_latents) + guide_entries_list = [None] * len(unpacked_latents) + keyframe_idxs_list = [None] * len(unpacked_latents) + + extracted_guide_entries = self._get_guide_entries(conds) + extracted_keyframe_idxs = self._get_keyframe_idxs(conds) + + # Strip guide frames (only from first modality for now) + if extracted_guide_entries is not None: + guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries) + if guide_count > 0: + x = unpacked_latents[0] + latent_count = x.size(self.dim) - guide_count + unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count) + guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count) + guide_entries_list[0] = extracted_guide_entries + keyframe_idxs_list[0] = extracted_keyframe_idxs + + + return WindowingState( + latents=unpacked_latents_list, + guide_latents=guide_latents_list, + guide_entries=guide_entries_list, + keyframe_idxs=keyframe_idxs_list, + latent_shapes=latent_shapes, + dim=self.dim, + is_multimodal=is_multimodal, + temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio) + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - # for now, assume first dim is batch - should have stored on BaseModel in actual implementation - if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute + total_frame_count = window_state.latents[0].size(self.dim) + if total_frame_count > self.context_length: + logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") + if self.latent_retain_index_list: + logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}") return True + logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).") return False def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: @@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) + sample_sigmas = model_options["transformer_options"]["sample_sigmas"] + current_timestep = timestep[0].to(sample_sigmas.dtype) + mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: return # substep from multi-step sampler: keep self._step from the last full step @@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): self._model = model self.set_step(timestep, model_options) - context_windows = self.get_context_windows(model, x_in, model_options) - enumerated_context_windows = list(enumerate(context_windows)) - conds_final = [torch.zeros_like(x_in) for _ in conds] + window_state = self._build_window_state(x_in, conds, model) + num_modalities = len(window_state.latents) + + context_windows = self.get_context_windows(model, window_state.latents[0], model_options) + enumerated_context_windows = list(enumerate(context_windows)) + total_windows = len(enumerated_context_windows) + + # Initialize per-modality accumulators (length 1 for single-modality) + accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] else: - counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) + # accumulate results from each context window for enum_window in enumerated_context_windows: - results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + results = self.evaluate_context_windows( + calc_cond_batch, model, x_in, conds, timestep, [enum_window], + model_options, window_state=window_state, total_windows=total_windows) for result in results: - self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) + # result.sub_conds_out is per-cond, per-modality: list[list[Tensor]] + for mod_idx in range(num_modalities): + mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))] + modality_window = result.window.get_window_for_modality(mod_idx) + self.combine_context_window_results( + window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window, + result.window_idx, total_windows, timestep, + accum[mod_idx], counts[mod_idx], biases[mod_idx]) + + # fuse accumulated results into final conds try: - # finalize conds - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - # relative is already normalized, so return as is - del counts_final - return conds_final - else: - # normalize conds via division by context usage counts - for i in range(len(conds_final)): - conds_final[i] /= counts_final[i] - del counts_final - return conds_final + result_out = [] + for ci in range(len(conds)): + finalized = [] + for mod_idx in range(num_modalities): + if self.fuse_method.name != ContextFuseMethods.RELATIVE: + accum[mod_idx][ci] /= counts[mod_idx][ci] + f = accum[mod_idx][ci] + + # if guide frames were injected, append them to the end of the fused latents for the next step + if window_state.guide_latents[mod_idx] is not None: + f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim) + finalized.append(f) + + # pack modalities together if needed + if window_state.is_multimodal and len(finalized) > 1: + packed, _ = comfy.utils.pack_latents(finalized) + else: + packed = finalized[0] + + result_out.append(packed) + return result_out finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], - model_options, device=None, first_device=None): + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, + timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, window_state: WindowingState, total_windows: int = None, + device=None, first_device=None): + """Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out + + For each window: + 1. Builds windows (for each modality if multimodal) + 2. Slices window for each modality + 3. Injects concatenated latent guide frames where present + 4. Packs together if needed and calls model + 5. Unpacks and strips any guides from outputs + """ + x = window_state.latents[0] + results: list[ContextResults] = [] for window_idx, window in enumerated_context_windows: # allow processing to end between context window executions for faster Cancel comfy.model_management.throw_exception_if_processing_interrupted() - # causal_window_fix: prepend a pre-window frame that will be stripped post-forward + # prepare the window accounting for multimodal windows + window = window_state.prepare_window(window, model) + + # causal_window_fix: prepend a pre-window frame that will be stripped post-forward. + # Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up. anchor_applied = False if self.causal_window_fix: anchor_idx = window.index_list[0] - 1 @@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC): window.causal_anchor_index = anchor_idx anchor_applied = True + # slice the window for each modality, injecting guide frames where applicable + sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device) + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) - # update exposed params + logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}" + + (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "") + ) + + # if multimodal, pack modalities together + if window_state.is_multimodal and len(sliced) > 1: + sub_x, sub_shapes = comfy.utils.pack_latents(sliced) + else: + sub_x, sub_shapes = sliced[0], [sliced[0].shape] + + # get resized conds for window model_options["transformer_options"]["context_window"] = window - # get subsections of x, timestep, conds - sub_x = window.get_tensor(x_in, device) - sub_timestep = window.get_tensor(timestep, device, dim=0) - sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + sub_timestep = window.get_tensor(timestep, dim=0) + sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds] + # if multimodal, patch latent_shapes in conds for correct unpacking in model + window_state.patch_latent_shapes(sub_conds, sub_shapes) + + # call model on window sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - if device is not None: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].to(x_in.device) - # strip causal_window_fix anchor if applied + # unpack outputs + out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + + # strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct if anchor_applied: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1) + for ci in range(len(out_per_modality)): + t = out_per_modality[ci][0] + out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1) - results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + # strip injected guide frames + window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window) + + results.append(ContextResults(window_idx, out_per_modality, sub_conds, window)) return results @@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC): biases_final[i][idx] = bias_total + bias else: # add conds and counts based on weights of fuse method - weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) + weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap) weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) for i in range(len(sub_conds_out)): window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) @@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC): callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) -def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): - # limit noise_shape length to context_length for more accurate vram use estimation +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs): + # Scale noise_shape to a single context window so VRAM estimation budgets per-window. model_options = kwargs.get("model_options", None) if model_options is None: raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") handler: IndexListContextHandler = model_options.get("context_handler", None) if handler is not None: noise_shape = list(noise_shape) - noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) - return executor(model, noise_shape, *args, **kwargs) + is_packed = len(noise_shape) == 3 and noise_shape[1] == 1 + if is_packed: + # TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a + # per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM. + pass + elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + return executor(model, noise_shape, conds, *args, **kwargs) def create_prepare_sampling_wrapper(model: ModelPatcher): @@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + conds = [guider.conds.get('positive', guider.conds.get('negative', []))] + noise = handler._apply_freenoise(noise, conds, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - def create_sampler_sample_wrapper(model: ModelPatcher): model.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, @@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher): _sampler_sample_wrapper ) - def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) @@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: return ContextSchedule(context_schedule, func) -def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): - return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None): + context_overlap = handler.context_overlap if context_overlap is None else context_overlap + return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap) def create_weights_flat(length: int, **kwargs) -> list[float]: @@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]: weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) return weight_sequence -def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): +def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs): # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 # only expected overlap is given different weights weights_torch = torch.ones((length)) # blend left-side on all except first window if min(idxs) > 0: - ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) - weights_torch[:handler.context_overlap] = ramp_up + ramp_up = torch.linspace(1e-37, 1, context_overlap) + weights_torch[:context_overlap] = ramp_up # blend right-side on all except last window if max(idxs) < full_length-1: - ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) - weights_torch[-handler.context_overlap:] = ramp_down + ramp_down = torch.linspace(1, 1e-37, context_overlap) + weights_torch[-context_overlap:] = ramp_down return weights_torch class ContextFuseMethods: diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index e0a4a0f9b..9953b6679 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel): ) grid_mask = None - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: additional_args.update({ "orig_patchified_shape": list(x.shape)}) denoise_mask = self.patchifier.patchify(denoise_mask)[0] grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] @@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel): x = x * (1 + scale) + shift x = self.proj_out(x) - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: grid_mask = kwargs["grid_mask"] orig_patchified_shape = kwargs["orig_patchified_shape"] full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) diff --git a/comfy/model_base.py b/comfy/model_base.py index f49da50ae..264dbb9b3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model +import comfy.ldm.lightricks.symmetric_patchifier import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC @@ -1204,6 +1205,127 @@ class LTXAV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): + result = [primary_indices] + if len(latent_shapes) < 2: + return result + + video_total = latent_shapes[0][dim] + + for i in range(1, len(latent_shapes)): + mod_total = latent_shapes[i][dim] + # Map each primary index to its proportional range of modality indices and + # concatenate in order. Preserves wrapped/strided geometry so the modality + # attends to the same temporal regions as the primary window. + mod_indices = [] + seen = set() + for v_idx in primary_indices: + a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1) + a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total) + if a_end <= a_start: + a_end = a_start + 1 + for a in range(a_start, a_end): + if a not in seen: + seen.add(a) + mod_indices.append(a) + result.append(mod_indices) + + return result + + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # Audio denoise mask — slice using audio modality window + if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: + audio_window = window.modality_windows.get(1) + if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) + return cond_value._copy_with(sliced) + + # Video denoise mask — split into video + guide portions, slice each + if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + cond_tensor = cond_value.cond + guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) + if guide_count > 0: + T_video = x_in.size(window.dim) + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + suffix_indices = window.guide_frames_indices + if suffix_indices: + idx = tuple([slice(None)] * window.dim + [suffix_indices]) + sliced_guide = guide_mask[idx].to(device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + else: + return cond_value._copy_with(sliced_video) + + # Keyframe indices — regenerate pixel coords for window, select guide positions + if cond_key == "keyframe_idxs": + kf_local_pos = window.guide_kf_local_positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + # account for causal_window_fix anchor in coord space size + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + window_len += 1 + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + scale_factors = self.diffusion_model.vae_scale_factors + pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords( + latent_coords, + scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + + # Adjust spatial end positions for dilated (downscaled) guides. + # Each guide entry may have a different downscale factor; expand the + # per-entry factor to cover all tokens belonging to that entry. + downscale_factors = window.guide_downscale_factors + overlap_info = window.guide_overlap_info + if downscale_factors: + per_token_factor = [] + for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors): + per_token_factor.extend([dsf] * (overlap_count * H * W)) + factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype) + spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor( + scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset + + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + # Guide attention entries — adjust per-guide counts based on window overlap + if cond_key == "guide_attention_entries": + overlap_info = window.guide_overlap_info + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) + + return None + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 098c26f23..15d2dc506 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -13,21 +13,22 @@ class ContextWindowsManualNode(io.ComfyNode): description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, - ], tooltip="The stride of the context window."), - io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + ], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."), io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."), ], outputs=[ @@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model: + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -51,6 +52,7 @@ class ContextWindowsManualNode(io.ComfyNode): freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows, + latent_retain_index_list=latent_retain_index_list, causal_window_fix=causal_window_fix, ) # make memory usage calculation only take into account the context window latents @@ -65,33 +67,71 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): schema = super().define_schema() schema.node_id = "WanContextWindowsManual" schema.display_name = "WAN Context Windows (Manual)" - schema.description = "Manually set context windows for WAN-like models (dim=2)." + schema.display_name = "Wan Context Windows" + schema.description = "Set context windows for Wan-like models." schema.category="model/patch/wan" schema.inputs = [ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, - ], tooltip="The stride of the context window."), + ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), - io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), - io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True), + io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True), ] return schema @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: - context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 - context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) + retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 + context_overlap = max(context_overlap // 4, 0) # at least overlap 0 + retain_index_list = "0" if retain_first_frame else "" + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows) + + +class LTXVContextWindowsNode(ContextWindowsManualNode): + @classmethod + def define_schema(cls) -> io.Schema: + schema = super().define_schema() + schema.node_id = "LTXVContextWindows" + schema.display_name = "LTXV Context Windows" + schema.description = "Set context windows for LTXV-like models." + schema.inputs = [ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."), + io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True), + io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True), + ] + return schema + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool, + retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1 + context_overlap = max(context_overlap // 8, 0) # at least overlap 0 + retain_index_list = "0" if retain_first_frame else "" + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, + cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows) class ContextWindowsExtension(ComfyExtension): @@ -99,6 +139,7 @@ class ContextWindowsExtension(ComfyExtension): return [ ContextWindowsManualNode, WanContextWindowsManualNode, + LTXVContextWindowsNode, ] def comfy_entrypoint(): From 69d34f265407a1829da143347e13975d770ca6d7 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sat, 20 Jun 2026 08:01:28 +0800 Subject: [PATCH 5/5] Rename a bunch of nodes (#14547) --- comfy_extras/nodes_frame_interpolation.py | 2 +- comfy_extras/nodes_logic.py | 3 ++- comfy_extras/nodes_primitive.py | 13 +++++-------- comfy_extras/nodes_video.py | 9 ++------- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 4d5bca17e..44708e5ec 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -77,7 +77,7 @@ class FrameInterpolate(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FrameInterpolate", - display_name="Frame Interpolate", + display_name="Run Frame Interpolation Model", category="video", search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], inputs=[ diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 95f6ab848..13c1685f7 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -89,7 +89,8 @@ class SwitchNode(io.ComfyNode): template = io.MatchType.Template("switch") return io.Schema( node_id="ComfySwitchNode", - display_name="Switch", + search_aliases=["if", "then", "switch", "conditional", "branch"], + display_name="If/Else Switch", category="utilities/logic", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index c44b09098..7f90daf14 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -10,12 +10,11 @@ class String(io.ComfyNode): return io.Schema( node_id="PrimitiveString", search_aliases=["text", "string", "text box", "prompt"], - display_name="Text String", + display_name="Text String (DEPRECATED)", category="utilities/primitive", - inputs=[ - io.String.Input("value"), - ], + inputs=[io.String.Input("value")], outputs=[io.String.Output()], + is_deprecated=True ) @classmethod @@ -29,12 +28,10 @@ class StringMultiline(io.ComfyNode): return io.Schema( node_id="PrimitiveStringMultiline", search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"], - display_name="Text String (Multiline)", + display_name="Input Text", category="utilities/primitive", essentials_category="Basics", - inputs=[ - io.String.Input("value", multiline=True), - ], + inputs=[io.String.Input("value", multiline=True)], outputs=[io.String.Output()], ) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 050a897dd..8d76af1c1 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -233,13 +233,8 @@ class VideoSlice(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Video Slice", - display_name="Video Slice", - search_aliases=[ - "trim video duration", - "skip first frames", - "frame load cap", - "start time", - ], + display_name="Trim Video", + search_aliases=["trim video duration", "skip first frames", "frame load cap", "start time"], category="video", essentials_category="Video Tools", inputs=[