From 980621da83267beffcb84839a27101b7092256e7 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Wed, 11 Mar 2026 08:49:38 -0700
Subject: [PATCH 01/80] comfy-aimdo 0.2.10 (#12890)
Comfy Aimdo 0.2.10 fixes the aimdo allocator hook for legacy cudaMalloc
consumers. Some consumers of cudaMalloc assume implicit synchronization
built in closed source logic inside cuda. This is preserved by passing
through to cuda as-is and accouting after the fact as opposed to
integrating these hooks with Aimdos VMA based allocator.
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index bb58f8d01..89cd994e9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,7 +23,7 @@ SQLAlchemy
filelock
av>=14.2.0
comfy-kitchen>=0.2.7
-comfy-aimdo>=0.2.9
+comfy-aimdo>=0.2.10
requests
simpleeval>=1.0.0
blake3
From 3365008dfe5a7a46cbe76d8ad0d7efb054617733 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 11 Mar 2026 18:53:55 +0200
Subject: [PATCH 02/80] feat(api-nodes): add Reve Image nodes (#12848)
---
comfy_api_nodes/apis/reve.py | 68 ++++++
comfy_api_nodes/nodes_reve.py | 395 +++++++++++++++++++++++++++++++++
comfy_api_nodes/util/client.py | 12 +-
3 files changed, 474 insertions(+), 1 deletion(-)
create mode 100644 comfy_api_nodes/apis/reve.py
create mode 100644 comfy_api_nodes/nodes_reve.py
diff --git a/comfy_api_nodes/apis/reve.py b/comfy_api_nodes/apis/reve.py
new file mode 100644
index 000000000..c6b5a69d8
--- /dev/null
+++ b/comfy_api_nodes/apis/reve.py
@@ -0,0 +1,68 @@
+from pydantic import BaseModel, Field
+
+
+class RevePostprocessingOperation(BaseModel):
+ process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
+ upscale_factor: int | None = Field(
+ None,
+ description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
+ ge=2,
+ le=4,
+ )
+
+
+class ReveImageCreateRequest(BaseModel):
+ prompt: str = Field(...)
+ aspect_ratio: str | None = Field(...)
+ version: str = Field(...)
+ test_time_scaling: int = Field(
+ ...,
+ description="If included, the model will spend more effort making better images. Values between 1 and 15.",
+ ge=1,
+ le=15,
+ )
+ postprocessing: list[RevePostprocessingOperation] | None = Field(
+ None, description="Optional postprocessing operations to apply after generation."
+ )
+
+
+class ReveImageEditRequest(BaseModel):
+ edit_instruction: str = Field(...)
+ reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
+ aspect_ratio: str | None = Field(...)
+ version: str = Field(...)
+ test_time_scaling: int | None = Field(
+ ...,
+ description="If included, the model will spend more effort making better images. Values between 1 and 15.",
+ ge=1,
+ le=15,
+ )
+ postprocessing: list[RevePostprocessingOperation] | None = Field(
+ None, description="Optional postprocessing operations to apply after generation."
+ )
+
+
+class ReveImageRemixRequest(BaseModel):
+ prompt: str = Field(...)
+ reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
+ aspect_ratio: str | None = Field(...)
+ version: str = Field(...)
+ test_time_scaling: int | None = Field(
+ ...,
+ description="If included, the model will spend more effort making better images. Values between 1 and 15.",
+ ge=1,
+ le=15,
+ )
+ postprocessing: list[RevePostprocessingOperation] | None = Field(
+ None, description="Optional postprocessing operations to apply after generation."
+ )
+
+
+class ReveImageResponse(BaseModel):
+ image: str | None = Field(None, description="The base64 encoded image data.")
+ request_id: str | None = Field(None, description="A unique id for the request.")
+ credits_used: float | None = Field(None, description="The number of credits used for this request.")
+ version: str | None = Field(None, description="The specific model version used.")
+ content_violation: bool | None = Field(
+ None, description="Indicates whether the generated image violates the content policy."
+ )
diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py
new file mode 100644
index 000000000..608d9f058
--- /dev/null
+++ b/comfy_api_nodes/nodes_reve.py
@@ -0,0 +1,395 @@
+from io import BytesIO
+
+from typing_extensions import override
+
+from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api_nodes.apis.reve import (
+ ReveImageCreateRequest,
+ ReveImageEditRequest,
+ ReveImageRemixRequest,
+ RevePostprocessingOperation,
+)
+from comfy_api_nodes.util import (
+ ApiEndpoint,
+ bytesio_to_image_tensor,
+ sync_op_raw,
+ tensor_to_base64_string,
+ validate_string,
+)
+
+
+def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
+ ops = []
+ if upscale["upscale"] == "enabled":
+ ops.append(
+ RevePostprocessingOperation(
+ process="upscale",
+ upscale_factor=upscale["upscale_factor"],
+ )
+ )
+ if remove_background:
+ ops.append(RevePostprocessingOperation(process="remove_background"))
+ return ops or None
+
+
+def _postprocessing_inputs():
+ return [
+ IO.DynamicCombo.Input(
+ "upscale",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option(
+ "enabled",
+ [
+ IO.Int.Input(
+ "upscale_factor",
+ default=2,
+ min=2,
+ max=4,
+ step=1,
+ tooltip="Upscale factor (2x, 3x, or 4x).",
+ ),
+ ],
+ ),
+ ],
+ tooltip="Upscale the generated image. May add additional cost.",
+ ),
+ IO.Boolean.Input(
+ "remove_background",
+ default=False,
+ tooltip="Remove the background from the generated image. May add additional cost.",
+ ),
+ ]
+
+
+def _reve_price_extractor(headers: dict) -> float | None:
+ credits_used = headers.get("x-reve-credits-used")
+ if credits_used is not None:
+ return float(credits_used) / 524.48
+ return None
+
+
+def _reve_response_header_validator(headers: dict) -> None:
+ error_code = headers.get("x-reve-error-code")
+ if error_code:
+ raise ValueError(f"Reve API error: {error_code}")
+ if headers.get("x-reve-content-violation", "").lower() == "true":
+ raise ValueError("The generated image was flagged for content policy violation.")
+
+
+def _model_inputs(versions: list[str], aspect_ratios: list[str]):
+ return [
+ IO.DynamicCombo.Option(
+ version,
+ [
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=aspect_ratios,
+ tooltip="Aspect ratio of the output image.",
+ ),
+ IO.Int.Input(
+ "test_time_scaling",
+ default=1,
+ min=1,
+ max=5,
+ step=1,
+ tooltip="Higher values produce better images but cost more credits.",
+ advanced=True,
+ ),
+ ],
+ )
+ for version in versions
+ ]
+
+
+class ReveImageCreateNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ReveImageCreateNode",
+ display_name="Reve Image Create",
+ category="api node/image/Reve",
+ description="Generate images from text descriptions using Reve.",
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Text description of the desired image. Maximum 2560 characters.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=_model_inputs(
+ ["reve-create@20250915"],
+ aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
+ ),
+ tooltip="Model version to use for generation.",
+ ),
+ *_postprocessing_inputs(),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ model: dict,
+ upscale: dict,
+ remove_background: bool,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2560)
+ response = await sync_op_raw(
+ cls,
+ ApiEndpoint(
+ path="/proxy/reve/v1/image/create",
+ method="POST",
+ headers={"Accept": "image/webp"},
+ ),
+ as_binary=True,
+ price_extractor=_reve_price_extractor,
+ response_header_validator=_reve_response_header_validator,
+ data=ReveImageCreateRequest(
+ prompt=prompt,
+ aspect_ratio=model["aspect_ratio"],
+ version=model["model"],
+ test_time_scaling=model["test_time_scaling"],
+ postprocessing=_build_postprocessing(upscale, remove_background),
+ ),
+ )
+ return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
+
+
+class ReveImageEditNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ReveImageEditNode",
+ display_name="Reve Image Edit",
+ category="api node/image/Reve",
+ description="Edit images using natural language instructions with Reve.",
+ inputs=[
+ IO.Image.Input("image", tooltip="The image to edit."),
+ IO.String.Input(
+ "edit_instruction",
+ multiline=True,
+ default="",
+ tooltip="Text description of how to edit the image. Maximum 2560 characters.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=_model_inputs(
+ ["reve-edit@20250915", "reve-edit-fast@20251030"],
+ aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
+ ),
+ tooltip="Model version to use for editing.",
+ ),
+ *_postprocessing_inputs(),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=["model"],
+ ),
+ expr="""
+ (
+ $isFast := $contains(widgets.model, "fast");
+ $base := $isFast ? 0.01001 : 0.0572;
+ {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ image: Input.Image,
+ edit_instruction: str,
+ model: dict,
+ upscale: dict,
+ remove_background: bool,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(edit_instruction, min_length=1, max_length=2560)
+ tts = model["test_time_scaling"]
+ ar = model["aspect_ratio"]
+ response = await sync_op_raw(
+ cls,
+ ApiEndpoint(
+ path="/proxy/reve/v1/image/edit",
+ method="POST",
+ headers={"Accept": "image/webp"},
+ ),
+ as_binary=True,
+ price_extractor=_reve_price_extractor,
+ response_header_validator=_reve_response_header_validator,
+ data=ReveImageEditRequest(
+ edit_instruction=edit_instruction,
+ reference_image=tensor_to_base64_string(image),
+ aspect_ratio=ar if ar != "auto" else None,
+ version=model["model"],
+ test_time_scaling=tts if tts and tts > 1 else None,
+ postprocessing=_build_postprocessing(upscale, remove_background),
+ ),
+ )
+ return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
+
+
+class ReveImageRemixNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ReveImageRemixNode",
+ display_name="Reve Image Remix",
+ category="api node/image/Reve",
+ description="Combine reference images with text prompts to create new images using Reve.",
+ inputs=[
+ IO.Autogrow.Input(
+ "reference_images",
+ template=IO.Autogrow.TemplatePrefix(
+ IO.Image.Input("image"),
+ prefix="image_",
+ min=1,
+ max=6,
+ ),
+ ),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Text description of the desired image. "
+ "May include XML img tags to reference specific images by index, "
+ "e.g.
0,
1, etc.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=_model_inputs(
+ ["reve-remix@20250915", "reve-remix-fast@20251030"],
+ aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
+ ),
+ tooltip="Model version to use for remixing.",
+ ),
+ *_postprocessing_inputs(),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=["model"],
+ ),
+ expr="""
+ (
+ $isFast := $contains(widgets.model, "fast");
+ $base := $isFast ? 0.01001 : 0.0572;
+ {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ reference_images: IO.Autogrow.Type,
+ prompt: str,
+ model: dict,
+ upscale: dict,
+ remove_background: bool,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2560)
+ if not reference_images:
+ raise ValueError("At least one reference image is required.")
+ ref_base64_list = []
+ for key in reference_images:
+ ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
+ if len(ref_base64_list) > 6:
+ raise ValueError("Maximum 6 reference images are allowed.")
+ tts = model["test_time_scaling"]
+ ar = model["aspect_ratio"]
+ response = await sync_op_raw(
+ cls,
+ ApiEndpoint(
+ path="/proxy/reve/v1/image/remix",
+ method="POST",
+ headers={"Accept": "image/webp"},
+ ),
+ as_binary=True,
+ price_extractor=_reve_price_extractor,
+ response_header_validator=_reve_response_header_validator,
+ data=ReveImageRemixRequest(
+ prompt=prompt,
+ reference_images=ref_base64_list,
+ aspect_ratio=ar if ar != "auto" else None,
+ version=model["model"],
+ test_time_scaling=tts if tts and tts > 1 else None,
+ postprocessing=_build_postprocessing(upscale, remove_background),
+ ),
+ )
+ return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
+
+
+class ReveExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ ReveImageCreateNode,
+ ReveImageEditNode,
+ ReveImageRemixNode,
+ ]
+
+
+async def comfy_entrypoint() -> ReveExtension:
+ return ReveExtension()
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index 79ffb77c1..9d730b81a 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -67,6 +67,7 @@ class _RequestConfig:
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
+ response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass
@@ -202,11 +203,13 @@ async def sync_op_raw(
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
+ response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON).
- If as_binary=True: returns bytes.
+ - response_header_validator: optional callback receiving response headers dict
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
@@ -232,6 +235,7 @@ async def sync_op_raw(
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
+ response_header_validator=response_header_validator,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
+ resp_headers = {k.lower(): v for k, v in resp.headers.items()}
+ if cfg.price_extractor:
+ with contextlib.suppress(Exception):
+ extracted_price = cfg.price_extractor(resp_headers)
+ if cfg.response_header_validator:
+ cfg.response_header_validator(resp_headers)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method,
request_url=url,
response_status_code=resp.status,
- response_headers=dict(resp.headers),
+ response_headers=resp_headers,
response_content=bytes_payload,
)
return bytes_payload
From 4f4f8659c205069f74da8ac47378a5b1c0e142ca Mon Sep 17 00:00:00 2001
From: Adi Borochov <58855640+adiborochov@users.noreply.github.com>
Date: Wed, 11 Mar 2026 19:04:13 +0200
Subject: [PATCH 03/80] fix: guard torch.AcceleratorError for compatibility
with torch < 2.8.0 (#12874)
* fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0
torch.AcceleratorError was introduced in PyTorch 2.8.0. Accessing it
directly raises AttributeError on older versions. Use a try/except
fallback at module load time, consistent with the existing pattern used
for OOM_EXCEPTION.
* fix: address review feedback for AcceleratorError compat
- Fall back to RuntimeError instead of type(None) for ACCELERATOR_ERROR,
consistent with OOM_EXCEPTION fallback pattern and valid for except clauses
- Add "out of memory" message introspection for RuntimeError fallback case
- Use RuntimeError directly in discard_cuda_async_error except clause
---------
---
comfy/model_management.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 81550c790..81c89b180 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -270,10 +270,15 @@ try:
except:
OOM_EXCEPTION = Exception
+try:
+ ACCELERATOR_ERROR = torch.AcceleratorError
+except AttributeError:
+ ACCELERATOR_ERROR = RuntimeError
+
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
- if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
+ if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
discard_cuda_async_error()
return True
return False
@@ -1275,7 +1280,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
- except torch.AcceleratorError:
+ except RuntimeError:
#Dump it! We already know about it from the synchronous return
pass
From f6274c06b4e7bce8adbc1c60ae5a4c168825a614 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 11 Mar 2026 13:37:31 -0700
Subject: [PATCH 04/80] Fix issue with batch_size > 1 on some models. (#12892)
---
comfy/ldm/flux/layers.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index e20d498f8..e28d704b4 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor * m_mult
else:
for d in modulation_dims:
- tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
+ tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
if m_add is not None:
- tensor[:, d[0]:d[1]] += m_add[:, d[2]]
+ tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
return tensor
From abc87d36693b007bdbdab5ee753ccea6326acb34 Mon Sep 17 00:00:00 2001
From: Comfy Org PR Bot
Date: Thu, 12 Mar 2026 06:04:51 +0900
Subject: [PATCH 05/80] Bump comfyui-frontend-package to 1.41.15 (#12891)
---------
Co-authored-by: Alexander Brown
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 89cd994e9..ffa5fa376 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.39.19
+comfyui-frontend-package==1.41.15
comfyui-workflow-templates==0.9.18
comfyui-embedded-docs==0.4.3
torch
From 9ce4c3dd87c9c77dfe0371045fa920ce55e08973 Mon Sep 17 00:00:00 2001
From: Comfy Org PR Bot
Date: Thu, 12 Mar 2026 10:16:30 +0900
Subject: [PATCH 06/80] Bump comfyui-frontend-package to 1.41.16 (#12894)
Co-authored-by: github-actions[bot]
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index ffa5fa376..2272d121a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.41.15
+comfyui-frontend-package==1.41.16
comfyui-workflow-templates==0.9.18
comfyui-embedded-docs==0.4.3
torch
From 8f9ea495713d4565dfe564e0c06f362bd627f902 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 11 Mar 2026 21:17:31 -0700
Subject: [PATCH 07/80] Bump comfy-kitchen version to 0.2.8 (#12895)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 2272d121a..96cd0254f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
filelock
av>=14.2.0
-comfy-kitchen>=0.2.7
+comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.10
requests
simpleeval>=1.0.0
From 44f1246c899ed188759f799dbd00c31def289114 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 12 Mar 2026 08:30:50 -0700
Subject: [PATCH 08/80] Support flux 2 klein kv cache model: Use the
FluxKVCache node. (#12905)
---
comfy/ldm/flux/model.py | 76 ++++++++++++++++++++++++++++++++------
comfy_extras/nodes_flux.py | 64 ++++++++++++++++++++++++++++++++
2 files changed, 129 insertions(+), 11 deletions(-)
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index 00f12c031..8e7912e6d 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -44,6 +44,22 @@ class FluxParams:
txt_norm: bool = False
+def invert_slices(slices, length):
+ sorted_slices = sorted(slices)
+ result = []
+ current = 0
+
+ for start, end in sorted_slices:
+ if current < start:
+ result.append((current, start))
+ current = max(current, end)
+
+ if current < length:
+ result.append((current, length))
+
+ return result
+
+
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -138,6 +154,7 @@ class Flux(nn.Module):
y: Tensor,
guidance: Tensor = None,
control = None,
+ timestep_zero_index=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
@@ -164,10 +181,6 @@ class Flux(nn.Module):
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
- vec_orig = vec
- if self.params.global_modulation:
- vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
-
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
@@ -182,6 +195,24 @@ class Flux(nn.Module):
else:
pe = None
+ vec_orig = vec
+ txt_vec = vec
+ extra_kwargs = {}
+ if timestep_zero_index is not None:
+ modulation_dims = []
+ batch = vec.shape[0] // 2
+ vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
+ invert = invert_slices(timestep_zero_index, img.shape[1])
+ for s in invert:
+ modulation_dims.append((s[0], s[1], 0))
+ for s in timestep_zero_index:
+ modulation_dims.append((s[0], s[1], 1))
+ extra_kwargs["modulation_dims_img"] = modulation_dims
+ txt_vec = vec[:batch]
+
+ if self.params.global_modulation:
+ vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
+
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
@@ -195,7 +226,8 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
- transformer_options=args.get("transformer_options"))
+ transformer_options=args.get("transformer_options"),
+ **extra_kwargs)
return out
out = blocks_replace[("double_block", i)]({"img": img,
@@ -213,7 +245,8 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
attn_mask=attn_mask,
- transformer_options=transformer_options)
+ transformer_options=transformer_options,
+ **extra_kwargs)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -230,6 +263,12 @@ class Flux(nn.Module):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
+ extra_kwargs = {}
+ if timestep_zero_index is not None:
+ lambda a: 0 if a == 0 else a + txt.shape[1]
+ modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
+ extra_kwargs["modulation_dims"] = modulation_dims_combined
+
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@@ -242,7 +281,8 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
- transformer_options=args.get("transformer_options"))
+ transformer_options=args.get("transformer_options"),
+ **extra_kwargs)
return out
out = blocks_replace[("single_block", i)]({"img": img,
@@ -253,7 +293,7 @@ class Flux(nn.Module):
{"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -264,7 +304,11 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
- img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
+ extra_kwargs = {}
+ if timestep_zero_index is not None:
+ extra_kwargs["modulation_dims"] = modulation_dims
+
+ img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@@ -312,13 +356,16 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
+ timestep_zero_index = None
if ref_latents is not None:
+ ref_num_tokens = []
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
+ timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents:
- if ref_latents_method == "index":
+ if ref_latents_method in ("index", "index_timestep_zero"):
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
@@ -342,6 +389,13 @@ class Flux(nn.Module):
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
+ ref_num_tokens.append(kontext.shape[1])
+ if timestep_zero:
+ if index > 0:
+ timestep = torch.cat([timestep, timestep * 0], dim=0)
+ timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
+ transformer_options = transformer_options.copy()
+ transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@@ -349,6 +403,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
- out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
+ out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py
index fe9552022..c366d0d5b 100644
--- a/comfy_extras/nodes_flux.py
+++ b/comfy_extras/nodes_flux.py
@@ -6,6 +6,7 @@ import comfy.model_management
import torch
import math
import nodes
+import comfy.ldm.flux.math
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
@@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
+class KV_Attn_Input:
+ def __init__(self):
+ self.cache = {}
+
+ def __call__(self, q, k, v, extra_options, **kwargs):
+ reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
+ if len(reference_image_num_tokens) == 0:
+ return {}
+
+ ref_toks = sum(reference_image_num_tokens)
+ cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
+ if cache_key in self.cache:
+ kk, vv = self.cache[cache_key]
+ self.set_cache = False
+ return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
+
+ self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:])
+ self.set_cache = True
+ return {"q": q, "k": k, "v": v}
+
+ def cleanup(self):
+ self.cache = {}
+
+
+class FluxKVCache(io.ComfyNode):
+ @classmethod
+ def define_schema(cls) -> io.Schema:
+ return io.Schema(
+ node_id="FluxKVCache",
+ display_name="Flux KV Cache",
+ description="Enables KV Cache optimization for reference images on Flux family models.",
+ category="",
+ is_experimental=True,
+ inputs=[
+ io.Model.Input("model", tooltip="The model to use KV Cache on."),
+ ],
+ outputs=[
+ io.Model.Output(tooltip="The patched model with KV Cache enabled."),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, model: io.Model.Type) -> io.NodeOutput:
+ m = model.clone()
+ input_patch_obj = KV_Attn_Input()
+
+ def model_input_patch(inputs):
+ if len(input_patch_obj.cache) > 0:
+ ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
+ if ref_image_tokens > 0:
+ img = inputs["img"]
+ inputs["img"] = img[:, :-ref_image_tokens]
+ return inputs
+
+ m.set_model_attn1_patch(input_patch_obj)
+ m.set_model_post_input_patch(model_input_patch)
+ if hasattr(model.model.diffusion_model, "params"):
+ m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
+ else:
+ m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
+
+ return io.NodeOutput(m)
class FluxExtension(ComfyExtension):
@override
@@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
+ FluxKVCache,
]
From 73d9599495e45c22ef3672176f34945deeea5444 Mon Sep 17 00:00:00 2001
From: Terry Jia
Date: Thu, 12 Mar 2026 09:55:29 -0700
Subject: [PATCH 09/80] add painter node (#12294)
* add painter node
* use io.Color
* code improve
---------
Co-authored-by: guill
---
comfy_extras/nodes_painter.py | 127 ++++++++++++++++++++++++++++++++++
nodes.py | 1 +
2 files changed, 128 insertions(+)
create mode 100644 comfy_extras/nodes_painter.py
diff --git a/comfy_extras/nodes_painter.py b/comfy_extras/nodes_painter.py
new file mode 100644
index 000000000..b9ecdf5ea
--- /dev/null
+++ b/comfy_extras/nodes_painter.py
@@ -0,0 +1,127 @@
+from __future__ import annotations
+
+import hashlib
+import os
+
+import numpy as np
+import torch
+from PIL import Image
+
+import folder_paths
+import node_helpers
+from comfy_api.latest import ComfyExtension, io, UI
+from typing_extensions import override
+
+
+def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
+ hex_color = hex_color.lstrip("#")
+ if len(hex_color) != 6:
+ return (0.0, 0.0, 0.0)
+ r = int(hex_color[0:2], 16) / 255.0
+ g = int(hex_color[2:4], 16) / 255.0
+ b = int(hex_color[4:6], 16) / 255.0
+ return (r, g, b)
+
+
+class PainterNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="Painter",
+ display_name="Painter",
+ category="image",
+ inputs=[
+ io.Image.Input(
+ "image",
+ optional=True,
+ tooltip="Optional base image to paint over",
+ ),
+ io.String.Input(
+ "mask",
+ default="",
+ socketless=True,
+ extra_dict={"widgetType": "PAINTER", "image_upload": True},
+ ),
+ io.Int.Input(
+ "width",
+ default=512,
+ min=64,
+ max=4096,
+ step=64,
+ socketless=True,
+ extra_dict={"hidden": True},
+ ),
+ io.Int.Input(
+ "height",
+ default=512,
+ min=64,
+ max=4096,
+ step=64,
+ socketless=True,
+ extra_dict={"hidden": True},
+ ),
+ io.Color.Input("bg_color", default="#000000"),
+ ],
+ outputs=[
+ io.Image.Output("IMAGE"),
+ io.Mask.Output("MASK"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
+ if image is not None:
+ base_image = image[:1]
+ h, w = base_image.shape[1], base_image.shape[2]
+ else:
+ h, w = height, width
+ r, g, b = hex_to_rgb(bg_color)
+ base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
+ base_image[0, :, :, 0] = r
+ base_image[0, :, :, 1] = g
+ base_image[0, :, :, 2] = b
+
+ if mask and mask.strip():
+ mask_path = folder_paths.get_annotated_filepath(mask)
+ painter_img = node_helpers.pillow(Image.open, mask_path)
+ painter_img = painter_img.convert("RGBA")
+
+ if painter_img.size != (w, h):
+ painter_img = painter_img.resize((w, h), Image.LANCZOS)
+
+ painter_np = np.array(painter_img).astype(np.float32) / 255.0
+ painter_rgb = painter_np[:, :, :3]
+ painter_alpha = painter_np[:, :, 3:4]
+
+ mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
+
+ base_np = base_image[0].cpu().numpy()
+ composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
+ out_image = torch.from_numpy(composited).unsqueeze(0)
+ else:
+ mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
+ out_image = base_image
+
+ return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
+
+ @classmethod
+ def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
+ if mask and mask.strip():
+ mask_path = folder_paths.get_annotated_filepath(mask)
+ if os.path.exists(mask_path):
+ m = hashlib.sha256()
+ with open(mask_path, "rb") as f:
+ m.update(f.read())
+ return m.digest().hex()
+ return ""
+
+
+
+class PainterExtension(ComfyExtension):
+ @override
+ async def get_node_list(self):
+ return [PainterNode]
+
+
+async def comfy_entrypoint():
+ return PainterExtension()
diff --git a/nodes.py b/nodes.py
index 0ef23b640..eb63f9d44 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2450,6 +2450,7 @@ async def init_builtin_extra_nodes():
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
+ "nodes_painter.py",
]
import_failed = []
From 3fa8c5686dc86fe4e63ad3ca84d71524792a17b1 Mon Sep 17 00:00:00 2001
From: Terry Jia
Date: Thu, 12 Mar 2026 10:14:28 -0700
Subject: [PATCH 10/80] fix: use frontend-compatible format for Float
gradient_stops (#12789)
Co-authored-by: guill
Co-authored-by: Jedrzej Kosinski
---
comfy/comfy_types/node_typing.py | 4 ++--
comfy_api/latest/_io.py | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py
index 92b1acbd5..57126fa4a 100644
--- a/comfy/comfy_types/node_typing.py
+++ b/comfy/comfy_types/node_typing.py
@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
- gradient_stops: NotRequired[list[list[float]]]
- """Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
+ gradient_stops: NotRequired[list[dict]]
+ """Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
class HiddenInputTypeDict(TypedDict):
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 050031dc0..7ca8f4e0c 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
- display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
+ display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
From 712411d53919350ae5050cbdf7ed60fcc2b52cda Mon Sep 17 00:00:00 2001
From: ComfyUI Wiki
Date: Fri, 13 Mar 2026 03:16:54 +0800
Subject: [PATCH 11/80] chore: update workflow templates to v0.9.21 (#12908)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 96cd0254f..a2e53671e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.41.16
-comfyui-workflow-templates==0.9.18
+comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3
torch
torchsde
From 47e1e316c580ce6bf264cb069bffc10a50d3f167 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 12 Mar 2026 13:54:38 -0700
Subject: [PATCH 12/80] Lower kv cache memory usage. (#12909)
---
comfy_extras/nodes_flux.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py
index c366d0d5b..3a23c7d04 100644
--- a/comfy_extras/nodes_flux.py
+++ b/comfy_extras/nodes_flux.py
@@ -248,7 +248,7 @@ class KV_Attn_Input:
self.set_cache = False
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
- self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:])
+ self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
self.set_cache = True
return {"q": q, "k": k, "v": v}
From 8d9faaa181b9089cf8e4e00284443ef5c3405a12 Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Thu, 12 Mar 2026 15:14:59 -0700
Subject: [PATCH 13/80] Update requirements.txt (#12910)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index a2e53671e..511c62fee 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.41.16
+comfyui-frontend-package==1.41.18
comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3
torch
From af7b4a921d7abab7c852d7b5febb654be6e57eba Mon Sep 17 00:00:00 2001
From: Deep Mehta <42841935+deepme987@users.noreply.github.com>
Date: Thu, 12 Mar 2026 16:09:07 -0700
Subject: [PATCH 14/80] feat: Add CacheProvider API for external distributed
caching (#12056)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* feat: Add CacheProvider API for external distributed caching
Introduces a public API for external cache providers, enabling distributed
caching across multiple ComfyUI instances (e.g., Kubernetes pods).
New files:
- comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue
dataclasses, thread-safe provider registry, serialization utilities
Modified files:
- comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store,
_check_providers_lookup), subcache exclusion, prompt ID propagation
- execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to
PromptExecutor, set _current_prompt_id on caches
Key features:
- Local-first caching (check local before external for performance)
- NaN detection to prevent incorrect external cache hits
- Subcache exclusion (ephemeral subgraph results not cached externally)
- Thread-safe provider snapshot caching
- Graceful error handling (provider errors logged, never break execution)
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* fix: use deterministic hash for cache keys instead of pickle
Pickle serialization is NOT deterministic across Python sessions due
to hash randomization affecting frozenset iteration order. This causes
distributed caching to fail because different pods compute different
hashes for identical cache keys.
Fix: Use _canonicalize() + JSON serialization which ensures deterministic
ordering regardless of Python's hash randomization.
This is critical for cross-pod cache key consistency in Kubernetes
deployments.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* test: add unit tests for CacheProvider API
- Add comprehensive tests for _canonicalize deterministic ordering
- Add tests for serialize_cache_key hash consistency
- Add tests for contains_nan utility
- Add tests for estimate_value_size
- Add tests for provider registry (register, unregister, clear)
- Move json import to top-level (fix inline import)
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* style: remove unused imports in test_cache_provider.py
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* fix: move _torch_available before usage and use importlib.util.find_spec
Fixes ruff F821 (undefined name) and F401 (unused import) errors.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* fix: use hashable types in frozenset test and add dict test
Frozensets can only contain hashable types, so use nested frozensets
instead of dicts. Added separate test for dict handling via serialize_cache_key.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* refactor: expose CacheProvider API via comfy_api.latest.Caching
- Add Caching class to comfy_api/latest/__init__.py that re-exports
from comfy_execution.cache_provider (source of truth)
- Fix docstring: "Skip large values" instead of "Skip small values"
(small compute-heavy values are good cache targets)
- Maintain backward compatibility: comfy_execution.cache_provider
imports still work
Usage:
from comfy_api.latest import Caching
class MyProvider(Caching.CacheProvider):
def on_lookup(self, context): ...
def on_store(self, context, value): ...
Caching.register_provider(MyProvider())
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* docs: clarify should_cache filtering criteria
Change docstring from "Skip large values" to "Skip if download time > compute time"
which better captures the cost/benefit tradeoff for external caching.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* docs: make should_cache docstring implementation-agnostic
Remove prescriptive filtering suggestions - let implementations
decide their own caching logic based on their use case.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* feat: add optional ui field to CacheValue
- Add ui field to CacheValue dataclass (default None)
- Pass ui when creating CacheValue for external providers
- Use result.ui (or default {}) when returning from external cache lookup
This allows external cache implementations to store/retrieve UI data
if desired, while remaining optional for implementations that skip it.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* refactor: rename _is_cacheable_value to _is_external_cacheable_value
Clearer name since objects are also cached locally - this specifically
checks for external caching eligibility.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude
* refactor: async CacheProvider API + reduce public surface
- Make on_lookup/on_store async on CacheProvider ABC
- Simplify CacheContext: replace cache_key + cache_key_bytes with
cache_key_hash (str hex digest)
- Make registry/utility functions internal (_prefix)
- Trim comfy_api.latest.Caching exports to core API only
- Make cache get/set async throughout caching.py hierarchy
- Use asyncio.create_task for fire-and-forget on_store
- Add NaN gating before provider calls in Core
- Add await to 5 cache call sites in execution.py
Co-Authored-By: Claude Opus 4.6
* fix: remove unused imports (ruff) and update tests for internal API
- Remove unused CacheContext and _serialize_cache_key imports from
caching.py (now handled by _build_context helper)
- Update test_cache_provider.py to use _-prefixed internal names
- Update tests for new CacheContext.cache_key_hash field (str)
- Make MockCacheProvider methods async to match ABC
Co-Authored-By: Claude Opus 4.6
* fix: address coderabbit review feedback
- Add try/except to _build_context, return None when hash fails
- Return None from _serialize_cache_key on total failure (no id()-based fallback)
- Replace hex-like test literal with non-secret placeholder
Co-Authored-By: Claude Opus 4.6
* fix: use _-prefixed imports in _notify_prompt_lifecycle
The lifecycle notification method was importing the old non-prefixed
names (has_cache_providers, get_cache_providers, logger) which no
longer exist after the API cleanup.
Co-Authored-By: Claude Opus 4.6
* fix: add sync get_local/set_local for graph traversal
ExecutionList in graph.py calls output_cache.get() and .set() from
sync methods (is_cached, cache_link, get_cache). These cannot await
the now-async get/set. Add get_local/set_local that bypass external
providers and only access the local dict — which is all graph
traversal needs.
Co-Authored-By: Claude Opus 4.6
* chore: remove cloud-specific language from cache provider API
Make all docstrings and comments generic for the OSS codebase.
Remove references to Kubernetes, Redis, GCS, pods, and other
infrastructure-specific terminology.
Co-Authored-By: Claude Opus 4.6
* style: align documentation with codebase conventions
Strip verbose docstrings and section banners to match existing minimal
documentation style used throughout the codebase.
Co-Authored-By: Claude Opus 4.6
* fix: add usage example to Caching class, remove pickle fallback
- Add docstring with usage example to Caching class matching the
convention used by sibling APIs (Execution.set_progress, ComfyExtension)
- Remove non-deterministic pickle fallback from _serialize_cache_key;
return None on JSON failure instead of producing unretrievable hashes
- Move cache_provider imports to top of execution.py (no circular dep)
Co-Authored-By: Claude Opus 4.6
* refactor: move public types to comfy_api, eager provider snapshot
Address review feedback:
- Move CacheProvider/CacheContext/CacheValue definitions to
comfy_api/latest/_caching.py (source of truth for public API)
- comfy_execution/cache_provider.py re-exports types from there
- Build _providers_snapshot eagerly on register/unregister instead
of lazy memoization in _get_cache_providers
Co-Authored-By: Claude Opus 4.6
* fix: generalize self-inequality check, fail-closed canonicalization
Address review feedback from guill:
- Rename _contains_nan to _contains_self_unequal, use not (x == x)
instead of math.isnan to catch any self-unequal value
- Remove Unhashable and repr() fallbacks from _canonicalize; raise
ValueError for unknown types so _serialize_cache_key returns None
and external caching is skipped (fail-closed)
- Update tests for renamed function and new fail-closed behavior
Co-Authored-By: Claude Opus 4.6
* fix: suppress ruff F401 for re-exported CacheContext
CacheContext is imported from _caching and re-exported for use by
caching.py. Add noqa comment to satisfy the linter.
Co-Authored-By: Claude Opus 4.6
* fix: enable external caching for subcache (expanded) nodes
Subcache nodes (from node expansion) now participate in external
provider store/lookup. Previously skipped to avoid duplicates, but
the cost of missing partial-expansion cache hits outweighs redundant
stores — especially with looping behavior on the horizon.
Co-Authored-By: Claude Opus 4.6
* fix: wrap register/unregister as explicit static methods
Define register_provider and unregister_provider as wrapper functions
in the Caching class instead of re-importing. This locks the public
API signature in comfy_api/ so internal changes can't accidentally
break it.
Co-Authored-By: Claude Opus 4.6
* fix: use debug-level logging for provider registration
Co-Authored-By: Claude Opus 4.6
* fix: follow ProxiedSingleton pattern for Caching class
Add Caching as a nested class inside ComfyAPI_latest inheriting from
ProxiedSingleton with async instance methods, matching the Execution
and NodeReplacement patterns. Retains standalone Caching class for
direct import convenience.
Co-Authored-By: Claude Opus 4.6
* fix: inline registration logic in Caching class
Follow the Execution/NodeReplacement pattern — the public API methods
contain the actual logic operating on cache_provider module state,
not wrapper functions delegating to free functions.
Co-Authored-By: Claude Opus 4.6
* fix: single Caching definition inside ComfyAPI_latest
Remove duplicate standalone Caching class. Define it once as a nested
class in ComfyAPI_latest (matching Execution/NodeReplacement pattern),
with a module-level alias for import convenience.
Co-Authored-By: Claude Opus 4.6
* fix: remove prompt_id from CacheContext, type-safe canonicalization
Remove prompt_id from CacheContext — it's not relevant for cache
matching and added unnecessary plumbing (_current_prompt_id on every
cache). Lifecycle hooks still receive prompt_id directly.
Include type name in canonicalized primitives so that int 7 and
str "7" produce distinct hashes. Also canonicalize dict keys properly
instead of str() coercion.
Co-Authored-By: Claude Opus 4.6
* fix: address review feedback on cache provider API
- Hold references to pending store tasks to prevent "Task was destroyed
but it is still pending" warnings (bigcat88)
- Parallel cache lookups with asyncio.gather instead of sequential
awaits for better performance (bigcat88)
- Delegate Caching.register/unregister_provider to existing functions
in cache_provider.py instead of reimplementing (bigcat88)
Co-Authored-By: Claude Opus 4.6
---------
Co-authored-by: Claude
---
comfy_api/latest/__init__.py | 35 ++
comfy_api/latest/_caching.py | 42 ++
comfy_execution/cache_provider.py | 138 ++++++
comfy_execution/caching.py | 177 +++++++-
comfy_execution/graph.py | 6 +-
execution.py | 141 +++---
.../execution_test/test_cache_provider.py | 403 ++++++++++++++++++
7 files changed, 859 insertions(+), 83 deletions(-)
create mode 100644 comfy_api/latest/_caching.py
create mode 100644 comfy_execution/cache_provider.py
create mode 100644 tests-unit/execution_test/test_cache_provider.py
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index f2399422b..04973fea0 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
+ self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
@@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
+ class Caching(ProxiedSingleton):
+ """
+ External cache provider API for sharing cached node outputs
+ across ComfyUI instances.
+
+ Example::
+
+ from comfy_api.latest import Caching
+
+ class MyCacheProvider(Caching.CacheProvider):
+ async def on_lookup(self, context):
+ ... # check external storage
+
+ async def on_store(self, context, value):
+ ... # store to external storage
+
+ Caching.register_provider(MyCacheProvider())
+ """
+ from ._caching import CacheProvider, CacheContext, CacheValue
+
+ async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
+ """Register an external cache provider. Providers are called in registration order."""
+ from comfy_execution.cache_provider import register_cache_provider
+ register_cache_provider(provider)
+
+ async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
+ """Unregister a previously registered cache provider."""
+ from comfy_execution.cache_provider import unregister_cache_provider
+ unregister_cache_provider(provider)
+
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@@ -116,6 +147,9 @@ class Types:
VOXEL = VOXEL
File3D = File3D
+
+Caching = ComfyAPI_latest.Caching
+
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@@ -135,6 +169,7 @@ __all__ = [
"Input",
"InputImpl",
"Types",
+ "Caching",
"ComfyExtension",
"io",
"IO",
diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py
new file mode 100644
index 000000000..30c8848cd
--- /dev/null
+++ b/comfy_api/latest/_caching.py
@@ -0,0 +1,42 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+from dataclasses import dataclass
+
+
+@dataclass
+class CacheContext:
+ node_id: str
+ class_type: str
+ cache_key_hash: str # SHA256 hex digest
+
+
+@dataclass
+class CacheValue:
+ outputs: list
+ ui: dict = None
+
+
+class CacheProvider(ABC):
+ """Abstract base class for external cache providers.
+ Exceptions from provider methods are caught by the caller and never break execution.
+ """
+
+ @abstractmethod
+ async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
+ """Called on local cache miss. Return CacheValue if found, None otherwise."""
+ pass
+
+ @abstractmethod
+ async def on_store(self, context: CacheContext, value: CacheValue) -> None:
+ """Called after local store. Dispatched via asyncio.create_task."""
+ pass
+
+ def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
+ """Return False to skip external caching for this node. Default: True."""
+ return True
+
+ def on_prompt_start(self, prompt_id: str) -> None:
+ pass
+
+ def on_prompt_end(self, prompt_id: str) -> None:
+ pass
diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py
new file mode 100644
index 000000000..d455d08e8
--- /dev/null
+++ b/comfy_execution/cache_provider.py
@@ -0,0 +1,138 @@
+from typing import Any, Optional, Tuple, List
+import hashlib
+import json
+import logging
+import threading
+
+# Public types — source of truth is comfy_api.latest._caching
+from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
+
+_logger = logging.getLogger(__name__)
+
+
+_providers: List[CacheProvider] = []
+_providers_lock = threading.Lock()
+_providers_snapshot: Tuple[CacheProvider, ...] = ()
+
+
+def register_cache_provider(provider: CacheProvider) -> None:
+ """Register an external cache provider. Providers are called in registration order."""
+ global _providers_snapshot
+ with _providers_lock:
+ if provider in _providers:
+ _logger.warning(f"Provider {provider.__class__.__name__} already registered")
+ return
+ _providers.append(provider)
+ _providers_snapshot = tuple(_providers)
+ _logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
+
+
+def unregister_cache_provider(provider: CacheProvider) -> None:
+ global _providers_snapshot
+ with _providers_lock:
+ try:
+ _providers.remove(provider)
+ _providers_snapshot = tuple(_providers)
+ _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
+ except ValueError:
+ _logger.warning(f"Provider {provider.__class__.__name__} was not registered")
+
+
+def _get_cache_providers() -> Tuple[CacheProvider, ...]:
+ return _providers_snapshot
+
+
+def _has_cache_providers() -> bool:
+ return bool(_providers_snapshot)
+
+
+def _clear_cache_providers() -> None:
+ global _providers_snapshot
+ with _providers_lock:
+ _providers.clear()
+ _providers_snapshot = ()
+
+
+def _canonicalize(obj: Any) -> Any:
+ # Convert to canonical JSON-serializable form with deterministic ordering.
+ # Frozensets have non-deterministic iteration order between Python sessions.
+ # Raises ValueError for non-cacheable types (Unhashable, unknown) so that
+ # _serialize_cache_key returns None and external caching is skipped.
+ if isinstance(obj, frozenset):
+ return ("__frozenset__", sorted(
+ [_canonicalize(item) for item in obj],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ ))
+ elif isinstance(obj, set):
+ return ("__set__", sorted(
+ [_canonicalize(item) for item in obj],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ ))
+ elif isinstance(obj, tuple):
+ return ("__tuple__", [_canonicalize(item) for item in obj])
+ elif isinstance(obj, list):
+ return [_canonicalize(item) for item in obj]
+ elif isinstance(obj, dict):
+ return {"__dict__": sorted(
+ [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ )}
+ elif isinstance(obj, (int, float, str, bool, type(None))):
+ return (type(obj).__name__, obj)
+ elif isinstance(obj, bytes):
+ return ("__bytes__", obj.hex())
+ else:
+ raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
+
+
+def _serialize_cache_key(cache_key: Any) -> Optional[str]:
+ # Returns deterministic SHA256 hex digest, or None on failure.
+ # Uses JSON (not pickle) because pickle is non-deterministic across sessions.
+ try:
+ canonical = _canonicalize(cache_key)
+ json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
+ return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
+ except Exception as e:
+ _logger.warning(f"Failed to serialize cache key: {e}")
+ return None
+
+
+def _contains_self_unequal(obj: Any) -> bool:
+ # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
+ # never hit locally, but serialized form would match externally. Skip these.
+ try:
+ if not (obj == obj):
+ return True
+ except Exception:
+ return True
+ if isinstance(obj, (frozenset, tuple, list, set)):
+ return any(_contains_self_unequal(item) for item in obj)
+ if isinstance(obj, dict):
+ return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
+ if hasattr(obj, 'value'):
+ return _contains_self_unequal(obj.value)
+ return False
+
+
+def _estimate_value_size(value: CacheValue) -> int:
+ try:
+ import torch
+ except ImportError:
+ return 0
+
+ total = 0
+
+ def estimate(obj):
+ nonlocal total
+ if isinstance(obj, torch.Tensor):
+ total += obj.numel() * obj.element_size()
+ elif isinstance(obj, dict):
+ for v in obj.values():
+ estimate(v)
+ elif isinstance(obj, (list, tuple)):
+ for item in obj:
+ estimate(item)
+
+ for output in value.outputs:
+ estimate(output)
+ return total
diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py
index 326a279fc..750bddf2e 100644
--- a/comfy_execution/caching.py
+++ b/comfy_execution/caching.py
@@ -1,3 +1,4 @@
+import asyncio
import bisect
import gc
import itertools
@@ -154,6 +155,7 @@ class BasicCache:
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
+ self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
@@ -196,18 +198,134 @@ class BasicCache:
def poll(self, **kwargs):
pass
- def _set_immediate(self, node_id, value):
- assert self.initialized
- cache_key = self.cache_key_set.get_data_key(node_id)
- self.cache[cache_key] = value
-
- def _get_immediate(self, node_id):
+ def get_local(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
- else:
+ return None
+
+ def set_local(self, node_id, value):
+ assert self.initialized
+ cache_key = self.cache_key_set.get_data_key(node_id)
+ self.cache[cache_key] = value
+
+ async def _set_immediate(self, node_id, value):
+ assert self.initialized
+ cache_key = self.cache_key_set.get_data_key(node_id)
+ self.cache[cache_key] = value
+
+ await self._notify_providers_store(node_id, cache_key, value)
+
+ async def _get_immediate(self, node_id):
+ if not self.initialized:
+ return None
+ cache_key = self.cache_key_set.get_data_key(node_id)
+
+ if cache_key in self.cache:
+ return self.cache[cache_key]
+
+ external_result = await self._check_providers_lookup(node_id, cache_key)
+ if external_result is not None:
+ self.cache[cache_key] = external_result
+ return external_result
+
+ return None
+
+ async def _notify_providers_store(self, node_id, cache_key, value):
+ from comfy_execution.cache_provider import (
+ _has_cache_providers, _get_cache_providers,
+ CacheValue, _contains_self_unequal, _logger
+ )
+
+ if not _has_cache_providers():
+ return
+ if not self._is_external_cacheable_value(value):
+ return
+ if _contains_self_unequal(cache_key):
+ return
+
+ context = self._build_context(node_id, cache_key)
+ if context is None:
+ return
+ cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
+
+ for provider in _get_cache_providers():
+ try:
+ if provider.should_cache(context, cache_value):
+ task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
+ self._pending_store_tasks.add(task)
+ task.add_done_callback(self._pending_store_tasks.discard)
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
+
+ @staticmethod
+ async def _safe_provider_store(provider, context, cache_value):
+ from comfy_execution.cache_provider import _logger
+ try:
+ await provider.on_store(context, cache_value)
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
+
+ async def _check_providers_lookup(self, node_id, cache_key):
+ from comfy_execution.cache_provider import (
+ _has_cache_providers, _get_cache_providers,
+ CacheValue, _contains_self_unequal, _logger
+ )
+
+ if not _has_cache_providers():
+ return None
+ if _contains_self_unequal(cache_key):
+ return None
+
+ context = self._build_context(node_id, cache_key)
+ if context is None:
+ return None
+
+ for provider in _get_cache_providers():
+ try:
+ if not provider.should_cache(context):
+ continue
+ result = await provider.on_lookup(context)
+ if result is not None:
+ if not isinstance(result, CacheValue):
+ _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
+ continue
+ if not isinstance(result.outputs, (list, tuple)):
+ _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
+ continue
+ from execution import CacheEntry
+ return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
+
+ return None
+
+ def _is_external_cacheable_value(self, value):
+ return hasattr(value, 'outputs') and hasattr(value, 'ui')
+
+ def _get_class_type(self, node_id):
+ if not self.initialized or not self.dynprompt:
+ return ''
+ try:
+ return self.dynprompt.get_node(node_id).get('class_type', '')
+ except Exception:
+ return ''
+
+ def _build_context(self, node_id, cache_key):
+ from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
+ try:
+ cache_key_hash = _serialize_cache_key(cache_key)
+ if cache_key_hash is None:
+ return None
+ return CacheContext(
+ node_id=node_id,
+ class_type=self._get_class_type(node_id),
+ cache_key_hash=cache_key_hash,
+ )
+ except Exception as e:
+ _logger.warning(f"Failed to build cache context for node {node_id}: {e}")
return None
async def _ensure_subcache(self, node_id, children_ids):
@@ -257,16 +375,27 @@ class HierarchicalCache(BasicCache):
return None
return cache
- def get(self, node_id):
+ async def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
- return cache._get_immediate(node_id)
+ return await cache._get_immediate(node_id)
- def set(self, node_id, value):
+ def get_local(self, node_id):
+ cache = self._get_cache_for(node_id)
+ if cache is None:
+ return None
+ return BasicCache.get_local(cache, node_id)
+
+ async def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
- cache._set_immediate(node_id, value)
+ await cache._set_immediate(node_id, value)
+
+ def set_local(self, node_id, value):
+ cache = self._get_cache_for(node_id)
+ assert cache is not None
+ BasicCache.set_local(cache, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@@ -287,10 +416,16 @@ class NullCache:
def poll(self, **kwargs):
pass
- def get(self, node_id):
+ async def get(self, node_id):
return None
- def set(self, node_id, value):
+ def get_local(self, node_id):
+ return None
+
+ async def set(self, node_id, value):
+ pass
+
+ def set_local(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
@@ -322,18 +457,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
- def get(self, node_id):
+ async def get(self, node_id):
self._mark_used(node_id)
- return self._get_immediate(node_id)
+ return await self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
- def set(self, node_id, value):
+ async def set(self, node_id, value):
self._mark_used(node_id)
- return self._set_immediate(node_id, value)
+ return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@@ -373,13 +508,13 @@ class RAMPressureCache(LRUCache):
def clean_unused(self):
self._clean_subcaches()
- def set(self, node_id, value):
+ async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
- super().set(node_id, value)
+ await super().set(node_id, value)
- def get(self, node_id):
+ async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
- return super().get(node_id)
+ return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py
index 9d170b16e..c47f3c79b 100644
--- a/comfy_execution/graph.py
+++ b/comfy_execution/graph.py
@@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {}
def is_cached(self, node_id):
- return self.output_cache.get(node_id) is not None
+ return self.output_cache.get_local(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {}
- self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
+ self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
@@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None:
return None
#Write back to the main cache on touch.
- self.output_cache.set(from_node_id, value)
+ self.output_cache.set_local(from_node_id, value)
return value
def cache_update(self, node_id, value):
diff --git a/execution.py b/execution.py
index a7791efed..a8e8fc59f 100644
--- a/execution.py
+++ b/execution.py
@@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
+from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
class ExecutionResult(Enum):
@@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
- cached = caches.outputs.get(unique_id)
+ cached = await caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
@@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
- obj = caches.objects.get(unique_id)
+ obj = await caches.objects.get(unique_id)
if obj is None:
obj = class_def()
- caches.objects.set(unique_id, obj)
+ await caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
- caches.outputs.set(unique_id, cache_entry)
+ await caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@@ -684,6 +685,19 @@ class PromptExecutor:
}
self.add_message("execution_error", mes, broadcast=False)
+ def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
+ if not _has_cache_providers():
+ return
+
+ for provider in _get_cache_providers():
+ try:
+ if event == "start":
+ provider.on_prompt_start(prompt_id)
+ elif event == "end":
+ provider.on_prompt_end(prompt_id)
+ except Exception as e:
+ _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
+
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@@ -700,66 +714,75 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
- with torch.inference_mode():
- dynamic_prompt = DynamicPrompt(prompt)
- reset_progress_state(prompt_id, dynamic_prompt)
- add_progress_handler(WebUIProgressHandler(self.server))
- is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
- for cache in self.caches.all:
- await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
- cache.clean_unused()
+ self._notify_prompt_lifecycle("start", prompt_id)
- cached_nodes = []
- for node_id in prompt:
- if self.caches.outputs.get(node_id) is not None:
- cached_nodes.append(node_id)
+ try:
+ with torch.inference_mode():
+ dynamic_prompt = DynamicPrompt(prompt)
+ reset_progress_state(prompt_id, dynamic_prompt)
+ add_progress_handler(WebUIProgressHandler(self.server))
+ is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
+ for cache in self.caches.all:
+ await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
+ cache.clean_unused()
- comfy.model_management.cleanup_models_gc()
- self.add_message("execution_cached",
- { "nodes": cached_nodes, "prompt_id": prompt_id},
- broadcast=False)
- pending_subgraph_results = {}
- pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
- ui_node_outputs = {}
- executed = set()
- execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
- current_outputs = self.caches.outputs.all_node_ids()
- for node_id in list(execute_outputs):
- execution_list.add_node(node_id)
+ node_ids = list(prompt.keys())
+ cache_results = await asyncio.gather(
+ *(self.caches.outputs.get(node_id) for node_id in node_ids)
+ )
+ cached_nodes = [
+ node_id for node_id, result in zip(node_ids, cache_results)
+ if result is not None
+ ]
- while not execution_list.is_empty():
- node_id, error, ex = await execution_list.stage_node_execution()
- if error is not None:
- self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
- break
+ comfy.model_management.cleanup_models_gc()
+ self.add_message("execution_cached",
+ { "nodes": cached_nodes, "prompt_id": prompt_id},
+ broadcast=False)
+ pending_subgraph_results = {}
+ pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
+ ui_node_outputs = {}
+ executed = set()
+ execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
+ current_outputs = self.caches.outputs.all_node_ids()
+ for node_id in list(execute_outputs):
+ execution_list.add_node(node_id)
- assert node_id is not None, "Node ID should not be None at this point"
- result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
- self.success = result != ExecutionResult.FAILURE
- if result == ExecutionResult.FAILURE:
- self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
- break
- elif result == ExecutionResult.PENDING:
- execution_list.unstage_node_execution()
- else: # result == ExecutionResult.SUCCESS:
- execution_list.complete_node_execution()
- self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
- else:
- # Only execute when the while-loop ends without break
- self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
+ while not execution_list.is_empty():
+ node_id, error, ex = await execution_list.stage_node_execution()
+ if error is not None:
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
+ break
- ui_outputs = {}
- meta_outputs = {}
- for node_id, ui_info in ui_node_outputs.items():
- ui_outputs[node_id] = ui_info["output"]
- meta_outputs[node_id] = ui_info["meta"]
- self.history_result = {
- "outputs": ui_outputs,
- "meta": meta_outputs,
- }
- self.server.last_node_id = None
- if comfy.model_management.DISABLE_SMART_MEMORY:
- comfy.model_management.unload_all_models()
+ assert node_id is not None, "Node ID should not be None at this point"
+ result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
+ self.success = result != ExecutionResult.FAILURE
+ if result == ExecutionResult.FAILURE:
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
+ break
+ elif result == ExecutionResult.PENDING:
+ execution_list.unstage_node_execution()
+ else: # result == ExecutionResult.SUCCESS:
+ execution_list.complete_node_execution()
+ self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
+ else:
+ # Only execute when the while-loop ends without break
+ self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
+
+ ui_outputs = {}
+ meta_outputs = {}
+ for node_id, ui_info in ui_node_outputs.items():
+ ui_outputs[node_id] = ui_info["output"]
+ meta_outputs[node_id] = ui_info["meta"]
+ self.history_result = {
+ "outputs": ui_outputs,
+ "meta": meta_outputs,
+ }
+ self.server.last_node_id = None
+ if comfy.model_management.DISABLE_SMART_MEMORY:
+ comfy.model_management.unload_all_models()
+ finally:
+ self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated):
diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py
new file mode 100644
index 000000000..ac3814746
--- /dev/null
+++ b/tests-unit/execution_test/test_cache_provider.py
@@ -0,0 +1,403 @@
+"""Tests for external cache provider API."""
+
+import importlib.util
+import pytest
+from typing import Optional
+
+
+def _torch_available() -> bool:
+ """Check if PyTorch is available."""
+ return importlib.util.find_spec("torch") is not None
+
+
+from comfy_execution.cache_provider import (
+ CacheProvider,
+ CacheContext,
+ CacheValue,
+ register_cache_provider,
+ unregister_cache_provider,
+ _get_cache_providers,
+ _has_cache_providers,
+ _clear_cache_providers,
+ _serialize_cache_key,
+ _contains_self_unequal,
+ _estimate_value_size,
+ _canonicalize,
+)
+
+
+class TestCanonicalize:
+ """Test _canonicalize function for deterministic ordering."""
+
+ def test_frozenset_ordering_is_deterministic(self):
+ """Frozensets should produce consistent canonical form regardless of iteration order."""
+ # Create two frozensets with same content
+ fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
+ fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
+
+ result1 = _canonicalize(fs1)
+ result2 = _canonicalize(fs2)
+
+ assert result1 == result2
+
+ def test_nested_frozenset_ordering(self):
+ """Nested frozensets should also be deterministically ordered."""
+ inner1 = frozenset([1, 2, 3])
+ inner2 = frozenset([3, 2, 1])
+
+ fs1 = frozenset([("key", inner1)])
+ fs2 = frozenset([("key", inner2)])
+
+ result1 = _canonicalize(fs1)
+ result2 = _canonicalize(fs2)
+
+ assert result1 == result2
+
+ def test_dict_ordering(self):
+ """Dicts should be sorted by key."""
+ d1 = {"z": 1, "a": 2, "m": 3}
+ d2 = {"a": 2, "m": 3, "z": 1}
+
+ result1 = _canonicalize(d1)
+ result2 = _canonicalize(d2)
+
+ assert result1 == result2
+
+ def test_tuple_preserved(self):
+ """Tuples should be marked and preserved."""
+ t = (1, 2, 3)
+ result = _canonicalize(t)
+
+ assert result[0] == "__tuple__"
+
+ def test_list_preserved(self):
+ """Lists should be recursively canonicalized."""
+ lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
+ result = _canonicalize(lst)
+
+ # First element should be canonicalized dict
+ assert "__dict__" in result[0]
+ # Second element should be canonicalized frozenset
+ assert result[1][0] == "__frozenset__"
+
+ def test_primitives_include_type(self):
+ """Primitive types should include type name for disambiguation."""
+ assert _canonicalize(42) == ("int", 42)
+ assert _canonicalize(3.14) == ("float", 3.14)
+ assert _canonicalize("hello") == ("str", "hello")
+ assert _canonicalize(True) == ("bool", True)
+ assert _canonicalize(None) == ("NoneType", None)
+
+ def test_int_and_str_distinguished(self):
+ """int 7 and str '7' must produce different canonical forms."""
+ assert _canonicalize(7) != _canonicalize("7")
+
+ def test_bytes_converted(self):
+ """Bytes should be converted to hex string."""
+ b = b"\x00\xff"
+ result = _canonicalize(b)
+
+ assert result[0] == "__bytes__"
+ assert result[1] == "00ff"
+
+ def test_set_ordering(self):
+ """Sets should be sorted like frozensets."""
+ s1 = {3, 1, 2}
+ s2 = {1, 2, 3}
+
+ result1 = _canonicalize(s1)
+ result2 = _canonicalize(s2)
+
+ assert result1 == result2
+ assert result1[0] == "__set__"
+
+ def test_unknown_type_raises(self):
+ """Unknown types should raise ValueError (fail-closed)."""
+ class CustomObj:
+ pass
+ with pytest.raises(ValueError):
+ _canonicalize(CustomObj())
+
+ def test_object_with_value_attr_raises(self):
+ """Objects with .value attribute (Unhashable-like) should raise ValueError."""
+ class FakeUnhashable:
+ def __init__(self):
+ self.value = float('nan')
+ with pytest.raises(ValueError):
+ _canonicalize(FakeUnhashable())
+
+
+class TestSerializeCacheKey:
+ """Test _serialize_cache_key for deterministic hashing."""
+
+ def test_same_content_same_hash(self):
+ """Same content should produce same hash."""
+ key1 = frozenset([("node_1", frozenset([("input", "value")]))])
+ key2 = frozenset([("node_1", frozenset([("input", "value")]))])
+
+ hash1 = _serialize_cache_key(key1)
+ hash2 = _serialize_cache_key(key2)
+
+ assert hash1 == hash2
+
+ def test_different_content_different_hash(self):
+ """Different content should produce different hash."""
+ key1 = frozenset([("node_1", "value_a")])
+ key2 = frozenset([("node_1", "value_b")])
+
+ hash1 = _serialize_cache_key(key1)
+ hash2 = _serialize_cache_key(key2)
+
+ assert hash1 != hash2
+
+ def test_returns_hex_string(self):
+ """Should return hex string (SHA256 hex digest)."""
+ key = frozenset([("test", 123)])
+ result = _serialize_cache_key(key)
+
+ assert isinstance(result, str)
+ assert len(result) == 64 # SHA256 hex digest is 64 chars
+
+ def test_complex_nested_structure(self):
+ """Complex nested structures should hash deterministically."""
+ # Note: frozensets can only contain hashable types, so we use
+ # nested frozensets of tuples to represent dict-like structures
+ key = frozenset([
+ ("node_1", frozenset([
+ ("input_a", ("tuple", "value")),
+ ("input_b", frozenset([("nested", "dict")])),
+ ])),
+ ("node_2", frozenset([
+ ("param", 42),
+ ])),
+ ])
+
+ # Hash twice to verify determinism
+ hash1 = _serialize_cache_key(key)
+ hash2 = _serialize_cache_key(key)
+
+ assert hash1 == hash2
+
+ def test_dict_in_cache_key(self):
+ """Dicts passed directly to _serialize_cache_key should work."""
+ key = {"node_1": {"input": "value"}, "node_2": 42}
+
+ hash1 = _serialize_cache_key(key)
+ hash2 = _serialize_cache_key(key)
+
+ assert hash1 == hash2
+ assert isinstance(hash1, str)
+ assert len(hash1) == 64
+
+ def test_unknown_type_returns_none(self):
+ """Non-cacheable types should return None (fail-closed)."""
+ class CustomObj:
+ pass
+ assert _serialize_cache_key(CustomObj()) is None
+
+
+class TestContainsSelfUnequal:
+ """Test _contains_self_unequal utility function."""
+
+ def test_nan_float_detected(self):
+ """NaN floats should be detected (not equal to itself)."""
+ assert _contains_self_unequal(float('nan')) is True
+
+ def test_regular_float_not_detected(self):
+ """Regular floats are equal to themselves."""
+ assert _contains_self_unequal(3.14) is False
+ assert _contains_self_unequal(0.0) is False
+ assert _contains_self_unequal(-1.5) is False
+
+ def test_infinity_not_detected(self):
+ """Infinity is equal to itself."""
+ assert _contains_self_unequal(float('inf')) is False
+ assert _contains_self_unequal(float('-inf')) is False
+
+ def test_nan_in_list(self):
+ """NaN in list should be detected."""
+ assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
+ assert _contains_self_unequal([1, 2, 3, 4]) is False
+
+ def test_nan_in_tuple(self):
+ """NaN in tuple should be detected."""
+ assert _contains_self_unequal((1, float('nan'))) is True
+ assert _contains_self_unequal((1, 2, 3)) is False
+
+ def test_nan_in_frozenset(self):
+ """NaN in frozenset should be detected."""
+ assert _contains_self_unequal(frozenset([1, float('nan')])) is True
+ assert _contains_self_unequal(frozenset([1, 2, 3])) is False
+
+ def test_nan_in_dict_value(self):
+ """NaN in dict value should be detected."""
+ assert _contains_self_unequal({"key": float('nan')}) is True
+ assert _contains_self_unequal({"key": 42}) is False
+
+ def test_nan_in_nested_structure(self):
+ """NaN in deeply nested structure should be detected."""
+ nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
+ assert _contains_self_unequal(nested) is True
+
+ def test_non_numeric_types(self):
+ """Non-numeric types should not be self-unequal."""
+ assert _contains_self_unequal("string") is False
+ assert _contains_self_unequal(None) is False
+ assert _contains_self_unequal(True) is False
+
+ def test_object_with_nan_value_attribute(self):
+ """Objects wrapping NaN in .value should be detected."""
+ class NanWrapper:
+ def __init__(self):
+ self.value = float('nan')
+ assert _contains_self_unequal(NanWrapper()) is True
+
+ def test_custom_self_unequal_object(self):
+ """Custom objects where not (x == x) should be detected."""
+ class NeverEqual:
+ def __eq__(self, other):
+ return False
+ assert _contains_self_unequal(NeverEqual()) is True
+
+
+class TestEstimateValueSize:
+ """Test _estimate_value_size utility function."""
+
+ def test_empty_outputs(self):
+ """Empty outputs should have zero size."""
+ value = CacheValue(outputs=[])
+ assert _estimate_value_size(value) == 0
+
+ @pytest.mark.skipif(
+ not _torch_available(),
+ reason="PyTorch not available"
+ )
+ def test_tensor_size_estimation(self):
+ """Tensor size should be estimated correctly."""
+ import torch
+
+ # 1000 float32 elements = 4000 bytes
+ tensor = torch.zeros(1000, dtype=torch.float32)
+ value = CacheValue(outputs=[[tensor]])
+
+ size = _estimate_value_size(value)
+ assert size == 4000
+
+ @pytest.mark.skipif(
+ not _torch_available(),
+ reason="PyTorch not available"
+ )
+ def test_nested_tensor_in_dict(self):
+ """Tensors nested in dicts should be counted."""
+ import torch
+
+ tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
+ value = CacheValue(outputs=[[{"samples": tensor}]])
+
+ size = _estimate_value_size(value)
+ assert size == 400
+
+
+class TestProviderRegistry:
+ """Test cache provider registration and retrieval."""
+
+ def setup_method(self):
+ """Clear providers before each test."""
+ _clear_cache_providers()
+
+ def teardown_method(self):
+ """Clear providers after each test."""
+ _clear_cache_providers()
+
+ def test_register_provider(self):
+ """Provider should be registered successfully."""
+ provider = MockCacheProvider()
+ register_cache_provider(provider)
+
+ assert _has_cache_providers() is True
+ providers = _get_cache_providers()
+ assert len(providers) == 1
+ assert providers[0] is provider
+
+ def test_unregister_provider(self):
+ """Provider should be unregistered successfully."""
+ provider = MockCacheProvider()
+ register_cache_provider(provider)
+ unregister_cache_provider(provider)
+
+ assert _has_cache_providers() is False
+
+ def test_multiple_providers(self):
+ """Multiple providers can be registered."""
+ provider1 = MockCacheProvider()
+ provider2 = MockCacheProvider()
+
+ register_cache_provider(provider1)
+ register_cache_provider(provider2)
+
+ providers = _get_cache_providers()
+ assert len(providers) == 2
+
+ def test_duplicate_registration_ignored(self):
+ """Registering same provider twice should be ignored."""
+ provider = MockCacheProvider()
+
+ register_cache_provider(provider)
+ register_cache_provider(provider) # Should be ignored
+
+ providers = _get_cache_providers()
+ assert len(providers) == 1
+
+ def test_clear_providers(self):
+ """_clear_cache_providers should remove all providers."""
+ provider1 = MockCacheProvider()
+ provider2 = MockCacheProvider()
+
+ register_cache_provider(provider1)
+ register_cache_provider(provider2)
+ _clear_cache_providers()
+
+ assert _has_cache_providers() is False
+ assert len(_get_cache_providers()) == 0
+
+
+class TestCacheContext:
+ """Test CacheContext dataclass."""
+
+ def test_context_creation(self):
+ """CacheContext should be created with all fields."""
+ context = CacheContext(
+ node_id="node-456",
+ class_type="KSampler",
+ cache_key_hash="a" * 64,
+ )
+
+ assert context.node_id == "node-456"
+ assert context.class_type == "KSampler"
+ assert context.cache_key_hash == "a" * 64
+
+
+class TestCacheValue:
+ """Test CacheValue dataclass."""
+
+ def test_value_creation(self):
+ """CacheValue should be created with outputs."""
+ outputs = [[{"samples": "tensor_data"}]]
+ value = CacheValue(outputs=outputs)
+
+ assert value.outputs == outputs
+
+
+class MockCacheProvider(CacheProvider):
+ """Mock cache provider for testing."""
+
+ def __init__(self):
+ self.lookups = []
+ self.stores = []
+
+ async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
+ self.lookups.append(context)
+ return None
+
+ async def on_store(self, context: CacheContext, value: CacheValue) -> None:
+ self.stores.append((context, value))
From d1d53c14be8442fca19aae978e944edad1935d46 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 12 Mar 2026 17:21:23 -0700
Subject: [PATCH 15/80] Revert "feat: Add CacheProvider API for external
distributed caching (#12056)" (#12912)
This reverts commit af7b4a921d7abab7c852d7b5febb654be6e57eba.
---
comfy_api/latest/__init__.py | 35 --
comfy_api/latest/_caching.py | 42 --
comfy_execution/cache_provider.py | 138 ------
comfy_execution/caching.py | 177 +-------
comfy_execution/graph.py | 6 +-
execution.py | 141 +++---
.../execution_test/test_cache_provider.py | 403 ------------------
7 files changed, 83 insertions(+), 859 deletions(-)
delete mode 100644 comfy_api/latest/_caching.py
delete mode 100644 comfy_execution/cache_provider.py
delete mode 100644 tests-unit/execution_test/test_cache_provider.py
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 04973fea0..f2399422b 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -25,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
- self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
@@ -85,36 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
- class Caching(ProxiedSingleton):
- """
- External cache provider API for sharing cached node outputs
- across ComfyUI instances.
-
- Example::
-
- from comfy_api.latest import Caching
-
- class MyCacheProvider(Caching.CacheProvider):
- async def on_lookup(self, context):
- ... # check external storage
-
- async def on_store(self, context, value):
- ... # store to external storage
-
- Caching.register_provider(MyCacheProvider())
- """
- from ._caching import CacheProvider, CacheContext, CacheValue
-
- async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
- """Register an external cache provider. Providers are called in registration order."""
- from comfy_execution.cache_provider import register_cache_provider
- register_cache_provider(provider)
-
- async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
- """Unregister a previously registered cache provider."""
- from comfy_execution.cache_provider import unregister_cache_provider
- unregister_cache_provider(provider)
-
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@@ -147,9 +116,6 @@ class Types:
VOXEL = VOXEL
File3D = File3D
-
-Caching = ComfyAPI_latest.Caching
-
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@@ -169,7 +135,6 @@ __all__ = [
"Input",
"InputImpl",
"Types",
- "Caching",
"ComfyExtension",
"io",
"IO",
diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py
deleted file mode 100644
index 30c8848cd..000000000
--- a/comfy_api/latest/_caching.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Optional
-from dataclasses import dataclass
-
-
-@dataclass
-class CacheContext:
- node_id: str
- class_type: str
- cache_key_hash: str # SHA256 hex digest
-
-
-@dataclass
-class CacheValue:
- outputs: list
- ui: dict = None
-
-
-class CacheProvider(ABC):
- """Abstract base class for external cache providers.
- Exceptions from provider methods are caught by the caller and never break execution.
- """
-
- @abstractmethod
- async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
- """Called on local cache miss. Return CacheValue if found, None otherwise."""
- pass
-
- @abstractmethod
- async def on_store(self, context: CacheContext, value: CacheValue) -> None:
- """Called after local store. Dispatched via asyncio.create_task."""
- pass
-
- def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
- """Return False to skip external caching for this node. Default: True."""
- return True
-
- def on_prompt_start(self, prompt_id: str) -> None:
- pass
-
- def on_prompt_end(self, prompt_id: str) -> None:
- pass
diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py
deleted file mode 100644
index d455d08e8..000000000
--- a/comfy_execution/cache_provider.py
+++ /dev/null
@@ -1,138 +0,0 @@
-from typing import Any, Optional, Tuple, List
-import hashlib
-import json
-import logging
-import threading
-
-# Public types — source of truth is comfy_api.latest._caching
-from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
-
-_logger = logging.getLogger(__name__)
-
-
-_providers: List[CacheProvider] = []
-_providers_lock = threading.Lock()
-_providers_snapshot: Tuple[CacheProvider, ...] = ()
-
-
-def register_cache_provider(provider: CacheProvider) -> None:
- """Register an external cache provider. Providers are called in registration order."""
- global _providers_snapshot
- with _providers_lock:
- if provider in _providers:
- _logger.warning(f"Provider {provider.__class__.__name__} already registered")
- return
- _providers.append(provider)
- _providers_snapshot = tuple(_providers)
- _logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
-
-
-def unregister_cache_provider(provider: CacheProvider) -> None:
- global _providers_snapshot
- with _providers_lock:
- try:
- _providers.remove(provider)
- _providers_snapshot = tuple(_providers)
- _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
- except ValueError:
- _logger.warning(f"Provider {provider.__class__.__name__} was not registered")
-
-
-def _get_cache_providers() -> Tuple[CacheProvider, ...]:
- return _providers_snapshot
-
-
-def _has_cache_providers() -> bool:
- return bool(_providers_snapshot)
-
-
-def _clear_cache_providers() -> None:
- global _providers_snapshot
- with _providers_lock:
- _providers.clear()
- _providers_snapshot = ()
-
-
-def _canonicalize(obj: Any) -> Any:
- # Convert to canonical JSON-serializable form with deterministic ordering.
- # Frozensets have non-deterministic iteration order between Python sessions.
- # Raises ValueError for non-cacheable types (Unhashable, unknown) so that
- # _serialize_cache_key returns None and external caching is skipped.
- if isinstance(obj, frozenset):
- return ("__frozenset__", sorted(
- [_canonicalize(item) for item in obj],
- key=lambda x: json.dumps(x, sort_keys=True)
- ))
- elif isinstance(obj, set):
- return ("__set__", sorted(
- [_canonicalize(item) for item in obj],
- key=lambda x: json.dumps(x, sort_keys=True)
- ))
- elif isinstance(obj, tuple):
- return ("__tuple__", [_canonicalize(item) for item in obj])
- elif isinstance(obj, list):
- return [_canonicalize(item) for item in obj]
- elif isinstance(obj, dict):
- return {"__dict__": sorted(
- [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
- key=lambda x: json.dumps(x, sort_keys=True)
- )}
- elif isinstance(obj, (int, float, str, bool, type(None))):
- return (type(obj).__name__, obj)
- elif isinstance(obj, bytes):
- return ("__bytes__", obj.hex())
- else:
- raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
-
-
-def _serialize_cache_key(cache_key: Any) -> Optional[str]:
- # Returns deterministic SHA256 hex digest, or None on failure.
- # Uses JSON (not pickle) because pickle is non-deterministic across sessions.
- try:
- canonical = _canonicalize(cache_key)
- json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
- return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
- except Exception as e:
- _logger.warning(f"Failed to serialize cache key: {e}")
- return None
-
-
-def _contains_self_unequal(obj: Any) -> bool:
- # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
- # never hit locally, but serialized form would match externally. Skip these.
- try:
- if not (obj == obj):
- return True
- except Exception:
- return True
- if isinstance(obj, (frozenset, tuple, list, set)):
- return any(_contains_self_unequal(item) for item in obj)
- if isinstance(obj, dict):
- return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
- if hasattr(obj, 'value'):
- return _contains_self_unequal(obj.value)
- return False
-
-
-def _estimate_value_size(value: CacheValue) -> int:
- try:
- import torch
- except ImportError:
- return 0
-
- total = 0
-
- def estimate(obj):
- nonlocal total
- if isinstance(obj, torch.Tensor):
- total += obj.numel() * obj.element_size()
- elif isinstance(obj, dict):
- for v in obj.values():
- estimate(v)
- elif isinstance(obj, (list, tuple)):
- for item in obj:
- estimate(item)
-
- for output in value.outputs:
- estimate(output)
- return total
diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py
index 750bddf2e..326a279fc 100644
--- a/comfy_execution/caching.py
+++ b/comfy_execution/caching.py
@@ -1,4 +1,3 @@
-import asyncio
import bisect
import gc
import itertools
@@ -155,7 +154,6 @@ class BasicCache:
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
- self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
@@ -198,134 +196,18 @@ class BasicCache:
def poll(self, **kwargs):
pass
- def get_local(self, node_id):
+ def _set_immediate(self, node_id, value):
+ assert self.initialized
+ cache_key = self.cache_key_set.get_data_key(node_id)
+ self.cache[cache_key] = value
+
+ def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
- return None
-
- def set_local(self, node_id, value):
- assert self.initialized
- cache_key = self.cache_key_set.get_data_key(node_id)
- self.cache[cache_key] = value
-
- async def _set_immediate(self, node_id, value):
- assert self.initialized
- cache_key = self.cache_key_set.get_data_key(node_id)
- self.cache[cache_key] = value
-
- await self._notify_providers_store(node_id, cache_key, value)
-
- async def _get_immediate(self, node_id):
- if not self.initialized:
- return None
- cache_key = self.cache_key_set.get_data_key(node_id)
-
- if cache_key in self.cache:
- return self.cache[cache_key]
-
- external_result = await self._check_providers_lookup(node_id, cache_key)
- if external_result is not None:
- self.cache[cache_key] = external_result
- return external_result
-
- return None
-
- async def _notify_providers_store(self, node_id, cache_key, value):
- from comfy_execution.cache_provider import (
- _has_cache_providers, _get_cache_providers,
- CacheValue, _contains_self_unequal, _logger
- )
-
- if not _has_cache_providers():
- return
- if not self._is_external_cacheable_value(value):
- return
- if _contains_self_unequal(cache_key):
- return
-
- context = self._build_context(node_id, cache_key)
- if context is None:
- return
- cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
-
- for provider in _get_cache_providers():
- try:
- if provider.should_cache(context, cache_value):
- task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
- self._pending_store_tasks.add(task)
- task.add_done_callback(self._pending_store_tasks.discard)
- except Exception as e:
- _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
-
- @staticmethod
- async def _safe_provider_store(provider, context, cache_value):
- from comfy_execution.cache_provider import _logger
- try:
- await provider.on_store(context, cache_value)
- except Exception as e:
- _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
-
- async def _check_providers_lookup(self, node_id, cache_key):
- from comfy_execution.cache_provider import (
- _has_cache_providers, _get_cache_providers,
- CacheValue, _contains_self_unequal, _logger
- )
-
- if not _has_cache_providers():
- return None
- if _contains_self_unequal(cache_key):
- return None
-
- context = self._build_context(node_id, cache_key)
- if context is None:
- return None
-
- for provider in _get_cache_providers():
- try:
- if not provider.should_cache(context):
- continue
- result = await provider.on_lookup(context)
- if result is not None:
- if not isinstance(result, CacheValue):
- _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
- continue
- if not isinstance(result.outputs, (list, tuple)):
- _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
- continue
- from execution import CacheEntry
- return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
- except Exception as e:
- _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
-
- return None
-
- def _is_external_cacheable_value(self, value):
- return hasattr(value, 'outputs') and hasattr(value, 'ui')
-
- def _get_class_type(self, node_id):
- if not self.initialized or not self.dynprompt:
- return ''
- try:
- return self.dynprompt.get_node(node_id).get('class_type', '')
- except Exception:
- return ''
-
- def _build_context(self, node_id, cache_key):
- from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
- try:
- cache_key_hash = _serialize_cache_key(cache_key)
- if cache_key_hash is None:
- return None
- return CacheContext(
- node_id=node_id,
- class_type=self._get_class_type(node_id),
- cache_key_hash=cache_key_hash,
- )
- except Exception as e:
- _logger.warning(f"Failed to build cache context for node {node_id}: {e}")
+ else:
return None
async def _ensure_subcache(self, node_id, children_ids):
@@ -375,27 +257,16 @@ class HierarchicalCache(BasicCache):
return None
return cache
- async def get(self, node_id):
+ def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
- return await cache._get_immediate(node_id)
+ return cache._get_immediate(node_id)
- def get_local(self, node_id):
- cache = self._get_cache_for(node_id)
- if cache is None:
- return None
- return BasicCache.get_local(cache, node_id)
-
- async def set(self, node_id, value):
+ def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
- await cache._set_immediate(node_id, value)
-
- def set_local(self, node_id, value):
- cache = self._get_cache_for(node_id)
- assert cache is not None
- BasicCache.set_local(cache, node_id, value)
+ cache._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@@ -416,16 +287,10 @@ class NullCache:
def poll(self, **kwargs):
pass
- async def get(self, node_id):
+ def get(self, node_id):
return None
- def get_local(self, node_id):
- return None
-
- async def set(self, node_id, value):
- pass
-
- def set_local(self, node_id, value):
+ def set(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
@@ -457,18 +322,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
- async def get(self, node_id):
+ def get(self, node_id):
self._mark_used(node_id)
- return await self._get_immediate(node_id)
+ return self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
- async def set(self, node_id, value):
+ def set(self, node_id, value):
self._mark_used(node_id)
- return await self._set_immediate(node_id, value)
+ return self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@@ -508,13 +373,13 @@ class RAMPressureCache(LRUCache):
def clean_unused(self):
self._clean_subcaches()
- async def set(self, node_id, value):
+ def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
- await super().set(node_id, value)
+ super().set(node_id, value)
- async def get(self, node_id):
+ def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
- return await super().get(node_id)
+ return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py
index c47f3c79b..9d170b16e 100644
--- a/comfy_execution/graph.py
+++ b/comfy_execution/graph.py
@@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {}
def is_cached(self, node_id):
- return self.output_cache.get_local(node_id) is not None
+ return self.output_cache.get(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {}
- self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
+ self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
@@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None:
return None
#Write back to the main cache on touch.
- self.output_cache.set_local(from_node_id, value)
+ self.output_cache.set(from_node_id, value)
return value
def cache_update(self, node_id, value):
diff --git a/execution.py b/execution.py
index a8e8fc59f..a7791efed 100644
--- a/execution.py
+++ b/execution.py
@@ -40,7 +40,6 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
-from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
class ExecutionResult(Enum):
@@ -419,7 +418,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
- cached = await caches.outputs.get(unique_id)
+ cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
@@ -475,10 +474,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
- obj = await caches.objects.get(unique_id)
+ obj = caches.objects.get(unique_id)
if obj is None:
obj = class_def()
- await caches.objects.set(unique_id, obj)
+ caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@@ -589,7 +588,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
- await caches.outputs.set(unique_id, cache_entry)
+ caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@@ -685,19 +684,6 @@ class PromptExecutor:
}
self.add_message("execution_error", mes, broadcast=False)
- def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
- if not _has_cache_providers():
- return
-
- for provider in _get_cache_providers():
- try:
- if event == "start":
- provider.on_prompt_start(prompt_id)
- elif event == "end":
- provider.on_prompt_end(prompt_id)
- except Exception as e:
- _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
-
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@@ -714,75 +700,66 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
- self._notify_prompt_lifecycle("start", prompt_id)
+ with torch.inference_mode():
+ dynamic_prompt = DynamicPrompt(prompt)
+ reset_progress_state(prompt_id, dynamic_prompt)
+ add_progress_handler(WebUIProgressHandler(self.server))
+ is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
+ for cache in self.caches.all:
+ await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
+ cache.clean_unused()
- try:
- with torch.inference_mode():
- dynamic_prompt = DynamicPrompt(prompt)
- reset_progress_state(prompt_id, dynamic_prompt)
- add_progress_handler(WebUIProgressHandler(self.server))
- is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
- for cache in self.caches.all:
- await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
- cache.clean_unused()
+ cached_nodes = []
+ for node_id in prompt:
+ if self.caches.outputs.get(node_id) is not None:
+ cached_nodes.append(node_id)
- node_ids = list(prompt.keys())
- cache_results = await asyncio.gather(
- *(self.caches.outputs.get(node_id) for node_id in node_ids)
- )
- cached_nodes = [
- node_id for node_id, result in zip(node_ids, cache_results)
- if result is not None
- ]
+ comfy.model_management.cleanup_models_gc()
+ self.add_message("execution_cached",
+ { "nodes": cached_nodes, "prompt_id": prompt_id},
+ broadcast=False)
+ pending_subgraph_results = {}
+ pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
+ ui_node_outputs = {}
+ executed = set()
+ execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
+ current_outputs = self.caches.outputs.all_node_ids()
+ for node_id in list(execute_outputs):
+ execution_list.add_node(node_id)
- comfy.model_management.cleanup_models_gc()
- self.add_message("execution_cached",
- { "nodes": cached_nodes, "prompt_id": prompt_id},
- broadcast=False)
- pending_subgraph_results = {}
- pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
- ui_node_outputs = {}
- executed = set()
- execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
- current_outputs = self.caches.outputs.all_node_ids()
- for node_id in list(execute_outputs):
- execution_list.add_node(node_id)
+ while not execution_list.is_empty():
+ node_id, error, ex = await execution_list.stage_node_execution()
+ if error is not None:
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
+ break
- while not execution_list.is_empty():
- node_id, error, ex = await execution_list.stage_node_execution()
- if error is not None:
- self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
- break
+ assert node_id is not None, "Node ID should not be None at this point"
+ result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
+ self.success = result != ExecutionResult.FAILURE
+ if result == ExecutionResult.FAILURE:
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
+ break
+ elif result == ExecutionResult.PENDING:
+ execution_list.unstage_node_execution()
+ else: # result == ExecutionResult.SUCCESS:
+ execution_list.complete_node_execution()
+ self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
+ else:
+ # Only execute when the while-loop ends without break
+ self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
- assert node_id is not None, "Node ID should not be None at this point"
- result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
- self.success = result != ExecutionResult.FAILURE
- if result == ExecutionResult.FAILURE:
- self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
- break
- elif result == ExecutionResult.PENDING:
- execution_list.unstage_node_execution()
- else: # result == ExecutionResult.SUCCESS:
- execution_list.complete_node_execution()
- self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
- else:
- # Only execute when the while-loop ends without break
- self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
-
- ui_outputs = {}
- meta_outputs = {}
- for node_id, ui_info in ui_node_outputs.items():
- ui_outputs[node_id] = ui_info["output"]
- meta_outputs[node_id] = ui_info["meta"]
- self.history_result = {
- "outputs": ui_outputs,
- "meta": meta_outputs,
- }
- self.server.last_node_id = None
- if comfy.model_management.DISABLE_SMART_MEMORY:
- comfy.model_management.unload_all_models()
- finally:
- self._notify_prompt_lifecycle("end", prompt_id)
+ ui_outputs = {}
+ meta_outputs = {}
+ for node_id, ui_info in ui_node_outputs.items():
+ ui_outputs[node_id] = ui_info["output"]
+ meta_outputs[node_id] = ui_info["meta"]
+ self.history_result = {
+ "outputs": ui_outputs,
+ "meta": meta_outputs,
+ }
+ self.server.last_node_id = None
+ if comfy.model_management.DISABLE_SMART_MEMORY:
+ comfy.model_management.unload_all_models()
async def validate_inputs(prompt_id, prompt, item, validated):
diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py
deleted file mode 100644
index ac3814746..000000000
--- a/tests-unit/execution_test/test_cache_provider.py
+++ /dev/null
@@ -1,403 +0,0 @@
-"""Tests for external cache provider API."""
-
-import importlib.util
-import pytest
-from typing import Optional
-
-
-def _torch_available() -> bool:
- """Check if PyTorch is available."""
- return importlib.util.find_spec("torch") is not None
-
-
-from comfy_execution.cache_provider import (
- CacheProvider,
- CacheContext,
- CacheValue,
- register_cache_provider,
- unregister_cache_provider,
- _get_cache_providers,
- _has_cache_providers,
- _clear_cache_providers,
- _serialize_cache_key,
- _contains_self_unequal,
- _estimate_value_size,
- _canonicalize,
-)
-
-
-class TestCanonicalize:
- """Test _canonicalize function for deterministic ordering."""
-
- def test_frozenset_ordering_is_deterministic(self):
- """Frozensets should produce consistent canonical form regardless of iteration order."""
- # Create two frozensets with same content
- fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
- fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
-
- result1 = _canonicalize(fs1)
- result2 = _canonicalize(fs2)
-
- assert result1 == result2
-
- def test_nested_frozenset_ordering(self):
- """Nested frozensets should also be deterministically ordered."""
- inner1 = frozenset([1, 2, 3])
- inner2 = frozenset([3, 2, 1])
-
- fs1 = frozenset([("key", inner1)])
- fs2 = frozenset([("key", inner2)])
-
- result1 = _canonicalize(fs1)
- result2 = _canonicalize(fs2)
-
- assert result1 == result2
-
- def test_dict_ordering(self):
- """Dicts should be sorted by key."""
- d1 = {"z": 1, "a": 2, "m": 3}
- d2 = {"a": 2, "m": 3, "z": 1}
-
- result1 = _canonicalize(d1)
- result2 = _canonicalize(d2)
-
- assert result1 == result2
-
- def test_tuple_preserved(self):
- """Tuples should be marked and preserved."""
- t = (1, 2, 3)
- result = _canonicalize(t)
-
- assert result[0] == "__tuple__"
-
- def test_list_preserved(self):
- """Lists should be recursively canonicalized."""
- lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
- result = _canonicalize(lst)
-
- # First element should be canonicalized dict
- assert "__dict__" in result[0]
- # Second element should be canonicalized frozenset
- assert result[1][0] == "__frozenset__"
-
- def test_primitives_include_type(self):
- """Primitive types should include type name for disambiguation."""
- assert _canonicalize(42) == ("int", 42)
- assert _canonicalize(3.14) == ("float", 3.14)
- assert _canonicalize("hello") == ("str", "hello")
- assert _canonicalize(True) == ("bool", True)
- assert _canonicalize(None) == ("NoneType", None)
-
- def test_int_and_str_distinguished(self):
- """int 7 and str '7' must produce different canonical forms."""
- assert _canonicalize(7) != _canonicalize("7")
-
- def test_bytes_converted(self):
- """Bytes should be converted to hex string."""
- b = b"\x00\xff"
- result = _canonicalize(b)
-
- assert result[0] == "__bytes__"
- assert result[1] == "00ff"
-
- def test_set_ordering(self):
- """Sets should be sorted like frozensets."""
- s1 = {3, 1, 2}
- s2 = {1, 2, 3}
-
- result1 = _canonicalize(s1)
- result2 = _canonicalize(s2)
-
- assert result1 == result2
- assert result1[0] == "__set__"
-
- def test_unknown_type_raises(self):
- """Unknown types should raise ValueError (fail-closed)."""
- class CustomObj:
- pass
- with pytest.raises(ValueError):
- _canonicalize(CustomObj())
-
- def test_object_with_value_attr_raises(self):
- """Objects with .value attribute (Unhashable-like) should raise ValueError."""
- class FakeUnhashable:
- def __init__(self):
- self.value = float('nan')
- with pytest.raises(ValueError):
- _canonicalize(FakeUnhashable())
-
-
-class TestSerializeCacheKey:
- """Test _serialize_cache_key for deterministic hashing."""
-
- def test_same_content_same_hash(self):
- """Same content should produce same hash."""
- key1 = frozenset([("node_1", frozenset([("input", "value")]))])
- key2 = frozenset([("node_1", frozenset([("input", "value")]))])
-
- hash1 = _serialize_cache_key(key1)
- hash2 = _serialize_cache_key(key2)
-
- assert hash1 == hash2
-
- def test_different_content_different_hash(self):
- """Different content should produce different hash."""
- key1 = frozenset([("node_1", "value_a")])
- key2 = frozenset([("node_1", "value_b")])
-
- hash1 = _serialize_cache_key(key1)
- hash2 = _serialize_cache_key(key2)
-
- assert hash1 != hash2
-
- def test_returns_hex_string(self):
- """Should return hex string (SHA256 hex digest)."""
- key = frozenset([("test", 123)])
- result = _serialize_cache_key(key)
-
- assert isinstance(result, str)
- assert len(result) == 64 # SHA256 hex digest is 64 chars
-
- def test_complex_nested_structure(self):
- """Complex nested structures should hash deterministically."""
- # Note: frozensets can only contain hashable types, so we use
- # nested frozensets of tuples to represent dict-like structures
- key = frozenset([
- ("node_1", frozenset([
- ("input_a", ("tuple", "value")),
- ("input_b", frozenset([("nested", "dict")])),
- ])),
- ("node_2", frozenset([
- ("param", 42),
- ])),
- ])
-
- # Hash twice to verify determinism
- hash1 = _serialize_cache_key(key)
- hash2 = _serialize_cache_key(key)
-
- assert hash1 == hash2
-
- def test_dict_in_cache_key(self):
- """Dicts passed directly to _serialize_cache_key should work."""
- key = {"node_1": {"input": "value"}, "node_2": 42}
-
- hash1 = _serialize_cache_key(key)
- hash2 = _serialize_cache_key(key)
-
- assert hash1 == hash2
- assert isinstance(hash1, str)
- assert len(hash1) == 64
-
- def test_unknown_type_returns_none(self):
- """Non-cacheable types should return None (fail-closed)."""
- class CustomObj:
- pass
- assert _serialize_cache_key(CustomObj()) is None
-
-
-class TestContainsSelfUnequal:
- """Test _contains_self_unequal utility function."""
-
- def test_nan_float_detected(self):
- """NaN floats should be detected (not equal to itself)."""
- assert _contains_self_unequal(float('nan')) is True
-
- def test_regular_float_not_detected(self):
- """Regular floats are equal to themselves."""
- assert _contains_self_unequal(3.14) is False
- assert _contains_self_unequal(0.0) is False
- assert _contains_self_unequal(-1.5) is False
-
- def test_infinity_not_detected(self):
- """Infinity is equal to itself."""
- assert _contains_self_unequal(float('inf')) is False
- assert _contains_self_unequal(float('-inf')) is False
-
- def test_nan_in_list(self):
- """NaN in list should be detected."""
- assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
- assert _contains_self_unequal([1, 2, 3, 4]) is False
-
- def test_nan_in_tuple(self):
- """NaN in tuple should be detected."""
- assert _contains_self_unequal((1, float('nan'))) is True
- assert _contains_self_unequal((1, 2, 3)) is False
-
- def test_nan_in_frozenset(self):
- """NaN in frozenset should be detected."""
- assert _contains_self_unequal(frozenset([1, float('nan')])) is True
- assert _contains_self_unequal(frozenset([1, 2, 3])) is False
-
- def test_nan_in_dict_value(self):
- """NaN in dict value should be detected."""
- assert _contains_self_unequal({"key": float('nan')}) is True
- assert _contains_self_unequal({"key": 42}) is False
-
- def test_nan_in_nested_structure(self):
- """NaN in deeply nested structure should be detected."""
- nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
- assert _contains_self_unequal(nested) is True
-
- def test_non_numeric_types(self):
- """Non-numeric types should not be self-unequal."""
- assert _contains_self_unequal("string") is False
- assert _contains_self_unequal(None) is False
- assert _contains_self_unequal(True) is False
-
- def test_object_with_nan_value_attribute(self):
- """Objects wrapping NaN in .value should be detected."""
- class NanWrapper:
- def __init__(self):
- self.value = float('nan')
- assert _contains_self_unequal(NanWrapper()) is True
-
- def test_custom_self_unequal_object(self):
- """Custom objects where not (x == x) should be detected."""
- class NeverEqual:
- def __eq__(self, other):
- return False
- assert _contains_self_unequal(NeverEqual()) is True
-
-
-class TestEstimateValueSize:
- """Test _estimate_value_size utility function."""
-
- def test_empty_outputs(self):
- """Empty outputs should have zero size."""
- value = CacheValue(outputs=[])
- assert _estimate_value_size(value) == 0
-
- @pytest.mark.skipif(
- not _torch_available(),
- reason="PyTorch not available"
- )
- def test_tensor_size_estimation(self):
- """Tensor size should be estimated correctly."""
- import torch
-
- # 1000 float32 elements = 4000 bytes
- tensor = torch.zeros(1000, dtype=torch.float32)
- value = CacheValue(outputs=[[tensor]])
-
- size = _estimate_value_size(value)
- assert size == 4000
-
- @pytest.mark.skipif(
- not _torch_available(),
- reason="PyTorch not available"
- )
- def test_nested_tensor_in_dict(self):
- """Tensors nested in dicts should be counted."""
- import torch
-
- tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
- value = CacheValue(outputs=[[{"samples": tensor}]])
-
- size = _estimate_value_size(value)
- assert size == 400
-
-
-class TestProviderRegistry:
- """Test cache provider registration and retrieval."""
-
- def setup_method(self):
- """Clear providers before each test."""
- _clear_cache_providers()
-
- def teardown_method(self):
- """Clear providers after each test."""
- _clear_cache_providers()
-
- def test_register_provider(self):
- """Provider should be registered successfully."""
- provider = MockCacheProvider()
- register_cache_provider(provider)
-
- assert _has_cache_providers() is True
- providers = _get_cache_providers()
- assert len(providers) == 1
- assert providers[0] is provider
-
- def test_unregister_provider(self):
- """Provider should be unregistered successfully."""
- provider = MockCacheProvider()
- register_cache_provider(provider)
- unregister_cache_provider(provider)
-
- assert _has_cache_providers() is False
-
- def test_multiple_providers(self):
- """Multiple providers can be registered."""
- provider1 = MockCacheProvider()
- provider2 = MockCacheProvider()
-
- register_cache_provider(provider1)
- register_cache_provider(provider2)
-
- providers = _get_cache_providers()
- assert len(providers) == 2
-
- def test_duplicate_registration_ignored(self):
- """Registering same provider twice should be ignored."""
- provider = MockCacheProvider()
-
- register_cache_provider(provider)
- register_cache_provider(provider) # Should be ignored
-
- providers = _get_cache_providers()
- assert len(providers) == 1
-
- def test_clear_providers(self):
- """_clear_cache_providers should remove all providers."""
- provider1 = MockCacheProvider()
- provider2 = MockCacheProvider()
-
- register_cache_provider(provider1)
- register_cache_provider(provider2)
- _clear_cache_providers()
-
- assert _has_cache_providers() is False
- assert len(_get_cache_providers()) == 0
-
-
-class TestCacheContext:
- """Test CacheContext dataclass."""
-
- def test_context_creation(self):
- """CacheContext should be created with all fields."""
- context = CacheContext(
- node_id="node-456",
- class_type="KSampler",
- cache_key_hash="a" * 64,
- )
-
- assert context.node_id == "node-456"
- assert context.class_type == "KSampler"
- assert context.cache_key_hash == "a" * 64
-
-
-class TestCacheValue:
- """Test CacheValue dataclass."""
-
- def test_value_creation(self):
- """CacheValue should be created with outputs."""
- outputs = [[{"samples": "tensor_data"}]]
- value = CacheValue(outputs=outputs)
-
- assert value.outputs == outputs
-
-
-class MockCacheProvider(CacheProvider):
- """Mock cache provider for testing."""
-
- def __init__(self):
- self.lookups = []
- self.stores = []
-
- async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
- self.lookups.append(context)
- return None
-
- async def on_store(self, context: CacheContext, value: CacheValue) -> None:
- self.stores.append((context, value))
From 5df1427124f6ceb70166326ee257d52076adea37 Mon Sep 17 00:00:00 2001
From: PxTicks
Date: Fri, 13 Mar 2026 00:44:15 +0000
Subject: [PATCH 16/80] Fix audio extraction and truncation bugs (#12652)
Bug report in #12651
- to_skip fix: Prevents negative array slicing when the start offset is negative.
- __duration check: Prevents the extraction loop from breaking after a single audio chunk when the requested duration is 0 (which is a sentinel for unlimited).
---
comfy_api/latest/_input_impl/video_types.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index 58a37c9e8..1b4993aa7 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
- to_skip = int(offset_seconds * audio_stream.sample_rate)
+ to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
if to_skip < frame.samples:
has_first_frame = True
break
@@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
- if frame.time > start_time + self.__duration:
+ if self.__duration and frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
From 63d1bbdb407c69370d407ce5ced6ca3f917528a8 Mon Sep 17 00:00:00 2001
From: comfyanonymous
Date: Thu, 12 Mar 2026 20:41:48 -0400
Subject: [PATCH 17/80] ComfyUI v0.17.0
---
comfyui_version.py | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfyui_version.py b/comfyui_version.py
index 2723d02e7..701f4d66a 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.16.4"
+__version__ = "0.17.0"
diff --git a/pyproject.toml b/pyproject.toml
index 753b219b3..e2ca79be7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.16.4"
+version = "0.17.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
From 4a8cf359fe596fc4c25a0d335d303e42c3f8605d Mon Sep 17 00:00:00 2001
From: Deep Mehta <42841935+deepme987@users.noreply.github.com>
Date: Thu, 12 Mar 2026 21:17:50 -0700
Subject: [PATCH 18/80] Revert "Revert "feat: Add CacheProvider API for
external distributed caching"" (#12915)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Revert "Revert "feat: Add CacheProvider API for external distributed caching …"
This reverts commit d1d53c14be8442fca19aae978e944edad1935d46.
* fix: gate provider lookups to outputs cache and fix UI coercion
- Add `enable_providers` flag to BasicCache so only the outputs cache
triggers external provider lookups/stores. The objects cache stores
node class instances, not CacheEntry values, so provider calls were
wasted round-trips that always missed.
- Remove `or {}` coercion on `result.ui` — an empty dict passes the
`is not None` gate in execution.py and causes KeyError when the
history builder indexes `["output"]` and `["meta"]`. Preserving
`None` correctly skips the ui_node_outputs addition.
---
comfy_api/latest/__init__.py | 35 ++
comfy_api/latest/_caching.py | 42 ++
comfy_execution/cache_provider.py | 138 ++++++
comfy_execution/caching.py | 196 +++++++--
comfy_execution/graph.py | 6 +-
execution.py | 147 ++++---
.../execution_test/test_cache_provider.py | 403 ++++++++++++++++++
7 files changed, 874 insertions(+), 93 deletions(-)
create mode 100644 comfy_api/latest/_caching.py
create mode 100644 comfy_execution/cache_provider.py
create mode 100644 tests-unit/execution_test/test_cache_provider.py
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index f2399422b..04973fea0 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
+ self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
@@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
+ class Caching(ProxiedSingleton):
+ """
+ External cache provider API for sharing cached node outputs
+ across ComfyUI instances.
+
+ Example::
+
+ from comfy_api.latest import Caching
+
+ class MyCacheProvider(Caching.CacheProvider):
+ async def on_lookup(self, context):
+ ... # check external storage
+
+ async def on_store(self, context, value):
+ ... # store to external storage
+
+ Caching.register_provider(MyCacheProvider())
+ """
+ from ._caching import CacheProvider, CacheContext, CacheValue
+
+ async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
+ """Register an external cache provider. Providers are called in registration order."""
+ from comfy_execution.cache_provider import register_cache_provider
+ register_cache_provider(provider)
+
+ async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
+ """Unregister a previously registered cache provider."""
+ from comfy_execution.cache_provider import unregister_cache_provider
+ unregister_cache_provider(provider)
+
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@@ -116,6 +147,9 @@ class Types:
VOXEL = VOXEL
File3D = File3D
+
+Caching = ComfyAPI_latest.Caching
+
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@@ -135,6 +169,7 @@ __all__ = [
"Input",
"InputImpl",
"Types",
+ "Caching",
"ComfyExtension",
"io",
"IO",
diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py
new file mode 100644
index 000000000..30c8848cd
--- /dev/null
+++ b/comfy_api/latest/_caching.py
@@ -0,0 +1,42 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+from dataclasses import dataclass
+
+
+@dataclass
+class CacheContext:
+ node_id: str
+ class_type: str
+ cache_key_hash: str # SHA256 hex digest
+
+
+@dataclass
+class CacheValue:
+ outputs: list
+ ui: dict = None
+
+
+class CacheProvider(ABC):
+ """Abstract base class for external cache providers.
+ Exceptions from provider methods are caught by the caller and never break execution.
+ """
+
+ @abstractmethod
+ async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
+ """Called on local cache miss. Return CacheValue if found, None otherwise."""
+ pass
+
+ @abstractmethod
+ async def on_store(self, context: CacheContext, value: CacheValue) -> None:
+ """Called after local store. Dispatched via asyncio.create_task."""
+ pass
+
+ def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
+ """Return False to skip external caching for this node. Default: True."""
+ return True
+
+ def on_prompt_start(self, prompt_id: str) -> None:
+ pass
+
+ def on_prompt_end(self, prompt_id: str) -> None:
+ pass
diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py
new file mode 100644
index 000000000..d455d08e8
--- /dev/null
+++ b/comfy_execution/cache_provider.py
@@ -0,0 +1,138 @@
+from typing import Any, Optional, Tuple, List
+import hashlib
+import json
+import logging
+import threading
+
+# Public types — source of truth is comfy_api.latest._caching
+from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
+
+_logger = logging.getLogger(__name__)
+
+
+_providers: List[CacheProvider] = []
+_providers_lock = threading.Lock()
+_providers_snapshot: Tuple[CacheProvider, ...] = ()
+
+
+def register_cache_provider(provider: CacheProvider) -> None:
+ """Register an external cache provider. Providers are called in registration order."""
+ global _providers_snapshot
+ with _providers_lock:
+ if provider in _providers:
+ _logger.warning(f"Provider {provider.__class__.__name__} already registered")
+ return
+ _providers.append(provider)
+ _providers_snapshot = tuple(_providers)
+ _logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
+
+
+def unregister_cache_provider(provider: CacheProvider) -> None:
+ global _providers_snapshot
+ with _providers_lock:
+ try:
+ _providers.remove(provider)
+ _providers_snapshot = tuple(_providers)
+ _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
+ except ValueError:
+ _logger.warning(f"Provider {provider.__class__.__name__} was not registered")
+
+
+def _get_cache_providers() -> Tuple[CacheProvider, ...]:
+ return _providers_snapshot
+
+
+def _has_cache_providers() -> bool:
+ return bool(_providers_snapshot)
+
+
+def _clear_cache_providers() -> None:
+ global _providers_snapshot
+ with _providers_lock:
+ _providers.clear()
+ _providers_snapshot = ()
+
+
+def _canonicalize(obj: Any) -> Any:
+ # Convert to canonical JSON-serializable form with deterministic ordering.
+ # Frozensets have non-deterministic iteration order between Python sessions.
+ # Raises ValueError for non-cacheable types (Unhashable, unknown) so that
+ # _serialize_cache_key returns None and external caching is skipped.
+ if isinstance(obj, frozenset):
+ return ("__frozenset__", sorted(
+ [_canonicalize(item) for item in obj],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ ))
+ elif isinstance(obj, set):
+ return ("__set__", sorted(
+ [_canonicalize(item) for item in obj],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ ))
+ elif isinstance(obj, tuple):
+ return ("__tuple__", [_canonicalize(item) for item in obj])
+ elif isinstance(obj, list):
+ return [_canonicalize(item) for item in obj]
+ elif isinstance(obj, dict):
+ return {"__dict__": sorted(
+ [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ )}
+ elif isinstance(obj, (int, float, str, bool, type(None))):
+ return (type(obj).__name__, obj)
+ elif isinstance(obj, bytes):
+ return ("__bytes__", obj.hex())
+ else:
+ raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
+
+
+def _serialize_cache_key(cache_key: Any) -> Optional[str]:
+ # Returns deterministic SHA256 hex digest, or None on failure.
+ # Uses JSON (not pickle) because pickle is non-deterministic across sessions.
+ try:
+ canonical = _canonicalize(cache_key)
+ json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
+ return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
+ except Exception as e:
+ _logger.warning(f"Failed to serialize cache key: {e}")
+ return None
+
+
+def _contains_self_unequal(obj: Any) -> bool:
+ # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
+ # never hit locally, but serialized form would match externally. Skip these.
+ try:
+ if not (obj == obj):
+ return True
+ except Exception:
+ return True
+ if isinstance(obj, (frozenset, tuple, list, set)):
+ return any(_contains_self_unequal(item) for item in obj)
+ if isinstance(obj, dict):
+ return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
+ if hasattr(obj, 'value'):
+ return _contains_self_unequal(obj.value)
+ return False
+
+
+def _estimate_value_size(value: CacheValue) -> int:
+ try:
+ import torch
+ except ImportError:
+ return 0
+
+ total = 0
+
+ def estimate(obj):
+ nonlocal total
+ if isinstance(obj, torch.Tensor):
+ total += obj.numel() * obj.element_size()
+ elif isinstance(obj, dict):
+ for v in obj.values():
+ estimate(v)
+ elif isinstance(obj, (list, tuple)):
+ for item in obj:
+ estimate(item)
+
+ for output in value.outputs:
+ estimate(output)
+ return total
diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py
index 326a279fc..78212bde3 100644
--- a/comfy_execution/caching.py
+++ b/comfy_execution/caching.py
@@ -1,3 +1,4 @@
+import asyncio
import bisect
import gc
import itertools
@@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
- def __init__(self, key_class):
+ def __init__(self, key_class, enable_providers=False):
self.key_class = key_class
self.initialized = False
+ self.enable_providers = enable_providers
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
+ self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
@@ -196,18 +199,138 @@ class BasicCache:
def poll(self, **kwargs):
pass
- def _set_immediate(self, node_id, value):
- assert self.initialized
- cache_key = self.cache_key_set.get_data_key(node_id)
- self.cache[cache_key] = value
-
- def _get_immediate(self, node_id):
+ def get_local(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
- else:
+ return None
+
+ def set_local(self, node_id, value):
+ assert self.initialized
+ cache_key = self.cache_key_set.get_data_key(node_id)
+ self.cache[cache_key] = value
+
+ async def _set_immediate(self, node_id, value):
+ assert self.initialized
+ cache_key = self.cache_key_set.get_data_key(node_id)
+ self.cache[cache_key] = value
+
+ await self._notify_providers_store(node_id, cache_key, value)
+
+ async def _get_immediate(self, node_id):
+ if not self.initialized:
+ return None
+ cache_key = self.cache_key_set.get_data_key(node_id)
+
+ if cache_key in self.cache:
+ return self.cache[cache_key]
+
+ external_result = await self._check_providers_lookup(node_id, cache_key)
+ if external_result is not None:
+ self.cache[cache_key] = external_result
+ return external_result
+
+ return None
+
+ async def _notify_providers_store(self, node_id, cache_key, value):
+ from comfy_execution.cache_provider import (
+ _has_cache_providers, _get_cache_providers,
+ CacheValue, _contains_self_unequal, _logger
+ )
+
+ if not self.enable_providers:
+ return
+ if not _has_cache_providers():
+ return
+ if not self._is_external_cacheable_value(value):
+ return
+ if _contains_self_unequal(cache_key):
+ return
+
+ context = self._build_context(node_id, cache_key)
+ if context is None:
+ return
+ cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
+
+ for provider in _get_cache_providers():
+ try:
+ if provider.should_cache(context, cache_value):
+ task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
+ self._pending_store_tasks.add(task)
+ task.add_done_callback(self._pending_store_tasks.discard)
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
+
+ @staticmethod
+ async def _safe_provider_store(provider, context, cache_value):
+ from comfy_execution.cache_provider import _logger
+ try:
+ await provider.on_store(context, cache_value)
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
+
+ async def _check_providers_lookup(self, node_id, cache_key):
+ from comfy_execution.cache_provider import (
+ _has_cache_providers, _get_cache_providers,
+ CacheValue, _contains_self_unequal, _logger
+ )
+
+ if not self.enable_providers:
+ return None
+ if not _has_cache_providers():
+ return None
+ if _contains_self_unequal(cache_key):
+ return None
+
+ context = self._build_context(node_id, cache_key)
+ if context is None:
+ return None
+
+ for provider in _get_cache_providers():
+ try:
+ if not provider.should_cache(context):
+ continue
+ result = await provider.on_lookup(context)
+ if result is not None:
+ if not isinstance(result, CacheValue):
+ _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
+ continue
+ if not isinstance(result.outputs, (list, tuple)):
+ _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
+ continue
+ from execution import CacheEntry
+ return CacheEntry(ui=result.ui, outputs=list(result.outputs))
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
+
+ return None
+
+ def _is_external_cacheable_value(self, value):
+ return hasattr(value, 'outputs') and hasattr(value, 'ui')
+
+ def _get_class_type(self, node_id):
+ if not self.initialized or not self.dynprompt:
+ return ''
+ try:
+ return self.dynprompt.get_node(node_id).get('class_type', '')
+ except Exception:
+ return ''
+
+ def _build_context(self, node_id, cache_key):
+ from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
+ try:
+ cache_key_hash = _serialize_cache_key(cache_key)
+ if cache_key_hash is None:
+ return None
+ return CacheContext(
+ node_id=node_id,
+ class_type=self._get_class_type(node_id),
+ cache_key_hash=cache_key_hash,
+ )
+ except Exception as e:
+ _logger.warning(f"Failed to build cache context for node {node_id}: {e}")
return None
async def _ensure_subcache(self, node_id, children_ids):
@@ -236,8 +359,8 @@ class BasicCache:
return result
class HierarchicalCache(BasicCache):
- def __init__(self, key_class):
- super().__init__(key_class)
+ def __init__(self, key_class, enable_providers=False):
+ super().__init__(key_class, enable_providers=enable_providers)
def _get_cache_for(self, node_id):
assert self.dynprompt is not None
@@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache):
return None
return cache
- def get(self, node_id):
+ async def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
- return cache._get_immediate(node_id)
+ return await cache._get_immediate(node_id)
- def set(self, node_id, value):
+ def get_local(self, node_id):
+ cache = self._get_cache_for(node_id)
+ if cache is None:
+ return None
+ return BasicCache.get_local(cache, node_id)
+
+ async def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
- cache._set_immediate(node_id, value)
+ await cache._set_immediate(node_id, value)
+
+ def set_local(self, node_id, value):
+ cache = self._get_cache_for(node_id)
+ assert cache is not None
+ BasicCache.set_local(cache, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@@ -287,18 +421,24 @@ class NullCache:
def poll(self, **kwargs):
pass
- def get(self, node_id):
+ async def get(self, node_id):
return None
- def set(self, node_id, value):
+ def get_local(self, node_id):
+ return None
+
+ async def set(self, node_id, value):
+ pass
+
+ def set_local(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
return self
class LRUCache(BasicCache):
- def __init__(self, key_class, max_size=100):
- super().__init__(key_class)
+ def __init__(self, key_class, max_size=100, enable_providers=False):
+ super().__init__(key_class, enable_providers=enable_providers)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
@@ -322,18 +462,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
- def get(self, node_id):
+ async def get(self, node_id):
self._mark_used(node_id)
- return self._get_immediate(node_id)
+ return await self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
- def set(self, node_id, value):
+ async def set(self, node_id, value):
self._mark_used(node_id)
- return self._set_immediate(node_id, value)
+ return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
- def __init__(self, key_class):
- super().__init__(key_class, 0)
+ def __init__(self, key_class, enable_providers=False):
+ super().__init__(key_class, 0, enable_providers=enable_providers)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
- def set(self, node_id, value):
+ async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
- super().set(node_id, value)
+ await super().set(node_id, value)
- def get(self, node_id):
+ async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
- return super().get(node_id)
+ return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py
index 9d170b16e..c47f3c79b 100644
--- a/comfy_execution/graph.py
+++ b/comfy_execution/graph.py
@@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {}
def is_cached(self, node_id):
- return self.output_cache.get(node_id) is not None
+ return self.output_cache.get_local(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {}
- self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
+ self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
@@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None:
return None
#Write back to the main cache on touch.
- self.output_cache.set(from_node_id, value)
+ self.output_cache.set_local(from_node_id, value)
return value
def cache_update(self, node_id, value):
diff --git a/execution.py b/execution.py
index a7791efed..1a6c3429c 100644
--- a/execution.py
+++ b/execution.py
@@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
+from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
class ExecutionResult(Enum):
@@ -126,15 +127,15 @@ class CacheSet:
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
- self.outputs = HierarchicalCache(CacheKeySetInputSignature)
+ self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
- self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
+ self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
- self.outputs = RAMPressureCache(CacheKeySetInputSignature)
+ self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
@@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
- cached = caches.outputs.get(unique_id)
+ cached = await caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
@@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
- obj = caches.objects.get(unique_id)
+ obj = await caches.objects.get(unique_id)
if obj is None:
obj = class_def()
- caches.objects.set(unique_id, obj)
+ await caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
- caches.outputs.set(unique_id, cache_entry)
+ await caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@@ -684,6 +685,19 @@ class PromptExecutor:
}
self.add_message("execution_error", mes, broadcast=False)
+ def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
+ if not _has_cache_providers():
+ return
+
+ for provider in _get_cache_providers():
+ try:
+ if event == "start":
+ provider.on_prompt_start(prompt_id)
+ elif event == "end":
+ provider.on_prompt_end(prompt_id)
+ except Exception as e:
+ _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
+
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@@ -700,66 +714,75 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
- with torch.inference_mode():
- dynamic_prompt = DynamicPrompt(prompt)
- reset_progress_state(prompt_id, dynamic_prompt)
- add_progress_handler(WebUIProgressHandler(self.server))
- is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
- for cache in self.caches.all:
- await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
- cache.clean_unused()
+ self._notify_prompt_lifecycle("start", prompt_id)
- cached_nodes = []
- for node_id in prompt:
- if self.caches.outputs.get(node_id) is not None:
- cached_nodes.append(node_id)
+ try:
+ with torch.inference_mode():
+ dynamic_prompt = DynamicPrompt(prompt)
+ reset_progress_state(prompt_id, dynamic_prompt)
+ add_progress_handler(WebUIProgressHandler(self.server))
+ is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
+ for cache in self.caches.all:
+ await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
+ cache.clean_unused()
- comfy.model_management.cleanup_models_gc()
- self.add_message("execution_cached",
- { "nodes": cached_nodes, "prompt_id": prompt_id},
- broadcast=False)
- pending_subgraph_results = {}
- pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
- ui_node_outputs = {}
- executed = set()
- execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
- current_outputs = self.caches.outputs.all_node_ids()
- for node_id in list(execute_outputs):
- execution_list.add_node(node_id)
+ node_ids = list(prompt.keys())
+ cache_results = await asyncio.gather(
+ *(self.caches.outputs.get(node_id) for node_id in node_ids)
+ )
+ cached_nodes = [
+ node_id for node_id, result in zip(node_ids, cache_results)
+ if result is not None
+ ]
- while not execution_list.is_empty():
- node_id, error, ex = await execution_list.stage_node_execution()
- if error is not None:
- self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
- break
+ comfy.model_management.cleanup_models_gc()
+ self.add_message("execution_cached",
+ { "nodes": cached_nodes, "prompt_id": prompt_id},
+ broadcast=False)
+ pending_subgraph_results = {}
+ pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
+ ui_node_outputs = {}
+ executed = set()
+ execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
+ current_outputs = self.caches.outputs.all_node_ids()
+ for node_id in list(execute_outputs):
+ execution_list.add_node(node_id)
- assert node_id is not None, "Node ID should not be None at this point"
- result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
- self.success = result != ExecutionResult.FAILURE
- if result == ExecutionResult.FAILURE:
- self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
- break
- elif result == ExecutionResult.PENDING:
- execution_list.unstage_node_execution()
- else: # result == ExecutionResult.SUCCESS:
- execution_list.complete_node_execution()
- self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
- else:
- # Only execute when the while-loop ends without break
- self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
+ while not execution_list.is_empty():
+ node_id, error, ex = await execution_list.stage_node_execution()
+ if error is not None:
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
+ break
- ui_outputs = {}
- meta_outputs = {}
- for node_id, ui_info in ui_node_outputs.items():
- ui_outputs[node_id] = ui_info["output"]
- meta_outputs[node_id] = ui_info["meta"]
- self.history_result = {
- "outputs": ui_outputs,
- "meta": meta_outputs,
- }
- self.server.last_node_id = None
- if comfy.model_management.DISABLE_SMART_MEMORY:
- comfy.model_management.unload_all_models()
+ assert node_id is not None, "Node ID should not be None at this point"
+ result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
+ self.success = result != ExecutionResult.FAILURE
+ if result == ExecutionResult.FAILURE:
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
+ break
+ elif result == ExecutionResult.PENDING:
+ execution_list.unstage_node_execution()
+ else: # result == ExecutionResult.SUCCESS:
+ execution_list.complete_node_execution()
+ self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
+ else:
+ # Only execute when the while-loop ends without break
+ self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
+
+ ui_outputs = {}
+ meta_outputs = {}
+ for node_id, ui_info in ui_node_outputs.items():
+ ui_outputs[node_id] = ui_info["output"]
+ meta_outputs[node_id] = ui_info["meta"]
+ self.history_result = {
+ "outputs": ui_outputs,
+ "meta": meta_outputs,
+ }
+ self.server.last_node_id = None
+ if comfy.model_management.DISABLE_SMART_MEMORY:
+ comfy.model_management.unload_all_models()
+ finally:
+ self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated):
diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py
new file mode 100644
index 000000000..ac3814746
--- /dev/null
+++ b/tests-unit/execution_test/test_cache_provider.py
@@ -0,0 +1,403 @@
+"""Tests for external cache provider API."""
+
+import importlib.util
+import pytest
+from typing import Optional
+
+
+def _torch_available() -> bool:
+ """Check if PyTorch is available."""
+ return importlib.util.find_spec("torch") is not None
+
+
+from comfy_execution.cache_provider import (
+ CacheProvider,
+ CacheContext,
+ CacheValue,
+ register_cache_provider,
+ unregister_cache_provider,
+ _get_cache_providers,
+ _has_cache_providers,
+ _clear_cache_providers,
+ _serialize_cache_key,
+ _contains_self_unequal,
+ _estimate_value_size,
+ _canonicalize,
+)
+
+
+class TestCanonicalize:
+ """Test _canonicalize function for deterministic ordering."""
+
+ def test_frozenset_ordering_is_deterministic(self):
+ """Frozensets should produce consistent canonical form regardless of iteration order."""
+ # Create two frozensets with same content
+ fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
+ fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
+
+ result1 = _canonicalize(fs1)
+ result2 = _canonicalize(fs2)
+
+ assert result1 == result2
+
+ def test_nested_frozenset_ordering(self):
+ """Nested frozensets should also be deterministically ordered."""
+ inner1 = frozenset([1, 2, 3])
+ inner2 = frozenset([3, 2, 1])
+
+ fs1 = frozenset([("key", inner1)])
+ fs2 = frozenset([("key", inner2)])
+
+ result1 = _canonicalize(fs1)
+ result2 = _canonicalize(fs2)
+
+ assert result1 == result2
+
+ def test_dict_ordering(self):
+ """Dicts should be sorted by key."""
+ d1 = {"z": 1, "a": 2, "m": 3}
+ d2 = {"a": 2, "m": 3, "z": 1}
+
+ result1 = _canonicalize(d1)
+ result2 = _canonicalize(d2)
+
+ assert result1 == result2
+
+ def test_tuple_preserved(self):
+ """Tuples should be marked and preserved."""
+ t = (1, 2, 3)
+ result = _canonicalize(t)
+
+ assert result[0] == "__tuple__"
+
+ def test_list_preserved(self):
+ """Lists should be recursively canonicalized."""
+ lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
+ result = _canonicalize(lst)
+
+ # First element should be canonicalized dict
+ assert "__dict__" in result[0]
+ # Second element should be canonicalized frozenset
+ assert result[1][0] == "__frozenset__"
+
+ def test_primitives_include_type(self):
+ """Primitive types should include type name for disambiguation."""
+ assert _canonicalize(42) == ("int", 42)
+ assert _canonicalize(3.14) == ("float", 3.14)
+ assert _canonicalize("hello") == ("str", "hello")
+ assert _canonicalize(True) == ("bool", True)
+ assert _canonicalize(None) == ("NoneType", None)
+
+ def test_int_and_str_distinguished(self):
+ """int 7 and str '7' must produce different canonical forms."""
+ assert _canonicalize(7) != _canonicalize("7")
+
+ def test_bytes_converted(self):
+ """Bytes should be converted to hex string."""
+ b = b"\x00\xff"
+ result = _canonicalize(b)
+
+ assert result[0] == "__bytes__"
+ assert result[1] == "00ff"
+
+ def test_set_ordering(self):
+ """Sets should be sorted like frozensets."""
+ s1 = {3, 1, 2}
+ s2 = {1, 2, 3}
+
+ result1 = _canonicalize(s1)
+ result2 = _canonicalize(s2)
+
+ assert result1 == result2
+ assert result1[0] == "__set__"
+
+ def test_unknown_type_raises(self):
+ """Unknown types should raise ValueError (fail-closed)."""
+ class CustomObj:
+ pass
+ with pytest.raises(ValueError):
+ _canonicalize(CustomObj())
+
+ def test_object_with_value_attr_raises(self):
+ """Objects with .value attribute (Unhashable-like) should raise ValueError."""
+ class FakeUnhashable:
+ def __init__(self):
+ self.value = float('nan')
+ with pytest.raises(ValueError):
+ _canonicalize(FakeUnhashable())
+
+
+class TestSerializeCacheKey:
+ """Test _serialize_cache_key for deterministic hashing."""
+
+ def test_same_content_same_hash(self):
+ """Same content should produce same hash."""
+ key1 = frozenset([("node_1", frozenset([("input", "value")]))])
+ key2 = frozenset([("node_1", frozenset([("input", "value")]))])
+
+ hash1 = _serialize_cache_key(key1)
+ hash2 = _serialize_cache_key(key2)
+
+ assert hash1 == hash2
+
+ def test_different_content_different_hash(self):
+ """Different content should produce different hash."""
+ key1 = frozenset([("node_1", "value_a")])
+ key2 = frozenset([("node_1", "value_b")])
+
+ hash1 = _serialize_cache_key(key1)
+ hash2 = _serialize_cache_key(key2)
+
+ assert hash1 != hash2
+
+ def test_returns_hex_string(self):
+ """Should return hex string (SHA256 hex digest)."""
+ key = frozenset([("test", 123)])
+ result = _serialize_cache_key(key)
+
+ assert isinstance(result, str)
+ assert len(result) == 64 # SHA256 hex digest is 64 chars
+
+ def test_complex_nested_structure(self):
+ """Complex nested structures should hash deterministically."""
+ # Note: frozensets can only contain hashable types, so we use
+ # nested frozensets of tuples to represent dict-like structures
+ key = frozenset([
+ ("node_1", frozenset([
+ ("input_a", ("tuple", "value")),
+ ("input_b", frozenset([("nested", "dict")])),
+ ])),
+ ("node_2", frozenset([
+ ("param", 42),
+ ])),
+ ])
+
+ # Hash twice to verify determinism
+ hash1 = _serialize_cache_key(key)
+ hash2 = _serialize_cache_key(key)
+
+ assert hash1 == hash2
+
+ def test_dict_in_cache_key(self):
+ """Dicts passed directly to _serialize_cache_key should work."""
+ key = {"node_1": {"input": "value"}, "node_2": 42}
+
+ hash1 = _serialize_cache_key(key)
+ hash2 = _serialize_cache_key(key)
+
+ assert hash1 == hash2
+ assert isinstance(hash1, str)
+ assert len(hash1) == 64
+
+ def test_unknown_type_returns_none(self):
+ """Non-cacheable types should return None (fail-closed)."""
+ class CustomObj:
+ pass
+ assert _serialize_cache_key(CustomObj()) is None
+
+
+class TestContainsSelfUnequal:
+ """Test _contains_self_unequal utility function."""
+
+ def test_nan_float_detected(self):
+ """NaN floats should be detected (not equal to itself)."""
+ assert _contains_self_unequal(float('nan')) is True
+
+ def test_regular_float_not_detected(self):
+ """Regular floats are equal to themselves."""
+ assert _contains_self_unequal(3.14) is False
+ assert _contains_self_unequal(0.0) is False
+ assert _contains_self_unequal(-1.5) is False
+
+ def test_infinity_not_detected(self):
+ """Infinity is equal to itself."""
+ assert _contains_self_unequal(float('inf')) is False
+ assert _contains_self_unequal(float('-inf')) is False
+
+ def test_nan_in_list(self):
+ """NaN in list should be detected."""
+ assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
+ assert _contains_self_unequal([1, 2, 3, 4]) is False
+
+ def test_nan_in_tuple(self):
+ """NaN in tuple should be detected."""
+ assert _contains_self_unequal((1, float('nan'))) is True
+ assert _contains_self_unequal((1, 2, 3)) is False
+
+ def test_nan_in_frozenset(self):
+ """NaN in frozenset should be detected."""
+ assert _contains_self_unequal(frozenset([1, float('nan')])) is True
+ assert _contains_self_unequal(frozenset([1, 2, 3])) is False
+
+ def test_nan_in_dict_value(self):
+ """NaN in dict value should be detected."""
+ assert _contains_self_unequal({"key": float('nan')}) is True
+ assert _contains_self_unequal({"key": 42}) is False
+
+ def test_nan_in_nested_structure(self):
+ """NaN in deeply nested structure should be detected."""
+ nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
+ assert _contains_self_unequal(nested) is True
+
+ def test_non_numeric_types(self):
+ """Non-numeric types should not be self-unequal."""
+ assert _contains_self_unequal("string") is False
+ assert _contains_self_unequal(None) is False
+ assert _contains_self_unequal(True) is False
+
+ def test_object_with_nan_value_attribute(self):
+ """Objects wrapping NaN in .value should be detected."""
+ class NanWrapper:
+ def __init__(self):
+ self.value = float('nan')
+ assert _contains_self_unequal(NanWrapper()) is True
+
+ def test_custom_self_unequal_object(self):
+ """Custom objects where not (x == x) should be detected."""
+ class NeverEqual:
+ def __eq__(self, other):
+ return False
+ assert _contains_self_unequal(NeverEqual()) is True
+
+
+class TestEstimateValueSize:
+ """Test _estimate_value_size utility function."""
+
+ def test_empty_outputs(self):
+ """Empty outputs should have zero size."""
+ value = CacheValue(outputs=[])
+ assert _estimate_value_size(value) == 0
+
+ @pytest.mark.skipif(
+ not _torch_available(),
+ reason="PyTorch not available"
+ )
+ def test_tensor_size_estimation(self):
+ """Tensor size should be estimated correctly."""
+ import torch
+
+ # 1000 float32 elements = 4000 bytes
+ tensor = torch.zeros(1000, dtype=torch.float32)
+ value = CacheValue(outputs=[[tensor]])
+
+ size = _estimate_value_size(value)
+ assert size == 4000
+
+ @pytest.mark.skipif(
+ not _torch_available(),
+ reason="PyTorch not available"
+ )
+ def test_nested_tensor_in_dict(self):
+ """Tensors nested in dicts should be counted."""
+ import torch
+
+ tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
+ value = CacheValue(outputs=[[{"samples": tensor}]])
+
+ size = _estimate_value_size(value)
+ assert size == 400
+
+
+class TestProviderRegistry:
+ """Test cache provider registration and retrieval."""
+
+ def setup_method(self):
+ """Clear providers before each test."""
+ _clear_cache_providers()
+
+ def teardown_method(self):
+ """Clear providers after each test."""
+ _clear_cache_providers()
+
+ def test_register_provider(self):
+ """Provider should be registered successfully."""
+ provider = MockCacheProvider()
+ register_cache_provider(provider)
+
+ assert _has_cache_providers() is True
+ providers = _get_cache_providers()
+ assert len(providers) == 1
+ assert providers[0] is provider
+
+ def test_unregister_provider(self):
+ """Provider should be unregistered successfully."""
+ provider = MockCacheProvider()
+ register_cache_provider(provider)
+ unregister_cache_provider(provider)
+
+ assert _has_cache_providers() is False
+
+ def test_multiple_providers(self):
+ """Multiple providers can be registered."""
+ provider1 = MockCacheProvider()
+ provider2 = MockCacheProvider()
+
+ register_cache_provider(provider1)
+ register_cache_provider(provider2)
+
+ providers = _get_cache_providers()
+ assert len(providers) == 2
+
+ def test_duplicate_registration_ignored(self):
+ """Registering same provider twice should be ignored."""
+ provider = MockCacheProvider()
+
+ register_cache_provider(provider)
+ register_cache_provider(provider) # Should be ignored
+
+ providers = _get_cache_providers()
+ assert len(providers) == 1
+
+ def test_clear_providers(self):
+ """_clear_cache_providers should remove all providers."""
+ provider1 = MockCacheProvider()
+ provider2 = MockCacheProvider()
+
+ register_cache_provider(provider1)
+ register_cache_provider(provider2)
+ _clear_cache_providers()
+
+ assert _has_cache_providers() is False
+ assert len(_get_cache_providers()) == 0
+
+
+class TestCacheContext:
+ """Test CacheContext dataclass."""
+
+ def test_context_creation(self):
+ """CacheContext should be created with all fields."""
+ context = CacheContext(
+ node_id="node-456",
+ class_type="KSampler",
+ cache_key_hash="a" * 64,
+ )
+
+ assert context.node_id == "node-456"
+ assert context.class_type == "KSampler"
+ assert context.cache_key_hash == "a" * 64
+
+
+class TestCacheValue:
+ """Test CacheValue dataclass."""
+
+ def test_value_creation(self):
+ """CacheValue should be created with outputs."""
+ outputs = [[{"samples": "tensor_data"}]]
+ value = CacheValue(outputs=outputs)
+
+ assert value.outputs == outputs
+
+
+class MockCacheProvider(CacheProvider):
+ """Mock cache provider for testing."""
+
+ def __init__(self):
+ self.lookups = []
+ self.stores = []
+
+ async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
+ self.lookups.append(context)
+ return None
+
+ async def on_store(self, context: CacheContext, value: CacheValue) -> None:
+ self.stores.append((context, value))
From f9ceed9eefe20f6b54b801096cb80f874316f5b2 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Fri, 13 Mar 2026 19:10:40 +0200
Subject: [PATCH 19/80] fix(api-nodes): Tencent TextToModel and ImageToModel
nodes (#12680)
* fix(api-nodes): added "texture_image" output to TencentTextToModel and TencentImageToModel nodes. Fixed `OBJ` output when it is zipped
* support additional solid texture outputs
* fixed and enabled Tencent3DTextureEdit node
---
comfy_api_nodes/nodes_hunyuan3d.py | 97 +++++++++++++++++++++++++++---
1 file changed, 88 insertions(+), 9 deletions(-)
diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py
index bd8bde997..753c09b6e 100644
--- a/comfy_api_nodes/nodes_hunyuan3d.py
+++ b/comfy_api_nodes/nodes_hunyuan3d.py
@@ -1,3 +1,7 @@
+import zipfile
+from io import BytesIO
+
+import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, Types
@@ -17,7 +21,10 @@ from comfy_api_nodes.apis.hunyuan3d import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
+ bytesio_to_image_tensor,
+ download_url_to_bytesio,
download_url_to_file_3d,
+ download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
@@ -36,6 +43,68 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
)
+class ObjZipResult:
+ __slots__ = ("obj", "texture", "metallic", "normal", "roughness")
+
+ def __init__(
+ self,
+ obj: Types.File3D,
+ texture: Input.Image | None = None,
+ metallic: Input.Image | None = None,
+ normal: Input.Image | None = None,
+ roughness: Input.Image | None = None,
+ ):
+ self.obj = obj
+ self.texture = texture
+ self.metallic = metallic
+ self.normal = normal
+ self.roughness = roughness
+
+
+async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
+ """The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
+
+ When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
+ identified by their filename suffixes.
+ """
+ data = BytesIO()
+ await download_url_to_bytesio(url, data)
+ data.seek(0)
+ if not zipfile.is_zipfile(data):
+ data.seek(0)
+ return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
+ data.seek(0)
+ obj_bytes = None
+ textures: dict[str, Input.Image] = {}
+ with zipfile.ZipFile(data) as zf:
+ for name in zf.namelist():
+ lower = name.lower()
+ if lower.endswith(".obj"):
+ obj_bytes = zf.read(name)
+ elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
+ stem = lower.rsplit(".", 1)[0]
+ tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
+ matched_key = "texture"
+ for suffix, key in {
+ "_metallic": "metallic",
+ "_normal": "normal",
+ "_roughness": "roughness",
+ }.items():
+ if stem.endswith(suffix):
+ matched_key = key
+ break
+ textures[matched_key] = tensor
+ if obj_bytes is None:
+ raise ValueError("ZIP archive does not contain an OBJ file.")
+ return ObjZipResult(
+ obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
+ texture=textures.get("texture"),
+ metallic=textures.get("metallic"),
+ normal=textures.get("normal"),
+ roughness=textures.get("roughness"),
+ )
+
+
def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None:
@@ -93,6 +162,7 @@ class TencentTextToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
+ IO.Image.Output(display_name="texture_image"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -151,14 +221,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
+ obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
- await download_url_to_file_3d(
- get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
- ),
+ obj_result.obj,
+ obj_result.texture,
)
@@ -211,6 +281,10 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
+ IO.Image.Output(display_name="texture_image"),
+ IO.Image.Output(display_name="optional_metallic"),
+ IO.Image.Output(display_name="optional_normal"),
+ IO.Image.Output(display_name="optional_roughness"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -304,14 +378,17 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
+ obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
- await download_url_to_file_3d(
- get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
- ),
+ obj_result.obj,
+ obj_result.texture,
+ obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
+ obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
+ obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
)
@@ -431,7 +508,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
],
outputs=[
IO.File3DGLB.Output(display_name="GLB"),
- IO.File3DFBX.Output(display_name="FBX"),
+ IO.File3DOBJ.Output(display_name="OBJ"),
+ IO.Image.Output(display_name="texture_image"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -480,7 +558,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
- await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
+ await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
)
@@ -654,7 +733,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
TencentTextToModelNode,
TencentImageToModelNode,
TencentModelTo3DUVNode,
- # Tencent3DTextureEditNode,
+ Tencent3DTextureEditNode,
Tencent3DPartNode,
TencentSmartTopologyNode,
]
From 6cd35a0c5fd7d22df858be175f6a6e6ee0212e55 Mon Sep 17 00:00:00 2001
From: Comfy Org PR Bot
Date: Sat, 14 Mar 2026 03:31:25 +0900
Subject: [PATCH 20/80] Bump comfyui-frontend-package to 1.41.19 (#12923)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 511c62fee..6efb77f29 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.41.18
+comfyui-frontend-package==1.41.19
comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3
torch
From e1f10ca0932faf289757e7ec27a54894e271fdde Mon Sep 17 00:00:00 2001
From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com>
Date: Sat, 14 Mar 2026 09:14:27 +0900
Subject: [PATCH 21/80] bump manager version to 4.1b4 (#12930)
---
manager_requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/manager_requirements.txt b/manager_requirements.txt
index 6bcc3fb50..37a33bd4f 100644
--- a/manager_requirements.txt
+++ b/manager_requirements.txt
@@ -1 +1 @@
-comfyui_manager==4.1b2
\ No newline at end of file
+comfyui_manager==4.1b4
\ No newline at end of file
From 7810f49702eac6e617eb7f2c30b00a8939ef1404 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Fri, 13 Mar 2026 19:18:08 -0700
Subject: [PATCH 22/80] comfy aimdo 0.2.11 + Improved RAM Pressure release
strategies - Windows speedups (#12925)
* Implement seek and read for pins
Source pins from an mmap is pad because its its a CPU->CPU copy that
attempts to fully buffer the same data twice. Instead, use seek and
read which avoids the mmap buffering while usually being a faster
read in the first place (avoiding mmap faulting etc).
* pinned_memory: Use Aimdo pinner
The aimdo pinner bypasses pytorches CPU allocator which can leak
windows commit charge.
* ops: bypass init() of weight for embedding layer
This similarly consumes large commit charge especially for TEs. It can
cause a permanement leaked commit charge which can destabilize on
systems close to the commit ceiling and generally confuses the RAM
stats.
* model_patcher: implement pinned memory counter
Implement a pinned memory counter for better accounting of what volume
of memory pins have.
* implement touch accounting
Implement accounting of touching mmapped tensors.
* mm+mp: add residency mmap getter
* utils: use the aimdo mmap to load sft files
* model_management: Implement tigher RAM pressure semantics
Implement a pressure release on entire MMAPs as windows does perform
faster when mmaps are unloaded and model loads free ramp into fully
unallocated RAM.
Make the concept of freeing for pins a completely separate concept.
Now that pins are loadable directly from original file and don' touch
the mmap, tighten the freeing budget to just the current loaded model
- what you have left over. This still over-frees pins, but its a lot
better than before.
So after the pins are freed with that algorithm, bounce entire MMAPs
to free RAM based on what the model needs, deducting off any known
resident-in-mmap tensors to the free quota to keep it as tight as
possible.
* comfy-aimdo 0.2.11
Comfy aimdo 0.2.11
* mm: Implement file_slice path for QT
* ruff
* ops: put meta-tensors in place to allow custom nodes to check geo
---
comfy/memory_management.py | 59 +++++++++++++++++++++
comfy/model_management.py | 74 ++++++++++++++++++++++-----
comfy/model_patcher.py | 17 +++++++
comfy/ops.py | 102 ++++++++++++++++++++++++++++---------
comfy/pinned_memory.py | 26 +++++++---
comfy/utils.py | 28 +++++++---
requirements.txt | 2 +-
7 files changed, 258 insertions(+), 50 deletions(-)
diff --git a/comfy/memory_management.py b/comfy/memory_management.py
index 0b7da2852..563224098 100644
--- a/comfy/memory_management.py
+++ b/comfy/memory_management.py
@@ -1,9 +1,68 @@
import math
+import ctypes
+import threading
+import dataclasses
import torch
from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor
+
+class TensorFileSlice(NamedTuple):
+ file_ref: object
+ thread_id: int
+ offset: int
+ size: int
+
+
+def read_tensor_file_slice_into(tensor, destination):
+
+ if isinstance(tensor, QuantizedTensor):
+ if not isinstance(destination, QuantizedTensor):
+ return False
+ if tensor._layout_cls != destination._layout_cls:
+ return False
+
+ if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
+ return False
+
+ dst_orig_dtype = destination._params.orig_dtype
+ destination._params.copy_from(tensor._params, non_blocking=False)
+ destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
+ return True
+
+ info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
+ if info is None:
+ return False
+
+ file_obj = info.file_ref
+ if (destination.device.type != "cpu"
+ or file_obj is None
+ or threading.get_ident() != info.thread_id
+ or destination.numel() * destination.element_size() < info.size):
+ return False
+
+ if info.size == 0:
+ return True
+
+ buf_type = ctypes.c_ubyte * info.size
+ view = memoryview(buf_type.from_address(destination.data_ptr()))
+
+ try:
+ file_obj.seek(info.offset)
+ done = 0
+ while done < info.size:
+ try:
+ n = file_obj.readinto(view[done:])
+ except OSError:
+ return False
+ if n <= 0:
+ return False
+ done += n
+ return True
+ finally:
+ view.release()
+
class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 81c89b180..4d5851bc0 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -505,6 +505,28 @@ def module_size(module):
module_mem += t.nbytes
return module_mem
+def module_mmap_residency(module, free=False):
+ mmap_touched_mem = 0
+ module_mem = 0
+ bounced_mmaps = set()
+ sd = module.state_dict()
+ for k in sd:
+ t = sd[k]
+ module_mem += t.nbytes
+ storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
+ if not getattr(storage, "_comfy_tensor_mmap_touched", False):
+ continue
+ mmap_touched_mem += t.nbytes
+ if not free:
+ continue
+ storage._comfy_tensor_mmap_touched = False
+ mmap_obj = storage._comfy_tensor_mmap_refs[0]
+ if mmap_obj in bounced_mmaps:
+ continue
+ mmap_obj.bounce()
+ bounced_mmaps.add(mmap_obj)
+ return mmap_touched_mem, module_mem
+
class LoadedModel:
def __init__(self, model):
self._set_model(model)
@@ -532,6 +554,9 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
+ def model_mmap_residency(self, free=False):
+ return self.model.model_mmap_residency(free=free)
+
def model_loaded_memory(self):
return self.model.loaded_size()
@@ -633,7 +658,7 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
-def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
+def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc()
unloaded_model = []
can_unload = []
@@ -646,13 +671,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
- for x in sorted(can_unload):
+ can_unload_sorted = sorted(can_unload)
+ for x in can_unload_sorted:
i = x[-1]
memory_to_free = 1e32
- ram_to_free = 1e32
+ pins_to_free = 1e32
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
- ram_to_free = ram_required - get_free_ram()
+ pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
@@ -661,9 +687,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
- if ram_to_free > 0:
+ if pins_to_free > 0:
+ logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
+ current_loaded_models[i].model.partially_unload_ram(pins_to_free)
+
+ for x in can_unload_sorted:
+ i = x[-1]
+ ram_to_free = ram_required - psutil.virtual_memory().available
+ if ram_to_free <= 0 and i not in unloaded_model:
+ continue
+ resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
+ if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
- current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@@ -729,17 +764,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = {}
+ total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
- total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
- #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
- #want to do.
- #FIXME: This should subtract off the to_load current pin consumption.
- total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
+ device = loaded_model.device
+ total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
+ resident_memory, model_memory = loaded_model.model_mmap_residency()
+ pinned_memory = loaded_model.model.pinned_memory_size()
+ #FIXME: This can over-free the pins as it budgets to pin the entire model. We should
+ #make this JIT to keep as much pinned as possible.
+ pins_required = model_memory - pinned_memory
+ ram_required = model_memory - resident_memory
+ total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
+ total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required:
if device != torch.device("cpu"):
- free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
+ free_memory(total_memory_required[device] * 1.1 + extra_mem,
+ device,
+ for_dynamic=free_for_dynamic,
+ pins_required=total_pins_required[device],
+ ram_required=total_ram_required[device])
for device in total_memory_required:
if device != torch.device("cpu"):
@@ -1225,6 +1270,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0)
if tensor is None:
continue
+ if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
+ continue
+ storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
+ if hasattr(storage, "_comfy_tensor_mmap_touched"):
+ storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index bc3a8f446..c26d37db2 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -297,6 +297,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
+ def model_mmap_residency(self, free=False):
+ return comfy.model_management.module_mmap_residency(self.model, free=free)
+
def get_ram_usage(self):
return self.model_size()
@@ -1063,6 +1066,10 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
+ def pinned_memory_size(self):
+ # Pinned memory pressure tracking is only implemented for DynamicVram loading
+ return 0
+
def partially_unload_ram(self, ram_to_unload):
pass
@@ -1653,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
+ def pinned_memory_size(self):
+ total = 0
+ loading = self._load_list(for_dynamic=True)
+ for x in loading:
+ _, _, _, _, m, _ = x
+ pin = comfy.pinned_memory.get_pin(m)
+ if pin is not None:
+ total += pin.numel() * pin.element_size()
+ return total
+
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
diff --git a/comfy/ops.py b/comfy/ops.py
index 87b36b5c5..3f2da4e63 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -306,6 +306,33 @@ class CastWeightBiasOp:
bias_function = []
class disable_weight_init:
+ @staticmethod
+ def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
+ missing_keys, unexpected_keys, weight_shape,
+ bias_shape=None):
+ assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
+ prefix_len = len(prefix)
+ for k, v in state_dict.items():
+ key = k[prefix_len:]
+ if key == "weight":
+ if not assign_to_params_buffers:
+ v = v.clone()
+ module.weight = torch.nn.Parameter(v, requires_grad=False)
+ elif bias_shape is not None and key == "bias" and v is not None:
+ if not assign_to_params_buffers:
+ v = v.clone()
+ module.bias = torch.nn.Parameter(v, requires_grad=False)
+ else:
+ unexpected_keys.append(k)
+
+ if module.weight is None:
+ module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
+ missing_keys.append(prefix + "weight")
+
+ if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
+ module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
+ missing_keys.append(prefix + "bias")
+
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
@@ -333,29 +360,16 @@ class disable_weight_init:
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
- assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
- prefix_len = len(prefix)
- for k,v in state_dict.items():
- if k[prefix_len:] == "weight":
- if not assign_to_params_buffers:
- v = v.clone()
- self.weight = torch.nn.Parameter(v, requires_grad=False)
- elif k[prefix_len:] == "bias" and v is not None:
- if not assign_to_params_buffers:
- v = v.clone()
- self.bias = torch.nn.Parameter(v, requires_grad=False)
- else:
- unexpected_keys.append(k)
-
- #Reconcile default construction of the weight if its missing.
- if self.weight is None:
- v = torch.zeros(self.in_features, self.out_features)
- self.weight = torch.nn.Parameter(v, requires_grad=False)
- missing_keys.append(prefix+"weight")
- if self.bias is None and self.comfy_need_lazy_init_bias:
- v = torch.zeros(self.out_features,)
- self.bias = torch.nn.Parameter(v, requires_grad=False)
- missing_keys.append(prefix+"bias")
+ disable_weight_init._lazy_load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ missing_keys,
+ unexpected_keys,
+ weight_shape=(self.in_features, self.out_features),
+ bias_shape=(self.out_features,),
+ )
def reset_parameters(self):
@@ -547,6 +561,48 @@ class disable_weight_init:
return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
+ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
+ norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
+ _freeze=False, device=None, dtype=None):
+ if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
+ super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
+ norm_type, scale_grad_by_freq, sparse, _weight,
+ _freeze, device, dtype)
+ return
+
+ torch.nn.Module.__init__(self)
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.max_norm = max_norm
+ self.norm_type = norm_type
+ self.scale_grad_by_freq = scale_grad_by_freq
+ self.sparse = sparse
+ # Keep shape/dtype visible for module introspection without reserving storage.
+ embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
+ self.weight = torch.nn.Parameter(
+ torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
+ requires_grad=False,
+ )
+ self.bias = None
+ self.weight_comfy_model_dtype = dtype
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys, error_msgs):
+
+ if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+ disable_weight_init._lazy_load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ missing_keys,
+ unexpected_keys,
+ weight_shape=(self.num_embeddings, self.embedding_dim),
+ )
+
def reset_parameters(self):
self.bias = None
return None
diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py
index 8acc327a7..f6fb806c4 100644
--- a/comfy/pinned_memory.py
+++ b/comfy/pinned_memory.py
@@ -1,6 +1,7 @@
-import torch
import comfy.model_management
import comfy.memory_management
+import comfy_aimdo.host_buffer
+import comfy_aimdo.torch
from comfy.cli_args import args
@@ -12,18 +13,31 @@ def pin_memory(module):
return
#FIXME: This is a RAM cache trigger event
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
- pin = torch.empty((size,), dtype=torch.uint8)
- if comfy.model_management.pin_memory(pin):
- module._pin = pin
- else:
+
+ if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
module.pin_failed = True
return False
+
+ try:
+ hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
+ except RuntimeError:
+ module.pin_failed = True
+ return False
+
+ module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
+ module._pin_hostbuf = hostbuf
+ comfy.model_management.TOTAL_PINNED_MEMORY += size
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
- comfy.model_management.unpin_memory(module._pin)
+
+ comfy.model_management.TOTAL_PINNED_MEMORY -= size
+ if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
+ comfy.model_management.TOTAL_PINNED_MEMORY = 0
+
del module._pin
+ del module._pin_hostbuf
return size
diff --git a/comfy/utils.py b/comfy/utils.py
index 6e1d14419..9931fe3b4 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -20,6 +20,8 @@
import torch
import math
import struct
+import ctypes
+import os
import comfy.memory_management
import safetensors.torch
import numpy as np
@@ -32,7 +34,7 @@ from einops import rearrange
from comfy.cli_args import args
import json
import time
-import mmap
+import threading
import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
@@ -81,14 +83,17 @@ _TYPES = {
}
def load_safetensors(ckpt):
- f = open(ckpt, "rb")
- mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
- mv = memoryview(mapping)
+ import comfy_aimdo.model_mmap
- header_size = struct.unpack("=14.2.0
comfy-kitchen>=0.2.8
-comfy-aimdo>=0.2.10
+comfy-aimdo>=0.2.11
requests
simpleeval>=1.0.0
blake3
From 16cd8d8a8f5f16ce7e5f929fdba9f783990254ea Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 13 Mar 2026 19:33:28 -0700
Subject: [PATCH 23/80] Update README. (#12931)
---
README.md | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 56b7966cf..62c4f528c 100644
--- a/README.md
+++ b/README.md
@@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started
+### Local
+
#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- Available on Windows & macOS.
@@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
#### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
-## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
-See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
+### Cloud
+
+#### [Comfy Cloud](https://www.comfy.org/cloud)
+- Our official paid cloud version for those who can't afford local hardware.
+
+## Examples
+See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
From 4c4be1bba5ae714c6f455a49757bd7fc2e32c577 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Sat, 14 Mar 2026 07:53:00 -0700
Subject: [PATCH 24/80] comfy-aimdo 0.2.12 (#12941)
comfy-aimdo 0.2.12 fixes support for non-ASCII filepaths in the new
mmap helper.
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 52bc0fd12..c32a765a0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,7 +23,7 @@ SQLAlchemy
filelock
av>=14.2.0
comfy-kitchen>=0.2.8
-comfy-aimdo>=0.2.11
+comfy-aimdo>=0.2.12
requests
simpleeval>=1.0.0
blake3
From e0982a7174a9cacb0c3cd3fb6bd1f8e06d9aaf51 Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Sat, 14 Mar 2026 15:25:09 -0700
Subject: [PATCH 25/80] fix: use no-store cache headers to prevent stale
frontend chunks (#12911)
After a frontend update (e.g. nightly build), browsers could load
outdated cached index.html and JS/CSS chunks, causing dynamically
imported modules to fail with MIME type errors and vite:preloadError.
Hard refresh (Ctrl+Shift+R) was insufficient to fix the issue because
Cache-Control: no-cache still allows the browser to cache and
revalidate via ETags. aiohttp's FileResponse auto-generates ETags
based on file mtime+size, which may not change after pip reinstall,
so the browser gets 304 Not Modified and serves stale content.
Clearing ALL site data in DevTools did fix it, confirming the HTTP
cache was the root cause.
The fix changes:
- index.html: no-cache -> no-store, must-revalidate
- JS/CSS/JSON entry points: no-cache -> no-store
no-store instructs browsers to never cache these responses, ensuring
every page load fetches the current index.html with correct chunk
references. This is a small tradeoff (~5KB re-download per page load)
for guaranteed correctness after updates.
---
middleware/cache_middleware.py | 2 +-
server.py | 2 +-
tests-unit/server_test/test_cache_control.py | 16 ++++++++--------
3 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/middleware/cache_middleware.py b/middleware/cache_middleware.py
index f02135369..7a18821b0 100644
--- a/middleware/cache_middleware.py
+++ b/middleware/cache_middleware.py
@@ -32,7 +32,7 @@ async def cache_control(
)
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
- response.headers.setdefault("Cache-Control", "no-cache")
+ response.headers.setdefault("Cache-Control", "no-store")
return response
# Early return for non-image files - no cache headers needed
diff --git a/server.py b/server.py
index 76904ebc9..85a8964be 100644
--- a/server.py
+++ b/server.py
@@ -310,7 +310,7 @@ class PromptServer():
@routes.get("/")
async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
- response.headers['Cache-Control'] = 'no-cache'
+ response.headers['Cache-Control'] = 'no-store, must-revalidate'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
diff --git a/tests-unit/server_test/test_cache_control.py b/tests-unit/server_test/test_cache_control.py
index fa68d9408..1d0366387 100644
--- a/tests-unit/server_test/test_cache_control.py
+++ b/tests-unit/server_test/test_cache_control.py
@@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
},
# JavaScript/CSS scenarios
{
- "name": "js_no_cache",
+ "name": "js_no_store",
"path": "/script.js",
"status": 200,
- "expected_cache": "no-cache",
+ "expected_cache": "no-store",
"should_have_header": True,
},
{
- "name": "css_no_cache",
+ "name": "css_no_store",
"path": "/styles.css",
"status": 200,
- "expected_cache": "no-cache",
+ "expected_cache": "no-store",
"should_have_header": True,
},
{
- "name": "index_json_no_cache",
+ "name": "index_json_no_store",
"path": "/api/index.json",
"status": 200,
- "expected_cache": "no-cache",
+ "expected_cache": "no-store",
"should_have_header": True,
},
{
- "name": "localized_index_json_no_cache",
+ "name": "localized_index_json_no_store",
"path": "/templates/index.zh.json",
"status": 200,
- "expected_cache": "no-cache",
+ "expected_cache": "no-store",
"should_have_header": True,
},
# Non-matching files
From 1c5db7397d59eace38acef078b618c2f04e4e7fe Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Sun, 15 Mar 2026 00:36:29 +0200
Subject: [PATCH 26/80] feat: Support mxfp8 (#12907)
---
comfy/float.py | 36 ++++++++++++++++++++++++++++++
comfy/model_management.py | 13 +++++++++++
comfy/ops.py | 19 ++++++++++++++++
comfy/quant_ops.py | 47 +++++++++++++++++++++++++++++++++++++++
4 files changed, 115 insertions(+)
diff --git a/comfy/float.py b/comfy/float.py
index 88c47cd80..184b3d6d0 100644
--- a/comfy/float.py
+++ b/comfy/float.py
@@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False)
+
+
+def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
+ def roundup(x_val, multiple):
+ return ((x_val + multiple - 1) // multiple) * multiple
+
+ if pad_32x:
+ rows, cols = x.shape
+ padded_rows = roundup(rows, 32)
+ padded_cols = roundup(cols, 32)
+ if padded_rows != rows or padded_cols != cols:
+ x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
+
+ F8_E4M3_MAX = 448.0
+ E8M0_BIAS = 127
+ BLOCK_SIZE = 32
+
+ rows, cols = x.shape
+ x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
+ max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
+
+ # E8M0 block scales (power-of-2 exponents)
+ scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
+ exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
+ block_scales_e8m0 = exp_biased.to(torch.uint8)
+
+ zero_mask = (max_abs == 0)
+ block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
+ block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
+
+ # Scale per-block then stochastic round
+ data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
+ output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
+
+ block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
+ return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 4d5851bc0..bb77cff47 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -1712,6 +1712,19 @@ def supports_nvfp4_compute(device=None):
return True
+def supports_mxfp8_compute(device=None):
+ if not is_nvidia():
+ return False
+
+ if torch_version_numeric < (2, 10):
+ return False
+
+ props = torch.cuda.get_device_properties(device)
+ if props.major < 10:
+ return False
+
+ return True
+
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
diff --git a/comfy/ops.py b/comfy/ops.py
index 3f2da4e63..59c0df87d 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -857,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features),
)
+ elif self.quant_format == "mxfp8":
+ # MXFP8: E8M0 block scales stored as uint8 in safetensors
+ block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
+ dtype=torch.uint8)
+
+ if block_scale is None:
+ raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
+
+ block_scale = block_scale.view(torch.float8_e8m0fnu)
+
+ params = layout_cls.Params(
+ scale=block_scale,
+ orig_dtype=MixedPrecisionOps._compute_dtype,
+ orig_shape=(self.out_features, self.in_features),
+ )
+
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@@ -1006,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
+ mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
+ if not mxfp8_compute:
+ disabled.add("mxfp8")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")
diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py
index 15a4f457b..42ee08fb2 100644
--- a/comfy/quant_ops.py
+++ b/comfy/quant_ops.py
@@ -43,6 +43,18 @@ except ImportError as e:
def get_layout_class(name):
return None
+_CK_MXFP8_AVAILABLE = False
+if _CK_AVAILABLE:
+ try:
+ from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
+ _CK_MXFP8_AVAILABLE = True
+ except ImportError:
+ logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
+
+if not _CK_MXFP8_AVAILABLE:
+ class _CKMxfp8Layout:
+ pass
+
import comfy.float
# ==============================================================================
@@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
+class TensorCoreMXFP8Layout(_CKMxfp8Layout):
+ @classmethod
+ def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
+ if tensor.dim() != 2:
+ raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
+
+ orig_dtype = tensor.dtype
+ orig_shape = tuple(tensor.shape)
+
+ padded_shape = cls.get_padded_shape(orig_shape)
+ needs_padding = padded_shape != orig_shape
+
+ if stochastic_rounding > 0:
+ qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
+ else:
+ qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
+
+ params = cls.Params(
+ scale=block_scale,
+ orig_dtype=orig_dtype,
+ orig_shape=orig_shape,
+ )
+ return qdata, params
+
+
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
+if _CK_MXFP8_AVAILABLE:
+ register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@@ -157,6 +196,14 @@ QUANT_ALGOS = {
},
}
+if _CK_MXFP8_AVAILABLE:
+ QUANT_ALGOS["mxfp8"] = {
+ "storage_t": torch.float8_e4m3fn,
+ "parameters": {"weight_scale", "input_scale"},
+ "comfy_tensor_layout": "TensorCoreMXFP8Layout",
+ "group_size": 32,
+ }
+
# ==============================================================================
# Re-exports for backward compatibility
From c711b8f437923d9e732fa1d22ed101f81575683c Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sat, 14 Mar 2026 16:18:19 -0700
Subject: [PATCH 27/80] Add --fp16-intermediates to use fp16 for intermediate
values between nodes (#12953)
This is an experimental WIP option that might not work in your workflow but
should lower memory usage if it does.
Currently only the VAE and the load image node will output in fp16 when
this option is turned on.
---
comfy/cli_args.py | 2 ++
comfy/model_management.py | 6 ++++++
comfy/sd.py | 27 +++++++++++++++------------
nodes.py | 6 ++++--
4 files changed, 27 insertions(+), 14 deletions(-)
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index e9832acaf..0a0bf2f30 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
+parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
+
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
diff --git a/comfy/model_management.py b/comfy/model_management.py
index bb77cff47..442d5a40a 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -1050,6 +1050,12 @@ def intermediate_device():
else:
return torch.device("cpu")
+def intermediate_dtype():
+ if args.fp16_intermediates:
+ return torch.float16
+ else:
+ return torch.float32
+
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
diff --git a/comfy/sd.py b/comfy/sd.py
index adcd67767..4d427bb9a 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -871,13 +871,16 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
+ def vae_output_dtype(self):
+ return model_management.intermediate_dtype()
+
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
- decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
+ decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@@ -887,16 +890,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
- decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
+ decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
- decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
+ decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
- decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
+ decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@@ -905,7 +908,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
- encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
+ encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@@ -914,7 +917,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
- encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
+ encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
@@ -923,7 +926,7 @@ class VAE:
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
- encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
+ encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
@@ -932,7 +935,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
- encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
+ encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}):
@@ -950,9 +953,9 @@ class VAE:
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
- out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
+ out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
if pixel_samples is None:
- pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
+ pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
@@ -1025,9 +1028,9 @@ class VAE:
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
- out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
+ out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
- samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
+ samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
diff --git a/nodes.py b/nodes.py
index eb63f9d44..1e19a8223 100644
--- a/nodes.py
+++ b/nodes.py
@@ -1724,6 +1724,8 @@ class LoadImage:
output_masks = []
w, h = None, None
+ dtype = comfy.model_management.intermediate_dtype()
+
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
@@ -1748,8 +1750,8 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
- output_images.append(image)
- output_masks.append(mask.unsqueeze(0))
+ output_images.append(image.to(dtype=dtype))
+ output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO":
break # ignore all frames except the first one for MPO format
From 4941cd046eb1cd3021708ab7fe4e81e90a7b5dbe Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sat, 14 Mar 2026 16:53:31 -0700
Subject: [PATCH 28/80] Update comfyui-frontend-package to version 1.41.20
(#12954)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index c32a765a0..7e59ef206 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.41.19
+comfyui-frontend-package==1.41.20
comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3
torch
From 0904cc3fe5a551e3716851f12a568e481badd301 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Sun, 15 Mar 2026 03:09:09 +0200
Subject: [PATCH 29/80] LTXV: Accumulate VAE decode results on
intermediate_device (#12955)
---
comfy/ldm/lightricks/vae/causal_video_autoencoder.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index 5b57dfc5e..9f14f64a5 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
+import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
@@ -536,7 +537,7 @@ class Decoder(nn.Module):
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
- output.append(sample)
+ output.append(sample.to(comfy.model_management.intermediate_device()))
return
up_block = self.up_blocks[idx]
From 192cb8eeb9f644cda8e52ae24171491228ac8bb1 Mon Sep 17 00:00:00 2001
From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com>
Date: Mon, 16 Mar 2026 03:48:56 +0900
Subject: [PATCH 30/80] bump manager version to 4.1b5 (#12957)
---
manager_requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/manager_requirements.txt b/manager_requirements.txt
index 37a33bd4f..1c5e8f071 100644
--- a/manager_requirements.txt
+++ b/manager_requirements.txt
@@ -1 +1 @@
-comfyui_manager==4.1b4
\ No newline at end of file
+comfyui_manager==4.1b5
\ No newline at end of file
From e84a200a3c68044c2b5d6621ea80d27d1585703f Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Sun, 15 Mar 2026 11:49:49 -0700
Subject: [PATCH 31/80] ops: opt out of deferred weight init if subclassed
(#12967)
If a subclass BYO _load_from_state_dict and doesnt call the super() the
needed default init of these weights is missed and can lead to problems
for uninitialized weights.
---
comfy/ops.py | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index 59c0df87d..f47d4137a 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -336,7 +336,10 @@ class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
- if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
+ # don't trust subclasses that BYO state dict loader to call us.
+ if (not comfy.model_management.WINDOWS
+ or not comfy.memory_management.aimdo_enabled
+ or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
super().__init__(in_features, out_features, bias, device, dtype)
return
@@ -357,7 +360,9 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
- if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
+ if (not comfy.model_management.WINDOWS
+ or not comfy.memory_management.aimdo_enabled
+ or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
@@ -564,7 +569,10 @@ class disable_weight_init:
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
- if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
+ # don't trust subclasses that BYO state dict loader to call us.
+ if (not comfy.model_management.WINDOWS
+ or not comfy.memory_management.aimdo_enabled
+ or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
@@ -590,7 +598,9 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
- if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
+ if (not comfy.model_management.WINDOWS
+ or not comfy.memory_management.aimdo_enabled
+ or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
From d062becb336da8430052381111e952d6ab51d39c Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sun, 15 Mar 2026 12:37:27 -0700
Subject: [PATCH 32/80] Make EmptyLatentImage follow intermediate dtype.
(#12974)
---
nodes.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/nodes.py b/nodes.py
index 1e19a8223..dd9298b18 100644
--- a/nodes.py
+++ b/nodes.py
@@ -1211,9 +1211,6 @@ class GLIGENTextBoxApply:
return (c, )
class EmptyLatentImage:
- def __init__(self):
- self.device = comfy.model_management.intermediate_device()
-
@classmethod
def INPUT_TYPES(s):
return {
@@ -1232,7 +1229,7 @@ class EmptyLatentImage:
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
def generate(self, width, height, batch_size=1):
- latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
From 3814bf4454ef3302fd7f91750d7a194dcf979630 Mon Sep 17 00:00:00 2001
From: lostdisc <194321775+lostdisc@users.noreply.github.com>
Date: Sun, 15 Mar 2026 15:45:30 -0400
Subject: [PATCH 33/80] Enable Pytorch Attention for gfx1150 (#12973)
---
comfy/model_management.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 442d5a40a..a4af5ddb2 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -400,7 +400,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
- if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
+ if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
From 593be209a45a8a306c26de550e240a363de405a7 Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Sun, 15 Mar 2026 16:18:04 -0700
Subject: [PATCH 34/80] feat: add essentials_category to nodes and blueprints
for Essentials tab (#12573)
* feat: add essentials_category to nodes and blueprints for Essentials tab
Add ESSENTIALS_CATEGORY or essentials_category to 12 node classes and all
36 blueprint JSONs. Update SubgraphEntry TypedDict and subgraph_manager to
extract and pass through the field.
Fixes COM-15221
Amp-Thread-ID: https://ampcode.com/threads/T-019c83de-f7ab-7779-a451-0ba5940b56a9
* fix: import NotRequired from typing_extensions for Python 3.10 compat
* refactor: keep only node class ESSENTIALS_CATEGORY, remove blueprint/subgraph changes
Frontend will own blueprint categorization separately.
* fix: remove essentials_category from CreateVideo (not in spec)
---------
Co-authored-by: guill
---
comfy_api_nodes/nodes_kling.py | 1 +
comfy_api_nodes/nodes_recraft.py | 1 +
comfy_extras/nodes_audio.py | 2 ++
comfy_extras/nodes_image_compare.py | 1 +
comfy_extras/nodes_images.py | 1 +
comfy_extras/nodes_post_processing.py | 1 +
nodes.py | 3 +++
7 files changed, 10 insertions(+)
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 8963c335d..9a37ccc53 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -1459,6 +1459,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling",
+ essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py
index 4d1d508fa..c60cfbc4a 100644
--- a/comfy_api_nodes/nodes_recraft.py
+++ b/comfy_api_nodes/nodes_recraft.py
@@ -833,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image",
category="api node/image/Recraft",
+ essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.",
inputs=[
IO.Image.Input("image"),
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 5d8d9bf6f..a395392d8 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -19,6 +19,7 @@ class EmptyLatentAudio(IO.ComfyNode):
node_id="EmptyLatentAudio",
display_name="Empty Latent Audio",
category="latent/audio",
+ essentials_category="Audio",
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
@@ -185,6 +186,7 @@ class SaveAudioMP3(IO.ComfyNode):
search_aliases=["export mp3"],
display_name="Save Audio (MP3)",
category="audio",
+ essentials_category="Audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
diff --git a/comfy_extras/nodes_image_compare.py b/comfy_extras/nodes_image_compare.py
index 8e9f809e6..3d943be67 100644
--- a/comfy_extras/nodes_image_compare.py
+++ b/comfy_extras/nodes_image_compare.py
@@ -14,6 +14,7 @@ class ImageCompare(IO.ComfyNode):
display_name="Image Compare",
description="Compares two images side by side with a slider.",
category="image",
+ essentials_category="Image Tools",
is_experimental=True,
is_output_node=True,
inputs=[
diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py
index 4c57bb5cb..a8223cf8b 100644
--- a/comfy_extras/nodes_images.py
+++ b/comfy_extras/nodes_images.py
@@ -58,6 +58,7 @@ class ImageCropV2(IO.ComfyNode):
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
+ essentials_category="Image Tools",
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py
index 4a0f7141a..06626f9dd 100644
--- a/comfy_extras/nodes_post_processing.py
+++ b/comfy_extras/nodes_post_processing.py
@@ -21,6 +21,7 @@ class Blend(io.ComfyNode):
node_id="ImageBlend",
display_name="Image Blend",
category="image/postprocessing",
+ essentials_category="Image Tools",
inputs=[
io.Image.Input("image1"),
io.Image.Input("image2"),
diff --git a/nodes.py b/nodes.py
index dd9298b18..03dcc9d4a 100644
--- a/nodes.py
+++ b/nodes.py
@@ -81,6 +81,7 @@ class CLIPTextEncode(ComfyNodeABC):
class ConditioningCombine:
+ ESSENTIALS_CATEGORY = "Image Generation"
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
@@ -1778,6 +1779,7 @@ class LoadImage:
return True
class LoadImageMask:
+ ESSENTIALS_CATEGORY = "Image Tools"
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"]
@@ -1886,6 +1888,7 @@ class ImageScale:
return (s,)
class ImageScaleBy:
+ ESSENTIALS_CATEGORY = "Image Tools"
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod
From 2bd4d82b4f19c30dc979a3a16ddae97068e1bdc8 Mon Sep 17 00:00:00 2001
From: Luke Mino-Altherr
Date: Mon, 16 Mar 2026 15:34:04 -0400
Subject: [PATCH 35/80] feat(assets): align local API with cloud spec (#12863)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* feat(assets): align local API with cloud spec
Unify response models, add missing fields, and align input schemas with
the cloud OpenAPI spec at cloud.comfy.org/openapi.
- Replace AssetSummary/AssetDetail/AssetUpdated with single Asset model
- Add is_immutable, metadata (system_metadata), prompt_id fields
- Support mime_type and preview_id in update endpoint
- Make CreateFromHashBody.name optional, add mime_type, require >=1 tag
- Add id/mime_type/preview_id to upload, relax tags to optional
- Rename total_tags → tags in tag add/remove responses
- Add GET /api/assets/tags/refine histogram endpoint
- Add DB migration for system_metadata and prompt_id columns
Co-Authored-By: Claude Opus 4.6
* Fix review issues: tags validation, size nullability, type annotation, hash mismatch check, and add tag histogram tests
- Remove contradictory min_length=1 from CreateFromHashBody.tags default
- Restore size field to int|None=None for proper null semantics
- Add Union type annotation to _build_asset_response result param
- Add hash mismatch validation on idempotent upload path (409 HASH_MISMATCH)
- Add unit tests for list_tag_histogram service function
Amp-Thread-ID: https://ampcode.com/threads/T-019cd993-f43c-704e-b3d7-6cfc3d4d4a80
Co-authored-by: Amp
* Add preview_url to /assets API response using /api/view endpoint
For input and output assets, generate a preview_url pointing to the
existing /api/view endpoint using the asset's filename and tag-derived
type (input/output). Handles subdirectories via subfolder param and
URL-encodes filenames with spaces, unicode, and special characters.
This aligns the OSS backend response with the frontend AssetCard
expectation for thumbnail rendering.
Amp-Thread-ID: https://ampcode.com/threads/T-019cda3f-5c2c-751a-a906-ac6c9153ac5c
Co-authored-by: Amp
* chore: remove unused imports from asset_reference queries
Amp-Thread-ID: https://ampcode.com/threads/T-019cda7d-cb21-77b4-a51b-b965af60208c
Co-authored-by: Amp
* feat: resolve blake3 hashes in /view endpoint via asset database
Amp-Thread-ID: https://ampcode.com/threads/T-019cda7d-cb21-77b4-a51b-b965af60208c
Co-authored-by: Amp
* Register uploaded images in asset database when --enable-assets is set
Add register_file_in_place() service function to ingest module for
registering already-saved files without moving them. Call it from the
/upload/image endpoint to return asset metadata in the response.
Amp-Thread-ID: https://ampcode.com/threads/T-019ce023-3384-7560-bacf-de40b0de0dd2
Co-authored-by: Amp
* Exclude None fields from asset API JSON responses
Add exclude_none=True to model_dump() calls across asset routes to
keep response payloads clean by omitting unset optional fields.
Amp-Thread-ID: https://ampcode.com/threads/T-019ce023-3384-7560-bacf-de40b0de0dd2
Co-authored-by: Amp
* Add comment explaining why /view resolves blake3 hashes
Amp-Thread-ID: https://ampcode.com/threads/T-019ce023-3384-7560-bacf-de40b0de0dd2
Co-authored-by: Amp
* Move blake3 hash resolution to asset_management service
Extract resolve_hash_to_path() into asset_management.py and remove
_resolve_blake3_to_path from server.py. Also revert loopback origin
check to original logic.
Amp-Thread-ID: https://ampcode.com/threads/T-019ce023-3384-7560-bacf-de40b0de0dd2
Co-authored-by: Amp
* Require at least one tag in UploadAssetSpec
Enforce non-empty tags at the Pydantic validation layer so uploads
with no tags are rejected with a 400 before reaching ingest. Adds
test_upload_empty_tags_rejected to cover this case.
Amp-Thread-ID: https://ampcode.com/threads/T-019ce377-8bde-7048-bc28-a9df063409f9
Co-authored-by: Amp
* Add owner_id check to resolve_hash_to_path
Filter asset references by owner visibility so the /view endpoint
only resolves hashes for assets the requesting user can access.
Adds table-driven tests for owner visibility cases.
Amp-Thread-ID: https://ampcode.com/threads/T-019ce377-8bde-7048-bc28-a9df063409f9
Co-authored-by: Amp
* Make ReferenceData.created_at and updated_at required
Remove None defaults and type: ignore comments. Move fields before
optional fields to satisfy dataclass ordering.
Amp-Thread-ID: https://ampcode.com/threads/T-019ce377-8bde-7048-bc28-a9df063409f9
Co-authored-by: Amp
* Fix double commit in create_from_hash
Move mime_type update into _register_existing_asset so it shares a
single transaction with reference creation. Log a warning when the
hash is not found instead of silently returning None.
Amp-Thread-ID: https://ampcode.com/threads/T-019ce377-8bde-7048-bc28-a9df063409f9
Co-authored-by: Amp
* Add exclude_none=True to create/upload responses
Align with get/update/list endpoints for consistent JSON output.
Amp-Thread-ID: https://ampcode.com/threads/T-019ce377-8bde-7048-bc28-a9df063409f9
Co-authored-by: Amp
* Change preview_id to reference asset by reference ID, not content ID
Clients receive preview_id in API responses but could not dereference it
through public routes (which use reference IDs). Now preview_id is a
self-referential FK to asset_references.id so the value is directly
usable in the public API.
Co-Authored-By: Claude Opus 4.6
* Filter soft-deleted and missing refs from visibility queries
list_references_by_asset_id and list_tags_with_usage were not filtering
out deleted_at/is_missing refs, allowing /view?filename=blake3:... to
serve files through hidden references and inflating tag usage counts.
Add list_all_file_paths_by_asset_id for orphan cleanup which
intentionally needs unfiltered access.
Co-Authored-By: Claude Opus 4.6
* Pass preview_id and mime_type through all asset creation fast paths
The duplicate-content upload path and hash-based creation paths were
silently dropping preview_id and mime_type. This wires both fields
through _register_existing_asset, create_from_hash, and all route
call sites so behavior is consistent regardless of whether the asset
content already exists.
Co-Authored-By: Claude Opus 4.6
* Remove unimplemented client-provided ID from upload API
The `id` field on UploadAssetSpec was advertised for idempotent creation
but never actually honored when creating new references. Remove it
rather than implementing the feature.
Co-Authored-By: Claude Opus 4.6
* Make asset mime_type immutable after first ingest
Prevents cross-tenant metadata mutation when multiple references share
the same content-addressed Asset row. mime_type can now only be set when
NULL (first ingest); subsequent attempts to change it are silently ignored.
Co-Authored-By: Claude Opus 4.6
* Use resolved content_type from asset lookup in /view endpoint
The /view endpoint was discarding the content_type computed by
resolve_hash_to_path() and re-guessing from the filename, which
produced wrong results for extensionless files or mismatched extensions.
Co-Authored-By: Claude Opus 4.6
* Merge system+user metadata into filter projection
Extract rebuild_metadata_projection() to build AssetReferenceMeta rows
from {**system_metadata, **user_metadata}, so system-generated metadata
is queryable via metadata_filter and user keys override system keys.
Co-Authored-By: Claude Opus 4.6
* Standardize tag ordering to alphabetical across all endpoints
Co-Authored-By: Claude Opus 4.6
* Derive subfolder tags from path in register_file_in_place
Co-Authored-By: Claude Opus 4.6
* Reject client-provided id, fix preview URLs, rename tags→total_tags
- Reject 'id' field in multipart upload with 400 UNSUPPORTED_FIELD
instead of silently ignoring it
- Build preview URL from the preview asset's own metadata rather than
the parent asset's
- Rename 'tags' to 'total_tags' in TagsAdd/TagsRemove response schemas
for clarity
Co-Authored-By: Claude Opus 4.6
* fix: SQLite migration 0003 FK drop fails on file-backed DBs (MB-2)
Add naming_convention to Base.metadata so Alembic batch-mode reflection
can match unnamed FK constraints created by migration 0002. Pass
naming_convention and render_as_batch=True through env.py online config.
Add migration roundtrip tests (upgrade/downgrade/cycle from baseline).
Amp-Thread-ID: https://ampcode.com/threads/T-019ce466-1683-7471-b6e1-bb078223cda0
Co-authored-by: Amp
* Fix missing tag count for is_missing references and update test for total_tags field
- Allow is_missing=True references to be counted in list_tags_with_usage
when the tag is 'missing', so the missing tag count reflects all
references that have been tagged as missing
- Add update_is_missing_by_asset_id query helper for bulk updates by asset
- Update test_add_and_remove_tags to use 'total_tags' matching the API schema
Amp-Thread-ID: https://ampcode.com/threads/T-019ce482-05e7-7324-a1b0-a56a929cc7ef
Co-authored-by: Amp
* Remove unused imports in scanner.py
Co-Authored-By: Claude Opus 4.6
* Rename prompt_id to job_id on asset_references
Rename the column in the DB model, migration, and service schemas.
The API response emits both job_id and prompt_id (deprecated alias)
for backward compatibility with the cloud API.
Amp-Thread-ID: https://ampcode.com/threads/T-019cef41-60b0-752a-aa3c-ed7f20fda2f7
Co-authored-by: Amp
* Add index on asset_references.preview_id for FK cascade performance
Amp-Thread-ID: https://ampcode.com/threads/T-019cef45-a4d2-7548-86d2-d46bcd3db419
Co-authored-by: Amp
* Add clarifying comments for Asset/AssetReference naming and preview_id
Amp-Thread-ID: https://ampcode.com/threads/T-019cef49-f94e-7348-bf23-9a19ebf65e0d
Co-authored-by: Amp
* Disallow all-null meta rows: add CHECK constraint, skip null values on write
- convert_metadata_to_rows returns [] for None values instead of an all-null row
- Remove dead None branch from _scalar_to_row
- Simplify null filter in common.py to just check for row absence
- Add CHECK constraint ck_asset_reference_meta_has_value to model and migration 0003
Amp-Thread-ID: https://ampcode.com/threads/T-019cef4e-5240-7749-bb25-1f17fcf9c09c
Co-authored-by: Amp
* Remove dead None guards on result.asset in upload handler
register_file_in_place guarantees a non-None asset, so the
'if result.asset else None' checks were unreachable.
Amp-Thread-ID: https://ampcode.com/threads/T-019cef5b-4cf8-723c-8a98-8fb8f333c133
Co-authored-by: Amp
* Remove mime_type from asset update API
Clients can no longer modify mime_type after asset creation via the
PUT /api/assets/{id} endpoint. This reduces the risk of mime_type
spoofing. The internal update_asset_hash_and_mime function remains
available for server-side use (e.g., enrichment).
Amp-Thread-ID: https://ampcode.com/threads/T-019cef5d-8d61-75cc-a1c6-2841ac395648
Co-authored-by: Amp
* Fix migration constraint naming double-prefix and NULL in mixed metadata lists
- Use fully-rendered constraint names in migration 0003 to avoid the
naming convention doubling the ck_ prefix on batch operations.
- Add table_args to downgrade so SQLite batch mode can find the CHECK
constraint (not exposed by SQLite reflection).
- Fix model CheckConstraint name to use bare 'has_value' (convention
auto-prefixes).
- Skip None items when converting metadata lists to rows, preventing
all-NULL rows that violate the has_value check constraint.
Amp-Thread-ID: https://ampcode.com/threads/T-019cef87-94f9-7172-a6af-c6282290ce4f
Co-authored-by: Amp
---------
Co-authored-by: Claude Opus 4.6
Co-authored-by: Amp
---
alembic_db/env.py | 7 +-
.../versions/0003_add_metadata_job_id.py | 98 +++++++
app/assets/api/routes.py | 172 +++++++-----
app/assets/api/schemas_in.py | 64 ++++-
app/assets/api/schemas_out.py | 63 ++---
app/assets/api/upload.py | 14 +
app/assets/database/models.py | 25 +-
app/assets/database/queries/__init__.py | 12 +
app/assets/database/queries/asset.py | 4 +-
.../database/queries/asset_reference.py | 247 +++++++++---------
app/assets/database/queries/common.py | 79 +++++-
app/assets/database/queries/tags.py | 70 ++++-
app/assets/scanner.py | 6 +-
app/assets/services/asset_management.py | 72 ++++-
app/assets/services/ingest.py | 126 +++++++--
app/assets/services/schemas.py | 6 +-
app/assets/services/tagging.py | 23 ++
app/database/models.py | 11 +-
server.py | 79 ++++--
tests-unit/app_test/test_migrations.py | 57 ++++
tests-unit/assets_test/queries/test_asset.py | 43 +++
.../assets_test/queries/test_asset_info.py | 21 +-
.../assets_test/queries/test_metadata.py | 51 +++-
.../services/test_asset_management.py | 54 +++-
.../assets_test/services/test_ingest.py | 12 +-
.../services/test_tag_histogram.py | 123 +++++++++
tests-unit/assets_test/test_uploads.py | 9 +
27 files changed, 1218 insertions(+), 330 deletions(-)
create mode 100644 alembic_db/versions/0003_add_metadata_job_id.py
create mode 100644 tests-unit/app_test/test_migrations.py
create mode 100644 tests-unit/assets_test/services/test_tag_histogram.py
diff --git a/alembic_db/env.py b/alembic_db/env.py
index 4d7770679..4ce37c012 100644
--- a/alembic_db/env.py
+++ b/alembic_db/env.py
@@ -8,7 +8,7 @@ from alembic import context
config = context.config
-from app.database.models import Base
+from app.database.models import Base, NAMING_CONVENTION
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
@@ -51,7 +51,10 @@ def run_migrations_online() -> None:
with connectable.connect() as connection:
context.configure(
- connection=connection, target_metadata=target_metadata
+ connection=connection,
+ target_metadata=target_metadata,
+ render_as_batch=True,
+ naming_convention=NAMING_CONVENTION,
)
with context.begin_transaction():
diff --git a/alembic_db/versions/0003_add_metadata_job_id.py b/alembic_db/versions/0003_add_metadata_job_id.py
new file mode 100644
index 000000000..2a14ee924
--- /dev/null
+++ b/alembic_db/versions/0003_add_metadata_job_id.py
@@ -0,0 +1,98 @@
+"""
+Add system_metadata and job_id columns to asset_references.
+Change preview_id FK from assets.id to asset_references.id.
+
+Revision ID: 0003_add_metadata_job_id
+Revises: 0002_merge_to_asset_references
+Create Date: 2026-03-09
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+from app.database.models import NAMING_CONVENTION
+
+revision = "0003_add_metadata_job_id"
+down_revision = "0002_merge_to_asset_references"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ with op.batch_alter_table("asset_references") as batch_op:
+ batch_op.add_column(
+ sa.Column("system_metadata", sa.JSON(), nullable=True)
+ )
+ batch_op.add_column(
+ sa.Column("job_id", sa.String(length=36), nullable=True)
+ )
+
+ # Change preview_id FK from assets.id to asset_references.id (self-ref).
+ # Existing values are asset-content IDs that won't match reference IDs,
+ # so null them out first.
+ op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
+ with op.batch_alter_table(
+ "asset_references", naming_convention=NAMING_CONVENTION
+ ) as batch_op:
+ batch_op.drop_constraint(
+ "fk_asset_references_preview_id_assets", type_="foreignkey"
+ )
+ batch_op.create_foreign_key(
+ "fk_asset_references_preview_id_asset_references",
+ "asset_references",
+ ["preview_id"],
+ ["id"],
+ ondelete="SET NULL",
+ )
+ batch_op.create_index(
+ "ix_asset_references_preview_id", ["preview_id"]
+ )
+
+ # Purge any all-null meta rows before adding the constraint
+ op.execute(
+ "DELETE FROM asset_reference_meta"
+ " WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
+ )
+ with op.batch_alter_table("asset_reference_meta") as batch_op:
+ batch_op.create_check_constraint(
+ "ck_asset_reference_meta_has_value",
+ "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
+ )
+
+
+def downgrade() -> None:
+ # SQLite doesn't reflect CHECK constraints, so we must declare it
+ # explicitly via table_args for the batch recreate to find it.
+ # Use the fully-rendered constraint name to avoid the naming convention
+ # doubling the prefix.
+ with op.batch_alter_table(
+ "asset_reference_meta",
+ table_args=[
+ sa.CheckConstraint(
+ "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
+ name="ck_asset_reference_meta_has_value",
+ ),
+ ],
+ ) as batch_op:
+ batch_op.drop_constraint(
+ "ck_asset_reference_meta_has_value", type_="check"
+ )
+
+ with op.batch_alter_table(
+ "asset_references", naming_convention=NAMING_CONVENTION
+ ) as batch_op:
+ batch_op.drop_index("ix_asset_references_preview_id")
+ batch_op.drop_constraint(
+ "fk_asset_references_preview_id_asset_references", type_="foreignkey"
+ )
+ batch_op.create_foreign_key(
+ "fk_asset_references_preview_id_assets",
+ "assets",
+ ["preview_id"],
+ ["id"],
+ ondelete="SET NULL",
+ )
+
+ with op.batch_alter_table("asset_references") as batch_op:
+ batch_op.drop_column("job_id")
+ batch_op.drop_column("system_metadata")
diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py
index 40dee9f46..68126b6a5 100644
--- a/app/assets/api/routes.py
+++ b/app/assets/api/routes.py
@@ -13,6 +13,7 @@ from pydantic import ValidationError
import folder_paths
from app import user_manager
from app.assets.api import schemas_in, schemas_out
+from app.assets.services import schemas
from app.assets.api.schemas_in import (
AssetValidationError,
UploadError,
@@ -38,6 +39,7 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
+from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -122,6 +124,61 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at"
+def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
+ """Build a /api/view preview URL from asset tags and user_metadata filename."""
+ if not user_metadata:
+ return None
+ filename = user_metadata.get("filename")
+ if not filename:
+ return None
+
+ if "input" in tags:
+ view_type = "input"
+ elif "output" in tags:
+ view_type = "output"
+ else:
+ return None
+
+ subfolder = ""
+ if "/" in filename:
+ subfolder, filename = filename.rsplit("/", 1)
+
+ encoded_filename = urllib.parse.quote(filename, safe="")
+ url = f"/api/view?type={view_type}&filename={encoded_filename}"
+ if subfolder:
+ url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
+ return url
+
+
+def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
+ """Build an Asset response from a service result."""
+ if result.ref.preview_id:
+ preview_detail = get_asset_detail(result.ref.preview_id)
+ if preview_detail:
+ preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
+ else:
+ preview_url = None
+ else:
+ preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
+ return schemas_out.Asset(
+ id=result.ref.id,
+ name=result.ref.name,
+ asset_hash=result.asset.hash if result.asset else None,
+ size=int(result.asset.size_bytes) if result.asset else None,
+ mime_type=result.asset.mime_type if result.asset else None,
+ tags=result.tags,
+ preview_url=preview_url,
+ preview_id=result.ref.preview_id,
+ user_metadata=result.ref.user_metadata or {},
+ metadata=result.ref.system_metadata,
+ job_id=result.ref.job_id,
+ prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
+ created_at=result.ref.created_at,
+ updated_at=result.ref.updated_at,
+ last_access_time=result.ref.last_access_time,
+ )
+
+
@ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response:
@@ -164,20 +221,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order,
)
- summaries = [
- schemas_out.AssetSummary(
- id=item.ref.id,
- name=item.ref.name,
- asset_hash=item.asset.hash if item.asset else None,
- size=int(item.asset.size_bytes) if item.asset else None,
- mime_type=item.asset.mime_type if item.asset else None,
- tags=item.tags,
- created_at=item.ref.created_at,
- updated_at=item.ref.updated_at,
- last_access_time=item.ref.last_access_time,
- )
- for item in result.items
- ]
+ summaries = [_build_asset_response(item) for item in result.items]
payload = schemas_out.AssetsList(
assets=summaries,
@@ -207,18 +251,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id},
)
- payload = schemas_out.AssetDetail(
- id=result.ref.id,
- name=result.ref.name,
- asset_hash=result.asset.hash if result.asset else None,
- size=int(result.asset.size_bytes) if result.asset else None,
- mime_type=result.asset.mime_type if result.asset else None,
- tags=result.tags,
- user_metadata=result.ref.user_metadata or {},
- preview_id=result.ref.preview_id,
- created_at=result.ref.created_at,
- last_access_time=result.ref.last_access_time,
- )
+ payload = _build_asset_response(result)
except ValueError as e:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@@ -230,7 +263,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
- return web.json_response(payload.model_dump(mode="json"), status=200)
+ return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
@@ -312,32 +345,31 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON."
)
+ # Derive name from hash if not provided
+ name = body.name
+ if name is None:
+ name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
+
result = create_from_hash(
hash_str=body.hash,
- name=body.name,
+ name=name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
+ mime_type=body.mime_type,
+ preview_id=body.preview_id,
)
if result is None:
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
)
+ asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated(
- id=result.ref.id,
- name=result.ref.name,
- asset_hash=result.asset.hash,
- size=int(result.asset.size_bytes),
- mime_type=result.asset.mime_type,
- tags=result.tags,
- user_metadata=result.ref.user_metadata or {},
- preview_id=result.ref.preview_id,
- created_at=result.ref.created_at,
- last_access_time=result.ref.last_access_time,
+ **asset.model_dump(),
created_new=result.created_new,
)
- return web.json_response(payload_out.model_dump(mode="json"), status=201)
+ return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
@ROUTES.post("/api/assets")
@@ -358,6 +390,8 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash,
+ "mime_type": parsed.provided_mime_type,
+ "preview_id": parsed.provided_preview_id,
}
)
except ValidationError as ve:
@@ -386,6 +420,8 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
+ mime_type=spec.mime_type,
+ preview_id=spec.preview_id,
)
if result is None:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -410,6 +446,8 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name,
owner_id=owner_id,
expected_hash=spec.hash,
+ mime_type=spec.mime_type,
+ preview_id=spec.preview_id,
)
except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path)
@@ -428,21 +466,13 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
- payload = schemas_out.AssetCreated(
- id=result.ref.id,
- name=result.ref.name,
- asset_hash=result.asset.hash,
- size=int(result.asset.size_bytes),
- mime_type=result.asset.mime_type,
- tags=result.tags,
- user_metadata=result.ref.user_metadata or {},
- preview_id=result.ref.preview_id,
- created_at=result.ref.created_at,
- last_access_time=result.ref.last_access_time,
+ asset = _build_asset_response(result)
+ payload_out = schemas_out.AssetCreated(
+ **asset.model_dump(),
created_new=result.created_new,
)
status = 201 if result.created_new else 200
- return web.json_response(payload.model_dump(mode="json"), status=status)
+ return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@@ -464,15 +494,9 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
+ preview_id=body.preview_id,
)
- payload = schemas_out.AssetUpdated(
- id=result.ref.id,
- name=result.ref.name,
- asset_hash=result.asset.hash if result.asset else None,
- tags=result.tags,
- user_metadata=result.ref.user_metadata or {},
- updated_at=result.ref.updated_at,
- )
+ payload = _build_asset_response(result)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve:
@@ -486,7 +510,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
- return web.json_response(payload.model_dump(mode="json"), status=200)
+ return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
@@ -555,7 +579,7 @@ async def get_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
)
- return web.json_response(payload.model_dump(mode="json"))
+ return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
@@ -603,7 +627,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
- return web.json_response(payload.model_dump(mode="json"), status=200)
+ return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
@@ -650,7 +674,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
- return web.json_response(payload.model_dump(mode="json"), status=200)
+ return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
+
+
+@ROUTES.get("/api/assets/tags/refine")
+@_require_assets_feature_enabled
+async def get_tags_refine(request: web.Request) -> web.Response:
+ """GET request to get tag histogram for filtered assets."""
+ query_dict = get_query_dict(request)
+ try:
+ q = schemas_in.TagsRefineQuery.model_validate(query_dict)
+ except ValidationError as ve:
+ return _build_validation_error_response("INVALID_QUERY", ve)
+
+ tag_counts = list_tag_histogram(
+ owner_id=USER_MANAGER.get_request_user_id(request),
+ include_tags=q.include_tags,
+ exclude_tags=q.exclude_tags,
+ name_contains=q.name_contains,
+ metadata_filter=q.metadata_filter,
+ limit=q.limit,
+ )
+ payload = schemas_out.TagHistogram(tag_counts=tag_counts)
+ return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.post("/api/assets/seed")
diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py
index d255c938e..186a6ae1e 100644
--- a/app/assets/api/schemas_in.py
+++ b/app/assets/api/schemas_in.py
@@ -45,6 +45,8 @@ class ParsedUpload:
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
+ provided_mime_type: str | None = None
+ provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel):
@@ -98,11 +100,17 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
+ preview_id: str | None = None # references an asset_reference id, not an asset id
@model_validator(mode="after")
def _validate_at_least_one_field(self):
- if self.name is None and self.user_metadata is None:
- raise ValueError("Provide at least one of: name, user_metadata.")
+ if all(
+ v is None
+ for v in (self.name, self.user_metadata, self.preview_id)
+ ):
+ raise ValueError(
+ "Provide at least one of: name, user_metadata, preview_id."
+ )
return self
@@ -110,9 +118,11 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
- name: str
+ name: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
+ mime_type: str | None = None
+ preview_id: str | None = None # references an asset_reference id, not an asset id
@field_validator("hash")
@classmethod
@@ -138,6 +148,44 @@ class CreateFromHashBody(BaseModel):
return []
+class TagsRefineQuery(BaseModel):
+ include_tags: list[str] = Field(default_factory=list)
+ exclude_tags: list[str] = Field(default_factory=list)
+ name_contains: str | None = None
+ metadata_filter: dict[str, Any] | None = None
+ limit: conint(ge=1, le=1000) = 100
+
+ @field_validator("include_tags", "exclude_tags", mode="before")
+ @classmethod
+ def _split_csv_tags(cls, v):
+ if v is None:
+ return []
+ if isinstance(v, str):
+ return [t.strip() for t in v.split(",") if t.strip()]
+ if isinstance(v, list):
+ out: list[str] = []
+ for item in v:
+ if isinstance(item, str):
+ out.extend([t.strip() for t in item.split(",") if t.strip()])
+ return out
+ return v
+
+ @field_validator("metadata_filter", mode="before")
+ @classmethod
+ def _parse_metadata_json(cls, v):
+ if v is None or isinstance(v, dict):
+ return v
+ if isinstance(v, str) and v.strip():
+ try:
+ parsed = json.loads(v)
+ except Exception as e:
+ raise ValueError(f"metadata_filter must be JSON: {e}") from e
+ if not isinstance(parsed, dict):
+ raise ValueError("metadata_filter must be a JSON object")
+ return parsed
+ return None
+
+
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@@ -186,21 +234,25 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- - tags: ordered; first is root ('models'|'input'|'output');
+ - tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:' for validation / fast-path
+ - mime_type: optional MIME type override
+ - preview_id: optional asset_reference ID for preview
Files are stored using the content hash as filename stem.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
- tags: list[str] = Field(..., min_length=1)
+ tags: list[str] = Field(default_factory=list)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
+ mime_type: str | None = Field(default=None)
+ preview_id: str | None = Field(default=None) # references an asset_reference id
@field_validator("hash", mode="before")
@classmethod
@@ -279,7 +331,7 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
- raise ValueError("tags must be provided and non-empty")
+ raise ValueError("at least one tag is required for uploads")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py
index f36447856..d99b1098d 100644
--- a/app/assets/api/schemas_out.py
+++ b/app/assets/api/schemas_out.py
@@ -4,7 +4,10 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
-class AssetSummary(BaseModel):
+class Asset(BaseModel):
+ """API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
+ ``id`` here is the AssetReference id, not the content-addressed Asset id."""
+
id: str
name: str
asset_hash: str | None = None
@@ -12,8 +15,14 @@ class AssetSummary(BaseModel):
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
- created_at: datetime | None = None
- updated_at: datetime | None = None
+ preview_id: str | None = None # references an asset_reference id, not an asset id
+ user_metadata: dict[str, Any] = Field(default_factory=dict)
+ is_immutable: bool = False
+ metadata: dict[str, Any] | None = None
+ job_id: str | None = None
+ prompt_id: str | None = None # deprecated: use job_id
+ created_at: datetime
+ updated_at: datetime
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@@ -23,50 +32,16 @@ class AssetSummary(BaseModel):
return v.isoformat() if v else None
+class AssetCreated(Asset):
+ created_new: bool
+
+
class AssetsList(BaseModel):
- assets: list[AssetSummary]
+ assets: list[Asset]
total: int
has_more: bool
-class AssetUpdated(BaseModel):
- id: str
- name: str
- asset_hash: str | None = None
- tags: list[str] = Field(default_factory=list)
- user_metadata: dict[str, Any] = Field(default_factory=dict)
- updated_at: datetime | None = None
-
- model_config = ConfigDict(from_attributes=True)
-
- @field_serializer("updated_at")
- def _serialize_updated_at(self, v: datetime | None, _info):
- return v.isoformat() if v else None
-
-
-class AssetDetail(BaseModel):
- id: str
- name: str
- asset_hash: str | None = None
- size: int | None = None
- mime_type: str | None = None
- tags: list[str] = Field(default_factory=list)
- user_metadata: dict[str, Any] = Field(default_factory=dict)
- preview_id: str | None = None
- created_at: datetime | None = None
- last_access_time: datetime | None = None
-
- model_config = ConfigDict(from_attributes=True)
-
- @field_serializer("created_at", "last_access_time")
- def _serialize_datetime(self, v: datetime | None, _info):
- return v.isoformat() if v else None
-
-
-class AssetCreated(AssetDetail):
- created_new: bool
-
-
class TagUsage(BaseModel):
name: str
count: int
@@ -91,3 +66,7 @@ class TagsRemove(BaseModel):
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
+
+
+class TagHistogram(BaseModel):
+ tag_counts: dict[str, int]
diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py
index 721c12f4d..13d3d372c 100644
--- a/app/assets/api/upload.py
+++ b/app/assets/api/upload.py
@@ -52,6 +52,8 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
+ provided_mime_type: str | None = None
+ provided_preview_id: str | None = None
file_written = 0
tmp_path: str | None = None
@@ -128,6 +130,16 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
+ elif fname == "id":
+ raise UploadError(
+ 400,
+ "UNSUPPORTED_FIELD",
+ "Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
+ )
+ elif fname == "mime_type":
+ provided_mime_type = ((await field.text()) or "").strip() or None
+ elif fname == "preview_id":
+ provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
@@ -152,6 +164,8 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
+ provided_mime_type=provided_mime_type,
+ provided_preview_id=provided_preview_id,
)
diff --git a/app/assets/database/models.py b/app/assets/database/models.py
index 03c1c1707..a3af8a192 100644
--- a/app/assets/database/models.py
+++ b/app/assets/database/models.py
@@ -45,13 +45,7 @@ class Asset(Base):
passive_deletes=True,
)
- preview_of: Mapped[list[AssetReference]] = relationship(
- "AssetReference",
- back_populates="preview_asset",
- primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
- foreign_keys=lambda: [AssetReference.preview_id],
- viewonly=True,
- )
+ # preview_id on AssetReference is a self-referential FK to asset_references.id
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
@@ -91,11 +85,15 @@ class AssetReference(Base):
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
preview_id: Mapped[str | None] = mapped_column(
- String(36), ForeignKey("assets.id", ondelete="SET NULL")
+ String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
)
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
+ system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
+ JSON(none_as_null=True), nullable=True, default=None
+ )
+ job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
@@ -115,10 +113,10 @@ class AssetReference(Base):
foreign_keys=[asset_id],
lazy="selectin",
)
- preview_asset: Mapped[Asset | None] = relationship(
- "Asset",
- back_populates="preview_of",
+ preview_ref: Mapped[AssetReference | None] = relationship(
+ "AssetReference",
foreign_keys=[preview_id],
+ remote_side=lambda: [AssetReference.id],
)
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
@@ -152,6 +150,7 @@ class AssetReference(Base):
Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_deleted_at", "deleted_at"),
+ Index("ix_asset_references_preview_id", "preview_id"),
Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
@@ -192,6 +191,10 @@ class AssetReferenceMeta(Base):
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
+ CheckConstraint(
+ "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
+ name="has_value",
+ ),
)
diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py
index 7888d0645..1632937b2 100644
--- a/app/assets/database/queries/__init__.py
+++ b/app/assets/database/queries/__init__.py
@@ -31,16 +31,21 @@ from app.assets.database.queries.asset_reference import (
get_unenriched_references,
get_unreferenced_unhashed_asset_ids,
insert_reference,
+ list_all_file_paths_by_asset_id,
list_references_by_asset_id,
list_references_page,
mark_references_missing_outside_prefixes,
+ rebuild_metadata_projection,
+ reference_exists,
reference_exists_for_asset_id,
restore_references_by_paths,
set_reference_metadata,
set_reference_preview,
+ set_reference_system_metadata,
soft_delete_reference_by_id,
update_reference_access_time,
update_reference_name,
+ update_is_missing_by_asset_id,
update_reference_timestamps,
update_reference_updated_at,
upsert_reference,
@@ -54,6 +59,7 @@ from app.assets.database.queries.tags import (
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_reference_tags,
+ list_tag_counts_for_filtered_assets,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
@@ -97,20 +103,26 @@ __all__ = [
"get_unenriched_references",
"get_unreferenced_unhashed_asset_ids",
"insert_reference",
+ "list_all_file_paths_by_asset_id",
"list_references_by_asset_id",
"list_references_page",
+ "list_tag_counts_for_filtered_assets",
"list_tags_with_usage",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",
+ "rebuild_metadata_projection",
+ "reference_exists",
"reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id",
"remove_tags_from_reference",
"restore_references_by_paths",
"set_reference_metadata",
"set_reference_preview",
+ "set_reference_system_metadata",
"soft_delete_reference_by_id",
"set_reference_tags",
"update_asset_hash_and_mime",
+ "update_is_missing_by_asset_id",
"update_reference_access_time",
"update_reference_name",
"update_reference_timestamps",
diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py
index a21f5b68f..594d1f1b2 100644
--- a/app/assets/database/queries/asset.py
+++ b/app/assets/database/queries/asset.py
@@ -69,7 +69,7 @@ def upsert_asset(
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
- if mime_type and asset.mime_type != mime_type:
+ if mime_type and not asset.mime_type:
asset.mime_type = mime_type
changed = True
if changed:
@@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
return False
if asset_hash is not None:
asset.hash = asset_hash
- if mime_type is not None:
+ if mime_type is not None and not asset.mime_type:
asset.mime_type = mime_type
return True
diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py
index 6524791cc..084a32512 100644
--- a/app/assets/database/queries/asset_reference.py
+++ b/app/assets/database/queries/asset_reference.py
@@ -10,7 +10,7 @@ from decimal import Decimal
from typing import NamedTuple, Sequence
import sqlalchemy as sa
-from sqlalchemy import delete, exists, select
+from sqlalchemy import delete, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, noload
@@ -24,12 +24,14 @@ from app.assets.database.models import (
)
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
+ apply_metadata_filter,
+ apply_tag_filters,
build_prefix_like_conditions,
build_visible_owner_clause,
calculate_rows_per_statement,
iter_chunks,
)
-from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
+from app.assets.helpers import escape_sql_like_string, get_utc_now
def _check_is_scalar(v):
@@ -44,15 +46,6 @@ def _check_is_scalar(v):
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row."""
- if value is None:
- return {
- "key": key,
- "ordinal": ordinal,
- "val_str": None,
- "val_num": None,
- "val_bool": None,
- "val_json": None,
- }
if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)):
@@ -66,96 +59,19 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""Turn a metadata key/value into typed projection rows."""
if value is None:
- return [_scalar_to_row(key, 0, None)]
+ return []
if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)]
if isinstance(value, list):
if all(_check_is_scalar(x) for x in value):
- return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
- return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
+ return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
+ return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
return [{"key": key, "ordinal": 0, "val_json": value}]
-def _apply_tag_filters(
- stmt: sa.sql.Select,
- include_tags: Sequence[str] | None = None,
- exclude_tags: Sequence[str] | None = None,
-) -> sa.sql.Select:
- """include_tags: every tag must be present; exclude_tags: none may be present."""
- include_tags = normalize_tags(include_tags)
- exclude_tags = normalize_tags(exclude_tags)
-
- if include_tags:
- for tag_name in include_tags:
- stmt = stmt.where(
- exists().where(
- (AssetReferenceTag.asset_reference_id == AssetReference.id)
- & (AssetReferenceTag.tag_name == tag_name)
- )
- )
-
- if exclude_tags:
- stmt = stmt.where(
- ~exists().where(
- (AssetReferenceTag.asset_reference_id == AssetReference.id)
- & (AssetReferenceTag.tag_name.in_(exclude_tags))
- )
- )
- return stmt
-
-
-def _apply_metadata_filter(
- stmt: sa.sql.Select,
- metadata_filter: dict | None = None,
-) -> sa.sql.Select:
- """Apply filters using asset_reference_meta projection table."""
- if not metadata_filter:
- return stmt
-
- def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
- return sa.exists().where(
- AssetReferenceMeta.asset_reference_id == AssetReference.id,
- AssetReferenceMeta.key == key,
- *preds,
- )
-
- def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
- if value is None:
- no_row_for_key = sa.not_(
- sa.exists().where(
- AssetReferenceMeta.asset_reference_id == AssetReference.id,
- AssetReferenceMeta.key == key,
- )
- )
- null_row = _exists_for_pred(
- key,
- AssetReferenceMeta.val_json.is_(None),
- AssetReferenceMeta.val_str.is_(None),
- AssetReferenceMeta.val_num.is_(None),
- AssetReferenceMeta.val_bool.is_(None),
- )
- return sa.or_(no_row_for_key, null_row)
-
- if isinstance(value, bool):
- return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
- if isinstance(value, (int, float, Decimal)):
- num = value if isinstance(value, Decimal) else Decimal(str(value))
- return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
- if isinstance(value, str):
- return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
- return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
-
- for k, v in metadata_filter.items():
- if isinstance(v, list):
- ors = [_exists_clause_for_value(k, elem) for elem in v]
- if ors:
- stmt = stmt.where(sa.or_(*ors))
- else:
- stmt = stmt.where(_exists_clause_for_value(k, v))
- return stmt
def get_reference_by_id(
@@ -212,6 +128,21 @@ def reference_exists_for_asset_id(
return session.execute(q).first() is not None
+def reference_exists(
+ session: Session,
+ reference_id: str,
+) -> bool:
+ """Return True if a reference with the given ID exists (not soft-deleted)."""
+ q = (
+ select(sa.literal(True))
+ .select_from(AssetReference)
+ .where(AssetReference.id == reference_id)
+ .where(AssetReference.deleted_at.is_(None))
+ .limit(1)
+ )
+ return session.execute(q).first() is not None
+
+
def insert_reference(
session: Session,
asset_id: str,
@@ -336,8 +267,8 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
- base = _apply_tag_filters(base, include_tags, exclude_tags)
- base = _apply_metadata_filter(base, metadata_filter)
+ base = apply_tag_filters(base, include_tags, exclude_tags)
+ base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
@@ -366,8 +297,8 @@ def list_references_page(
count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
)
- count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
- count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
+ count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
+ count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int(session.execute(count_stmt).scalar_one() or 0)
refs = session.execute(base).unique().scalars().all()
@@ -379,7 +310,7 @@ def list_references_page(
select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
- .order_by(AssetReferenceTag.added_at)
+ .order_by(AssetReferenceTag.tag_name.asc())
)
for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name)
@@ -492,6 +423,42 @@ def update_reference_updated_at(
)
+def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
+ """Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
+
+ The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
+ override system keys of the same name.
+ """
+ session.execute(
+ delete(AssetReferenceMeta).where(
+ AssetReferenceMeta.asset_reference_id == ref.id
+ )
+ )
+ session.flush()
+
+ merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
+ if not merged:
+ return
+
+ rows: list[AssetReferenceMeta] = []
+ for k, v in merged.items():
+ for r in convert_metadata_to_rows(k, v):
+ rows.append(
+ AssetReferenceMeta(
+ asset_reference_id=ref.id,
+ key=r["key"],
+ ordinal=int(r["ordinal"]),
+ val_str=r.get("val_str"),
+ val_num=r.get("val_num"),
+ val_bool=r.get("val_bool"),
+ val_json=r.get("val_json"),
+ )
+ )
+ if rows:
+ session.add_all(rows)
+ session.flush()
+
+
def set_reference_metadata(
session: Session,
reference_id: str,
@@ -505,33 +472,24 @@ def set_reference_metadata(
ref.updated_at = get_utc_now()
session.flush()
- session.execute(
- delete(AssetReferenceMeta).where(
- AssetReferenceMeta.asset_reference_id == reference_id
- )
- )
+ rebuild_metadata_projection(session, ref)
+
+
+def set_reference_system_metadata(
+ session: Session,
+ reference_id: str,
+ system_metadata: dict | None = None,
+) -> None:
+ """Set system_metadata on a reference and rebuild the merged projection."""
+ ref = session.get(AssetReference, reference_id)
+ if not ref:
+ raise ValueError(f"AssetReference {reference_id} not found")
+
+ ref.system_metadata = system_metadata or {}
+ ref.updated_at = get_utc_now()
session.flush()
- if not user_metadata:
- return
-
- rows: list[AssetReferenceMeta] = []
- for k, v in user_metadata.items():
- for r in convert_metadata_to_rows(k, v):
- rows.append(
- AssetReferenceMeta(
- asset_reference_id=reference_id,
- key=r["key"],
- ordinal=int(r["ordinal"]),
- val_str=r.get("val_str"),
- val_num=r.get("val_num"),
- val_bool=r.get("val_bool"),
- val_json=r.get("val_json"),
- )
- )
- if rows:
- session.add_all(rows)
- session.flush()
+ rebuild_metadata_projection(session, ref)
def delete_reference_by_id(
@@ -571,19 +529,19 @@ def soft_delete_reference_by_id(
def set_reference_preview(
session: Session,
reference_id: str,
- preview_asset_id: str | None = None,
+ preview_reference_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
- if preview_asset_id is None:
+ if preview_reference_id is None:
ref.preview_id = None
else:
- if not session.get(Asset, preview_asset_id):
- raise ValueError(f"Preview Asset {preview_asset_id} not found")
- ref.preview_id = preview_asset_id
+ if not session.get(AssetReference, preview_reference_id):
+ raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
+ ref.preview_id = preview_reference_id
ref.updated_at = get_utc_now()
session.flush()
@@ -609,6 +567,8 @@ def list_references_by_asset_id(
session.execute(
select(AssetReference)
.where(AssetReference.asset_id == asset_id)
+ .where(AssetReference.is_missing == False) # noqa: E712
+ .where(AssetReference.deleted_at.is_(None))
.order_by(AssetReference.id.asc())
)
.scalars()
@@ -616,6 +576,25 @@ def list_references_by_asset_id(
)
+def list_all_file_paths_by_asset_id(
+ session: Session,
+ asset_id: str,
+) -> list[str]:
+ """Return every file_path for an asset, including soft-deleted/missing refs.
+
+ Used for orphan cleanup where all on-disk files must be removed.
+ """
+ return list(
+ session.execute(
+ select(AssetReference.file_path)
+ .where(AssetReference.asset_id == asset_id)
+ .where(AssetReference.file_path.isnot(None))
+ )
+ .scalars()
+ .all()
+ )
+
+
def upsert_reference(
session: Session,
asset_id: str,
@@ -855,6 +834,22 @@ def bulk_update_is_missing(
return total
+def update_is_missing_by_asset_id(
+ session: Session, asset_id: str, value: bool
+) -> int:
+ """Set is_missing flag for ALL references belonging to an asset.
+
+ Returns: Number of rows updated
+ """
+ result = session.execute(
+ sa.update(AssetReference)
+ .where(AssetReference.asset_id == asset_id)
+ .where(AssetReference.deleted_at.is_(None))
+ .values(is_missing=value)
+ )
+ return result.rowcount
+
+
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
"""Delete references by their IDs.
diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py
index 194c39a1e..89bb49327 100644
--- a/app/assets/database/queries/common.py
+++ b/app/assets/database/queries/common.py
@@ -1,12 +1,14 @@
"""Shared utilities for database query modules."""
import os
-from typing import Iterable
+from decimal import Decimal
+from typing import Iterable, Sequence
import sqlalchemy as sa
+from sqlalchemy import exists
-from app.assets.database.models import AssetReference
-from app.assets.helpers import escape_sql_like_string
+from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
+from app.assets.helpers import escape_sql_like_string, normalize_tags
MAX_BIND_PARAMS = 800
@@ -52,3 +54,74 @@ def build_prefix_like_conditions(
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds
+
+
+def apply_tag_filters(
+ stmt: sa.sql.Select,
+ include_tags: Sequence[str] | None = None,
+ exclude_tags: Sequence[str] | None = None,
+) -> sa.sql.Select:
+ """include_tags: every tag must be present; exclude_tags: none may be present."""
+ include_tags = normalize_tags(include_tags)
+ exclude_tags = normalize_tags(exclude_tags)
+
+ if include_tags:
+ for tag_name in include_tags:
+ stmt = stmt.where(
+ exists().where(
+ (AssetReferenceTag.asset_reference_id == AssetReference.id)
+ & (AssetReferenceTag.tag_name == tag_name)
+ )
+ )
+
+ if exclude_tags:
+ stmt = stmt.where(
+ ~exists().where(
+ (AssetReferenceTag.asset_reference_id == AssetReference.id)
+ & (AssetReferenceTag.tag_name.in_(exclude_tags))
+ )
+ )
+ return stmt
+
+
+def apply_metadata_filter(
+ stmt: sa.sql.Select,
+ metadata_filter: dict | None = None,
+) -> sa.sql.Select:
+ """Apply filters using asset_reference_meta projection table."""
+ if not metadata_filter:
+ return stmt
+
+ def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
+ return sa.exists().where(
+ AssetReferenceMeta.asset_reference_id == AssetReference.id,
+ AssetReferenceMeta.key == key,
+ *preds,
+ )
+
+ def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
+ if value is None:
+ return sa.not_(
+ sa.exists().where(
+ AssetReferenceMeta.asset_reference_id == AssetReference.id,
+ AssetReferenceMeta.key == key,
+ )
+ )
+
+ if isinstance(value, bool):
+ return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
+ if isinstance(value, (int, float, Decimal)):
+ num = value if isinstance(value, Decimal) else Decimal(str(value))
+ return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
+ if isinstance(value, str):
+ return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
+ return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
+
+ for k, v in metadata_filter.items():
+ if isinstance(v, list):
+ ors = [_exists_clause_for_value(k, elem) for elem in v]
+ if ors:
+ stmt = stmt.where(sa.or_(*ors))
+ else:
+ stmt = stmt.where(_exists_clause_for_value(k, v))
+ return stmt
diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py
index 8b25fee67..f4126dba8 100644
--- a/app/assets/database/queries/tags.py
+++ b/app/assets/database/queries/tags.py
@@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import (
+ Asset,
AssetReference,
AssetReferenceMeta,
AssetReferenceTag,
Tag,
)
from app.assets.database.queries.common import (
+ apply_metadata_filter,
+ apply_tag_filters,
build_visible_owner_clause,
iter_row_chunks,
)
@@ -72,9 +75,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
tag_name
for (tag_name,) in (
session.execute(
- select(AssetReferenceTag.tag_name).where(
- AssetReferenceTag.asset_reference_id == reference_id
- )
+ select(AssetReferenceTag.tag_name)
+ .where(AssetReferenceTag.asset_reference_id == reference_id)
+ .order_by(AssetReferenceTag.tag_name.asc())
)
).all()
]
@@ -117,7 +120,7 @@ def set_reference_tags(
)
session.flush()
- return SetTagsResult(added=to_add, removed=to_remove, total=desired)
+ return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
def add_tags_to_reference(
@@ -272,6 +275,12 @@ def list_tags_with_usage(
.select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
+ .where(
+ sa.or_(
+ AssetReference.is_missing == False, # noqa: E712
+ AssetReferenceTag.tag_name == "missing",
+ )
+ )
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
.subquery()
@@ -308,6 +317,12 @@ def list_tags_with_usage(
select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
+ .where(
+ sa.or_(
+ AssetReference.is_missing == False, # noqa: E712
+ AssetReferenceTag.tag_name == "missing",
+ )
+ )
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
)
@@ -320,6 +335,53 @@ def list_tags_with_usage(
return rows_norm, int(total or 0)
+def list_tag_counts_for_filtered_assets(
+ session: Session,
+ owner_id: str = "",
+ include_tags: Sequence[str] | None = None,
+ exclude_tags: Sequence[str] | None = None,
+ name_contains: str | None = None,
+ metadata_filter: dict | None = None,
+ limit: int = 100,
+) -> dict[str, int]:
+ """Return tag counts for assets matching the given filters.
+
+ Uses the same filtering logic as list_references_page but returns
+ {tag_name: count} instead of paginated references.
+ """
+ # Build a subquery of matching reference IDs
+ ref_sq = (
+ select(AssetReference.id)
+ .join(Asset, Asset.id == AssetReference.asset_id)
+ .where(build_visible_owner_clause(owner_id))
+ .where(AssetReference.is_missing == False) # noqa: E712
+ .where(AssetReference.deleted_at.is_(None))
+ )
+
+ if name_contains:
+ escaped, esc = escape_sql_like_string(name_contains)
+ ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
+
+ ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
+ ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
+ ref_sq = ref_sq.subquery()
+
+ # Count tags across those references
+ q = (
+ select(
+ AssetReferenceTag.tag_name,
+ func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
+ )
+ .where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
+ .group_by(AssetReferenceTag.tag_name)
+ .order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
+ .limit(limit)
+ )
+
+ rows = session.execute(q).all()
+ return {tag_name: int(cnt) for tag_name, cnt in rows}
+
+
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
diff --git a/app/assets/scanner.py b/app/assets/scanner.py
index e27ea5123..4e05a97b5 100644
--- a/app/assets/scanner.py
+++ b/app/assets/scanner.py
@@ -18,7 +18,7 @@ from app.assets.database.queries import (
mark_references_missing_outside_prefixes,
reassign_asset_references,
remove_missing_tag_for_asset_id,
- set_reference_metadata,
+ set_reference_system_metadata,
update_asset_hash_and_mime,
)
from app.assets.services.bulk_ingest import (
@@ -490,8 +490,8 @@ def enrich_asset(
logging.warning("Failed to hash %s: %s", file_path, e)
if extract_metadata and metadata:
- user_metadata = metadata.to_user_metadata()
- set_reference_metadata(session, reference_id, user_metadata)
+ system_metadata = metadata.to_user_metadata()
+ set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash:
existing = get_asset_by_hash(session, full_hash)
diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py
index 3fe7115c8..5aefd9956 100644
--- a/app/assets/services/asset_management.py
+++ b/app/assets/services/asset_management.py
@@ -16,10 +16,12 @@ from app.assets.database.queries import (
get_reference_by_id,
get_reference_with_owner_check,
list_references_page,
+ list_all_file_paths_by_asset_id,
list_references_by_asset_id,
set_reference_metadata,
set_reference_preview,
set_reference_tags,
+ update_asset_hash_and_mime,
update_reference_access_time,
update_reference_name,
update_reference_updated_at,
@@ -67,6 +69,8 @@ def update_asset_metadata(
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
+ mime_type: str | None = None,
+ preview_id: str | None = None,
) -> AssetDetailResult:
with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id)
@@ -103,6 +107,21 @@ def update_asset_metadata(
)
touched = True
+ if mime_type is not None:
+ updated = update_asset_hash_and_mime(
+ session, asset_id=ref.asset_id, mime_type=mime_type
+ )
+ if updated:
+ touched = True
+
+ if preview_id is not None:
+ set_reference_preview(
+ session,
+ reference_id=reference_id,
+ preview_reference_id=preview_id,
+ )
+ touched = True
+
if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id)
@@ -159,11 +178,9 @@ def delete_asset_reference(
session.commit()
return True
- # Orphaned asset - delete it and its files
- refs = list_references_by_asset_id(session, asset_id=asset_id)
- file_paths = [
- r.file_path for r in (refs or []) if getattr(r, "file_path", None)
- ]
+ # Orphaned asset - gather ALL file paths (including
+ # soft-deleted / missing refs) so their on-disk files get cleaned up.
+ file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
# Also include the just-deleted file path
if file_path:
file_paths.append(file_path)
@@ -185,7 +202,7 @@ def delete_asset_reference(
def set_asset_preview(
reference_id: str,
- preview_asset_id: str | None = None,
+ preview_reference_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
@@ -194,7 +211,7 @@ def set_asset_preview(
set_reference_preview(
session,
reference_id=reference_id,
- preview_asset_id=preview_asset_id,
+ preview_reference_id=preview_reference_id,
)
result = fetch_reference_asset_and_tags(
@@ -263,6 +280,47 @@ def list_assets_page(
return ListAssetsResult(items=items, total=total)
+def resolve_hash_to_path(
+ asset_hash: str,
+ owner_id: str = "",
+) -> DownloadResolutionResult | None:
+ """Resolve a blake3 hash to an on-disk file path.
+
+ Only references visible to *owner_id* are considered (owner-less
+ references are always visible).
+
+ Returns a DownloadResolutionResult with abs_path, content_type, and
+ download_name, or None if no asset or live path is found.
+ """
+ with create_session() as session:
+ asset = queries_get_asset_by_hash(session, asset_hash)
+ if not asset:
+ return None
+ refs = list_references_by_asset_id(session, asset_id=asset.id)
+ visible = [
+ r for r in refs
+ if r.owner_id == "" or r.owner_id == owner_id
+ ]
+ abs_path = select_best_live_path(visible)
+ if not abs_path:
+ return None
+ display_name = os.path.basename(abs_path)
+ for ref in visible:
+ if ref.file_path == abs_path and ref.name:
+ display_name = ref.name
+ break
+ ctype = (
+ asset.mime_type
+ or mimetypes.guess_type(display_name)[0]
+ or "application/octet-stream"
+ )
+ return DownloadResolutionResult(
+ abs_path=abs_path,
+ content_type=ctype,
+ download_name=display_name,
+ )
+
+
def resolve_asset_for_download(
reference_id: str,
owner_id: str = "",
diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py
index 44d7aef36..90c51994f 100644
--- a/app/assets/services/ingest.py
+++ b/app/assets/services/ingest.py
@@ -11,13 +11,14 @@ from app.assets.database.queries import (
add_tags_to_reference,
fetch_reference_and_asset,
get_asset_by_hash,
- get_existing_asset_ids,
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
+ reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_tags,
+ update_asset_hash_and_mime,
upsert_asset,
upsert_reference,
validate_tags_exist,
@@ -26,6 +27,7 @@ from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_relative_filename,
+ get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
validate_path_within_base,
)
@@ -65,7 +67,7 @@ def _ingest_file_from_path(
with create_session() as session:
if preview_id:
- if preview_id not in get_existing_asset_ids(session, [preview_id]):
+ if not reference_exists(session, preview_id):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
@@ -135,6 +137,8 @@ def _register_existing_asset(
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
+ mime_type: str | None = None,
+ preview_id: str | None = None,
) -> RegisterAssetResult:
user_metadata = user_metadata or {}
@@ -143,14 +147,25 @@ def _register_existing_asset(
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
+ if mime_type and not asset.mime_type:
+ update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
+
+ if preview_id:
+ if not reference_exists(session, preview_id):
+ preview_id = None
+
ref, ref_created = get_or_create_reference(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
+ preview_id=preview_id,
)
if not ref_created:
+ if preview_id and ref.preview_id != preview_id:
+ ref.preview_id = preview_id
+
tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult(
ref=extract_reference_data(ref),
@@ -242,6 +257,8 @@ def upload_from_temp_path(
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
+ mime_type: str | None = None,
+ preview_id: str | None = None,
) -> UploadResult:
try:
digest, _ = hashing.compute_blake3_hash(temp_path)
@@ -270,6 +287,8 @@ def upload_from_temp_path(
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
+ mime_type=mime_type,
+ preview_id=preview_id,
)
return UploadResult(
ref=result.ref,
@@ -291,7 +310,7 @@ def upload_from_temp_path(
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
- content_type = (
+ content_type = mime_type or (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
@@ -315,7 +334,7 @@ def upload_from_temp_path(
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
- preview_id=None,
+ preview_id=preview_id,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
@@ -342,30 +361,99 @@ def upload_from_temp_path(
)
+def register_file_in_place(
+ abs_path: str,
+ name: str,
+ tags: list[str],
+ owner_id: str = "",
+ mime_type: str | None = None,
+) -> UploadResult:
+ """Register an already-saved file in the asset database without moving it.
+
+ Tags are derived from the filesystem path (root category + subfolder names),
+ merged with any caller-provided tags, matching the behavior of the scanner.
+ If the path is not under a known root, only the caller-provided tags are used.
+ """
+ try:
+ _, path_tags = get_name_and_tags_from_asset_path(abs_path)
+ except ValueError:
+ path_tags = []
+ merged_tags = normalize_tags([*path_tags, *tags])
+
+ try:
+ digest, _ = hashing.compute_blake3_hash(abs_path)
+ except ImportError as e:
+ raise DependencyMissingError(str(e))
+ except Exception as e:
+ raise RuntimeError(f"failed to hash file: {e}")
+ asset_hash = "blake3:" + digest
+
+ size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
+ content_type = mime_type or (
+ mimetypes.guess_type(abs_path, strict=False)[0]
+ or "application/octet-stream"
+ )
+
+ ingest_result = _ingest_file_from_path(
+ abs_path=abs_path,
+ asset_hash=asset_hash,
+ size_bytes=size_bytes,
+ mtime_ns=mtime_ns,
+ mime_type=content_type,
+ info_name=_sanitize_filename(name, fallback=digest),
+ owner_id=owner_id,
+ tags=merged_tags,
+ tag_origin="upload",
+ require_existing_tags=False,
+ )
+ reference_id = ingest_result.reference_id
+ if not reference_id:
+ raise RuntimeError("failed to create asset reference")
+
+ with create_session() as session:
+ pair = fetch_reference_and_asset(
+ session, reference_id=reference_id, owner_id=owner_id
+ )
+ if not pair:
+ raise RuntimeError("inconsistent DB state after ingest")
+ ref, asset = pair
+ tag_names = get_reference_tags(session, reference_id=ref.id)
+
+ return UploadResult(
+ ref=extract_reference_data(ref),
+ asset=extract_asset_data(asset),
+ tags=tag_names,
+ created_new=ingest_result.asset_created,
+ )
+
+
def create_from_hash(
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
+ mime_type: str | None = None,
+ preview_id: str | None = None,
) -> UploadResult | None:
canonical = hash_str.strip().lower()
- with create_session() as session:
- asset = get_asset_by_hash(session, asset_hash=canonical)
- if not asset:
- return None
-
- result = _register_existing_asset(
- asset_hash=canonical,
- name=_sanitize_filename(
- name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
- ),
- user_metadata=user_metadata or {},
- tags=tags or [],
- tag_origin="manual",
- owner_id=owner_id,
- )
+ try:
+ result = _register_existing_asset(
+ asset_hash=canonical,
+ name=_sanitize_filename(
+ name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
+ ),
+ user_metadata=user_metadata or {},
+ tags=tags or [],
+ tag_origin="manual",
+ owner_id=owner_id,
+ mime_type=mime_type,
+ preview_id=preview_id,
+ )
+ except ValueError:
+ logging.warning("create_from_hash: no asset found for hash %s", canonical)
+ return None
return UploadResult(
ref=result.ref,
diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py
index 8b1f1f4dc..0eb128f58 100644
--- a/app/assets/services/schemas.py
+++ b/app/assets/services/schemas.py
@@ -25,7 +25,9 @@ class ReferenceData:
preview_id: str | None
created_at: datetime
updated_at: datetime
- last_access_time: datetime | None
+ system_metadata: dict[str, Any] | None = None
+ job_id: str | None = None
+ last_access_time: datetime | None = None
@dataclass(frozen=True)
@@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
file_path=ref.file_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
+ system_metadata=ref.system_metadata,
+ job_id=ref.job_id,
created_at=ref.created_at,
updated_at=ref.updated_at,
last_access_time=ref.last_access_time,
diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py
index 28900464d..37b612753 100644
--- a/app/assets/services/tagging.py
+++ b/app/assets/services/tagging.py
@@ -1,3 +1,5 @@
+from typing import Sequence
+
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
@@ -6,6 +8,7 @@ from app.assets.database.queries import (
list_tags_with_usage,
remove_tags_from_reference,
)
+from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage
from app.database.db import create_session
@@ -73,3 +76,23 @@ def list_tags(
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
+
+
+def list_tag_histogram(
+ owner_id: str = "",
+ include_tags: Sequence[str] | None = None,
+ exclude_tags: Sequence[str] | None = None,
+ name_contains: str | None = None,
+ metadata_filter: dict | None = None,
+ limit: int = 100,
+) -> dict[str, int]:
+ with create_session() as session:
+ return list_tag_counts_for_filtered_assets(
+ session,
+ owner_id=owner_id,
+ include_tags=include_tags,
+ exclude_tags=exclude_tags,
+ name_contains=name_contains,
+ metadata_filter=metadata_filter,
+ limit=limit,
+ )
diff --git a/app/database/models.py b/app/database/models.py
index e7572677a..b02856f6e 100644
--- a/app/database/models.py
+++ b/app/database/models.py
@@ -1,9 +1,18 @@
from typing import Any
from datetime import datetime
+from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase
+NAMING_CONVENTION = {
+ "ix": "ix_%(table_name)s_%(column_0_N_name)s",
+ "uq": "uq_%(table_name)s_%(column_0_N_name)s",
+ "ck": "ck_%(table_name)s_%(constraint_name)s",
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+ "pk": "pk_%(table_name)s",
+}
+
class Base(DeclarativeBase):
- pass
+ metadata = MetaData(naming_convention=NAMING_CONVENTION)
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()
diff --git a/server.py b/server.py
index 85a8964be..173a28376 100644
--- a/server.py
+++ b/server.py
@@ -35,6 +35,8 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
+from app.assets.services.ingest import register_file_in_place
+from app.assets.services.asset_management import resolve_hash_to_path
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
@@ -419,7 +421,24 @@ class PromptServer():
with open(filepath, "wb") as f:
f.write(image.file.read())
- return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
+ resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
+
+ if args.enable_assets:
+ try:
+ tag = image_upload_type if image_upload_type in ("input", "output") else "input"
+ result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
+ resp["asset"] = {
+ "id": result.ref.id,
+ "name": result.ref.name,
+ "asset_hash": result.asset.hash,
+ "size": result.asset.size_bytes,
+ "mime_type": result.asset.mime_type,
+ "tags": result.tags,
+ }
+ except Exception:
+ logging.warning("Failed to register uploaded image as asset", exc_info=True)
+
+ return web.json_response(resp)
else:
return web.Response(status=400)
@@ -479,30 +498,43 @@ class PromptServer():
async def view_image(request):
if "filename" in request.rel_url.query:
filename = request.rel_url.query["filename"]
- filename, output_dir = folder_paths.annotated_filepath(filename)
- if not filename:
- return web.Response(status=400)
+ # The frontend's LoadImage combo widget uses asset_hash values
+ # (e.g. "blake3:...") as widget values. When litegraph renders the
+ # node preview, it constructs /view?filename=, so this
+ # endpoint must resolve blake3 hashes to their on-disk file paths.
+ if filename.startswith("blake3:"):
+ owner_id = self.user_manager.get_request_user_id(request)
+ result = resolve_hash_to_path(filename, owner_id=owner_id)
+ if result is None:
+ return web.Response(status=404)
+ file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
+ else:
+ resolved_content_type = None
+ filename, output_dir = folder_paths.annotated_filepath(filename)
- # validation for security: prevent accessing arbitrary path
- if filename[0] == '/' or '..' in filename:
- return web.Response(status=400)
+ if not filename:
+ return web.Response(status=400)
- if output_dir is None:
- type = request.rel_url.query.get("type", "output")
- output_dir = folder_paths.get_directory_by_type(type)
+ # validation for security: prevent accessing arbitrary path
+ if filename[0] == '/' or '..' in filename:
+ return web.Response(status=400)
- if output_dir is None:
- return web.Response(status=400)
+ if output_dir is None:
+ type = request.rel_url.query.get("type", "output")
+ output_dir = folder_paths.get_directory_by_type(type)
- if "subfolder" in request.rel_url.query:
- full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
- if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
- return web.Response(status=403)
- output_dir = full_output_dir
+ if output_dir is None:
+ return web.Response(status=400)
- filename = os.path.basename(filename)
- file = os.path.join(output_dir, filename)
+ if "subfolder" in request.rel_url.query:
+ full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
+ return web.Response(status=403)
+ output_dir = full_output_dir
+
+ filename = os.path.basename(filename)
+ file = os.path.join(output_dir, filename)
if os.path.isfile(file):
if 'preview' in request.rel_url.query:
@@ -562,8 +594,13 @@ class PromptServer():
return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""})
else:
- # Get content type from mimetype, defaulting to 'application/octet-stream'
- content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
+ # Use the content type from asset resolution if available,
+ # otherwise guess from the filename.
+ content_type = (
+ resolved_content_type
+ or mimetypes.guess_type(filename)[0]
+ or 'application/octet-stream'
+ )
# For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
diff --git a/tests-unit/app_test/test_migrations.py b/tests-unit/app_test/test_migrations.py
new file mode 100644
index 000000000..fa10c1727
--- /dev/null
+++ b/tests-unit/app_test/test_migrations.py
@@ -0,0 +1,57 @@
+"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
+
+This catches problems like unnamed FK constraints that prevent batch-mode
+drop_constraint from working on real SQLite files (see MB-2).
+
+Migrations 0001 and 0002 are already shipped, so we only exercise
+upgrade/downgrade for 0003+.
+"""
+
+import os
+
+import pytest
+from alembic import command
+from alembic.config import Config
+
+
+# Oldest shipped revision — we upgrade to here as a baseline and never
+# downgrade past it.
+_BASELINE = "0002_merge_to_asset_references"
+
+
+def _make_config(db_path: str) -> Config:
+ root = os.path.join(os.path.dirname(__file__), "../..")
+ config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
+ scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
+
+ cfg = Config(config_path)
+ cfg.set_main_option("script_location", scripts_path)
+ cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
+ return cfg
+
+
+@pytest.fixture
+def migration_db(tmp_path):
+ """Yield an alembic Config pre-upgraded to the baseline revision."""
+ db_path = str(tmp_path / "test_migration.db")
+ cfg = _make_config(db_path)
+ command.upgrade(cfg, _BASELINE)
+ yield cfg
+
+
+def test_upgrade_to_head(migration_db):
+ """Upgrade from baseline to head must succeed on a file-backed DB."""
+ command.upgrade(migration_db, "head")
+
+
+def test_downgrade_to_baseline(migration_db):
+ """Upgrade to head then downgrade back to baseline."""
+ command.upgrade(migration_db, "head")
+ command.downgrade(migration_db, _BASELINE)
+
+
+def test_upgrade_downgrade_cycle(migration_db):
+ """Full cycle: upgrade → downgrade → upgrade again."""
+ command.upgrade(migration_db, "head")
+ command.downgrade(migration_db, _BASELINE)
+ command.upgrade(migration_db, "head")
diff --git a/tests-unit/assets_test/queries/test_asset.py b/tests-unit/assets_test/queries/test_asset.py
index 08f84cd11..9b7eb4bac 100644
--- a/tests-unit/assets_test/queries/test_asset.py
+++ b/tests-unit/assets_test/queries/test_asset.py
@@ -10,6 +10,7 @@ from app.assets.database.queries import (
get_asset_by_hash,
upsert_asset,
bulk_insert_assets,
+ update_asset_hash_and_mime,
)
@@ -142,3 +143,45 @@ class TestBulkInsertAssets:
session.commit()
assert session.query(Asset).count() == 200
+
+
+class TestMimeTypeImmutability:
+ """mime_type on Asset is write-once: set on first ingest, never overwritten."""
+
+ @pytest.mark.parametrize(
+ "initial_mime,second_mime,expected_mime",
+ [
+ ("image/png", "image/jpeg", "image/png"),
+ (None, "image/png", "image/png"),
+ ],
+ ids=["preserves_existing", "fills_null"],
+ )
+ def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
+ h = f"blake3:upsert_{initial_mime}_{second_mime}"
+ upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
+ session.commit()
+
+ asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
+ assert created is False
+ assert asset.mime_type == expected_mime
+
+ @pytest.mark.parametrize(
+ "initial_mime,update_mime,update_hash,expected_mime,expected_hash",
+ [
+ (None, "image/png", None, "image/png", "blake3:upd0"),
+ ("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
+ ("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
+ ],
+ ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
+ )
+ def test_update_asset_hash_and_mime_immutability(
+ self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
+ ):
+ h = expected_hash.removesuffix("_new")
+ asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
+ session.add(asset)
+ session.flush()
+
+ update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
+ assert asset.mime_type == expected_mime
+ assert asset.hash == expected_hash
diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py
index 8f6c7fcdb..fe510e342 100644
--- a/tests-unit/assets_test/queries/test_asset_info.py
+++ b/tests-unit/assets_test/queries/test_asset_info.py
@@ -242,22 +242,24 @@ class TestSetReferencePreview:
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
+ preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit()
- set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id)
+ set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
session.commit()
session.refresh(ref)
- assert ref.preview_id == preview_asset.id
+ assert ref.preview_id == preview_ref.id
def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
- ref.preview_id = preview_asset.id
+ preview_ref = _make_reference(session, preview_asset, name="preview.png")
+ ref.preview_id = preview_ref.id
session.commit()
- set_reference_preview(session, reference_id=ref.id, preview_asset_id=None)
+ set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
session.commit()
session.refresh(ref)
@@ -265,15 +267,15 @@ class TestSetReferencePreview:
def test_raises_for_nonexistent_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"):
- set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None)
+ set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
- with pytest.raises(ValueError, match="Preview Asset"):
- set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent")
+ with pytest.raises(ValueError, match="Preview AssetReference"):
+ set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
class TestInsertReference:
@@ -351,13 +353,14 @@ class TestUpdateReferenceTimestamps:
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
+ preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit()
- update_reference_timestamps(session, ref, preview_id=preview_asset.id)
+ update_reference_timestamps(session, ref, preview_id=preview_ref.id)
session.commit()
session.refresh(ref)
- assert ref.preview_id == preview_asset.id
+ assert ref.preview_id == preview_ref.id
class TestSetReferenceMetadata:
diff --git a/tests-unit/assets_test/queries/test_metadata.py b/tests-unit/assets_test/queries/test_metadata.py
index 6a545e819..d7a747789 100644
--- a/tests-unit/assets_test/queries/test_metadata.py
+++ b/tests-unit/assets_test/queries/test_metadata.py
@@ -20,6 +20,7 @@ def _make_reference(
asset: Asset,
name: str,
metadata: dict | None = None,
+ system_metadata: dict | None = None,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
@@ -27,6 +28,7 @@ def _make_reference(
name=name,
asset_id=asset.id,
user_metadata=metadata,
+ system_metadata=system_metadata,
created_at=now,
updated_at=now,
last_access_time=now,
@@ -34,8 +36,10 @@ def _make_reference(
session.add(ref)
session.flush()
- if metadata:
- for key, val in metadata.items():
+ # Build merged projection: {**system_metadata, **user_metadata}
+ merged = {**(system_metadata or {}), **(metadata or {})}
+ if merged:
+ for key, val in merged.items():
for row in convert_metadata_to_rows(key, val):
meta_row = AssetReferenceMeta(
asset_reference_id=ref.id,
@@ -182,3 +186,46 @@ class TestMetadataFilterEmptyDict:
refs, _, total = list_references_page(session, metadata_filter={})
assert total == 2
+
+
+class TestSystemMetadataProjection:
+ """Tests for system_metadata merging into the filter projection."""
+
+ def test_system_metadata_keys_are_filterable(self, session: Session):
+ """system_metadata keys should appear in the merged projection."""
+ asset = _make_asset(session, "hash1")
+ _make_reference(
+ session, asset, "with_sys",
+ system_metadata={"source": "scanner"},
+ )
+ _make_reference(session, asset, "without_sys")
+ session.commit()
+
+ refs, _, total = list_references_page(
+ session, metadata_filter={"source": "scanner"}
+ )
+ assert total == 1
+ assert refs[0].name == "with_sys"
+
+ def test_user_metadata_overrides_system_metadata(self, session: Session):
+ """user_metadata should win when both have the same key."""
+ asset = _make_asset(session, "hash1")
+ _make_reference(
+ session, asset, "overridden",
+ metadata={"origin": "user_upload"},
+ system_metadata={"origin": "auto_scan"},
+ )
+ session.commit()
+
+ # Should match the user value, not the system value
+ refs, _, total = list_references_page(
+ session, metadata_filter={"origin": "user_upload"}
+ )
+ assert total == 1
+ assert refs[0].name == "overridden"
+
+ # Should NOT match the system value (it was overridden)
+ refs, _, total = list_references_page(
+ session, metadata_filter={"origin": "auto_scan"}
+ )
+ assert total == 0
diff --git a/tests-unit/assets_test/services/test_asset_management.py b/tests-unit/assets_test/services/test_asset_management.py
index 101ef7292..e8ff989e9 100644
--- a/tests-unit/assets_test/services/test_asset_management.py
+++ b/tests-unit/assets_test/services/test_asset_management.py
@@ -11,6 +11,7 @@ from app.assets.services import (
delete_asset_reference,
set_asset_preview,
)
+from app.assets.services.asset_management import resolve_hash_to_path
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
@@ -219,31 +220,33 @@ class TestSetAssetPreview:
asset = _make_asset(session, hash_val="blake3:main")
preview_asset = _make_asset(session, hash_val="blake3:preview")
ref = _make_reference(session, asset)
+ preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref_id = ref.id
- preview_id = preview_asset.id
+ preview_ref_id = preview_ref.id
session.commit()
set_asset_preview(
reference_id=ref_id,
- preview_asset_id=preview_id,
+ preview_reference_id=preview_ref_id,
)
# Verify by re-fetching from DB
session.expire_all()
updated_ref = session.get(AssetReference, ref_id)
- assert updated_ref.preview_id == preview_id
+ assert updated_ref.preview_id == preview_ref_id
def test_clears_preview(self, mock_create_session, session: Session):
asset = _make_asset(session)
preview_asset = _make_asset(session, hash_val="blake3:preview")
ref = _make_reference(session, asset)
- ref.preview_id = preview_asset.id
+ preview_ref = _make_reference(session, preview_asset, name="preview.png")
+ ref.preview_id = preview_ref.id
ref_id = ref.id
session.commit()
set_asset_preview(
reference_id=ref_id,
- preview_asset_id=None,
+ preview_reference_id=None,
)
# Verify by re-fetching from DB
@@ -263,6 +266,45 @@ class TestSetAssetPreview:
with pytest.raises(PermissionError, match="not owner"):
set_asset_preview(
reference_id=ref.id,
- preview_asset_id=None,
+ preview_reference_id=None,
owner_id="user2",
)
+
+
+class TestResolveHashToPath:
+ def test_returns_none_for_unknown_hash(self, mock_create_session):
+ result = resolve_hash_to_path("blake3:" + "a" * 64)
+ assert result is None
+
+ @pytest.mark.parametrize(
+ "ref_owner, query_owner, expect_found",
+ [
+ ("user1", "user1", True),
+ ("user1", "user2", False),
+ ("", "anyone", True),
+ ("", "", True),
+ ],
+ ids=[
+ "owner_sees_own_ref",
+ "other_owner_blocked",
+ "ownerless_visible_to_anyone",
+ "ownerless_visible_to_empty",
+ ],
+ )
+ def test_owner_visibility(
+ self, ref_owner, query_owner, expect_found,
+ mock_create_session, session: Session, temp_dir,
+ ):
+ f = temp_dir / "file.bin"
+ f.write_bytes(b"data")
+ asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
+ ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
+ ref.file_path = str(f)
+ session.commit()
+
+ result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
+ if expect_found:
+ assert result is not None
+ assert result.abs_path == str(f)
+ else:
+ assert result is None
diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py
index 367bc7721..dbb8441c2 100644
--- a/tests-unit/assets_test/services/test_ingest.py
+++ b/tests-unit/assets_test/services/test_ingest.py
@@ -113,11 +113,19 @@ class TestIngestFileFromPath:
file_path = temp_dir / "with_preview.bin"
file_path.write_bytes(b"data")
- # Create a preview asset first
+ # Create a preview asset and reference
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
session.add(preview_asset)
+ session.flush()
+ from app.assets.helpers import get_utc_now
+ now = get_utc_now()
+ preview_ref = AssetReference(
+ asset_id=preview_asset.id, name="preview.png", owner_id="",
+ created_at=now, updated_at=now, last_access_time=now,
+ )
+ session.add(preview_ref)
session.commit()
- preview_id = preview_asset.id
+ preview_id = preview_ref.id
result = _ingest_file_from_path(
abs_path=str(file_path),
diff --git a/tests-unit/assets_test/services/test_tag_histogram.py b/tests-unit/assets_test/services/test_tag_histogram.py
new file mode 100644
index 000000000..7bcd518ec
--- /dev/null
+++ b/tests-unit/assets_test/services/test_tag_histogram.py
@@ -0,0 +1,123 @@
+"""Tests for list_tag_histogram service function."""
+from sqlalchemy.orm import Session
+
+from app.assets.database.models import Asset, AssetReference
+from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
+from app.assets.helpers import get_utc_now
+from app.assets.services.tagging import list_tag_histogram
+
+
+def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
+ asset = Asset(hash=hash_val, size_bytes=1024)
+ session.add(asset)
+ session.flush()
+ return asset
+
+
+def _make_reference(
+ session: Session,
+ asset: Asset,
+ name: str = "test",
+ owner_id: str = "",
+) -> AssetReference:
+ now = get_utc_now()
+ ref = AssetReference(
+ owner_id=owner_id,
+ name=name,
+ asset_id=asset.id,
+ created_at=now,
+ updated_at=now,
+ last_access_time=now,
+ )
+ session.add(ref)
+ session.flush()
+ return ref
+
+
+class TestListTagHistogram:
+ def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
+ ensure_tags_exist(session, ["alpha", "beta"])
+ a1 = _make_asset(session, "blake3:aaa")
+ r1 = _make_reference(session, a1, name="r1")
+ add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
+
+ a2 = _make_asset(session, "blake3:bbb")
+ r2 = _make_reference(session, a2, name="r2")
+ add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
+ session.commit()
+
+ result = list_tag_histogram()
+
+ assert result["alpha"] == 2
+ assert result["beta"] == 1
+
+ def test_empty_when_no_assets(self, mock_create_session, session: Session):
+ ensure_tags_exist(session, ["unused"])
+ session.commit()
+
+ result = list_tag_histogram()
+
+ assert result == {}
+
+ def test_include_tags_filter(self, mock_create_session, session: Session):
+ ensure_tags_exist(session, ["models", "loras", "input"])
+ a1 = _make_asset(session, "blake3:aaa")
+ r1 = _make_reference(session, a1, name="r1")
+ add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
+
+ a2 = _make_asset(session, "blake3:bbb")
+ r2 = _make_reference(session, a2, name="r2")
+ add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
+ session.commit()
+
+ result = list_tag_histogram(include_tags=["models"])
+
+ # Only r1 has "models", so only its tags appear
+ assert "models" in result
+ assert "loras" in result
+ assert "input" not in result
+
+ def test_exclude_tags_filter(self, mock_create_session, session: Session):
+ ensure_tags_exist(session, ["models", "loras", "input"])
+ a1 = _make_asset(session, "blake3:aaa")
+ r1 = _make_reference(session, a1, name="r1")
+ add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
+
+ a2 = _make_asset(session, "blake3:bbb")
+ r2 = _make_reference(session, a2, name="r2")
+ add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
+ session.commit()
+
+ result = list_tag_histogram(exclude_tags=["models"])
+
+ # r1 excluded, only r2's tags remain
+ assert "input" in result
+ assert "loras" not in result
+
+ def test_name_contains_filter(self, mock_create_session, session: Session):
+ ensure_tags_exist(session, ["alpha", "beta"])
+ a1 = _make_asset(session, "blake3:aaa")
+ r1 = _make_reference(session, a1, name="my_model.safetensors")
+ add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
+
+ a2 = _make_asset(session, "blake3:bbb")
+ r2 = _make_reference(session, a2, name="picture.png")
+ add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
+ session.commit()
+
+ result = list_tag_histogram(name_contains="model")
+
+ assert "alpha" in result
+ assert "beta" not in result
+
+ def test_limit_caps_results(self, mock_create_session, session: Session):
+ tags = [f"tag{i}" for i in range(10)]
+ ensure_tags_exist(session, tags)
+ a = _make_asset(session, "blake3:aaa")
+ r = _make_reference(session, a, name="r1")
+ add_tags_to_reference(session, reference_id=r.id, tags=tags)
+ session.commit()
+
+ result = list_tag_histogram(limit=3)
+
+ assert len(result) == 3
diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py
index d68e5b5d7..0f2b124a3 100644
--- a/tests-unit/assets_test/test_uploads.py
+++ b/tests-unit/assets_test/test_uploads.py
@@ -243,6 +243,15 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
+def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
+ files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
+ form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
+ r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
+ body = r.json()
+ assert r.status_code == 400
+ assert body["error"]["code"] == "INVALID_BODY"
+
+
@pytest.mark.parametrize("root", ["input", "output"])
def test_duplicate_upload_same_display_name_does_not_clobber(
root: str,
From 7d5f5252c3dfdc8a6227e6f6ffb7aab5b3ec827c Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Mon, 16 Mar 2026 12:53:13 -0700
Subject: [PATCH 36/80] ci: add check to block AI agent Co-authored-by trailers
in PRs (#12799)
Add a GitHub Actions workflow and shell script that scan all commits
in a pull request for Co-authored-by trailers from known AI coding
agents (Claude, Cursor, Copilot, Codex, Aider, Devin, Gemini, Jules,
Windsurf, Cline, Amazon Q, Continue, OpenCode, etc.).
The check fails with clear instructions on how to remove the trailers
via interactive rebase.
---
.github/scripts/check-ai-co-authors.sh | 103 ++++++++++++++++++++++
.github/workflows/check-ai-co-authors.yml | 19 ++++
2 files changed, 122 insertions(+)
create mode 100755 .github/scripts/check-ai-co-authors.sh
create mode 100644 .github/workflows/check-ai-co-authors.yml
diff --git a/.github/scripts/check-ai-co-authors.sh b/.github/scripts/check-ai-co-authors.sh
new file mode 100755
index 000000000..842b1f2d8
--- /dev/null
+++ b/.github/scripts/check-ai-co-authors.sh
@@ -0,0 +1,103 @@
+#!/usr/bin/env bash
+# Checks pull request commits for AI agent Co-authored-by trailers.
+# Exits non-zero when any are found and prints fix instructions.
+set -euo pipefail
+
+base_sha="${1:?usage: check-ai-co-authors.sh }"
+head_sha="${2:?usage: check-ai-co-authors.sh }"
+
+# Known AI coding-agent trailer patterns (case-insensitive).
+# Each entry is an extended-regex fragment matched against Co-authored-by lines.
+AGENT_PATTERNS=(
+ # Anthropic — Claude Code / Amp
+ 'noreply@anthropic\.com'
+ # Cursor
+ 'cursoragent@cursor\.com'
+ # GitHub Copilot
+ 'copilot-swe-agent\[bot\]'
+ 'copilot@github\.com'
+ # OpenAI Codex
+ 'noreply@openai\.com'
+ 'codex@openai\.com'
+ # Aider
+ 'aider@aider\.chat'
+ # Google — Gemini / Jules
+ 'gemini@google\.com'
+ 'jules@google\.com'
+ # Windsurf / Codeium
+ '@codeium\.com'
+ # Devin
+ 'devin-ai-integration\[bot\]'
+ 'devin@cognition\.ai'
+ 'devin@cognition-labs\.com'
+ # Amazon Q Developer
+ 'amazon-q-developer'
+ '@amazon\.com.*[Qq].[Dd]eveloper'
+ # Cline
+ 'cline-bot'
+ 'cline@cline\.ai'
+ # Continue
+ 'continue-agent'
+ 'continue@continue\.dev'
+ # Sourcegraph
+ 'noreply@sourcegraph\.com'
+ # Generic catch-alls for common agent name patterns
+ 'Co-authored-by:.*\b[Cc]laude\b'
+ 'Co-authored-by:.*\b[Cc]opilot\b'
+ 'Co-authored-by:.*\b[Cc]ursor\b'
+ 'Co-authored-by:.*\b[Cc]odex\b'
+ 'Co-authored-by:.*\b[Gg]emini\b'
+ 'Co-authored-by:.*\b[Aa]ider\b'
+ 'Co-authored-by:.*\b[Dd]evin\b'
+ 'Co-authored-by:.*\b[Ww]indsurf\b'
+ 'Co-authored-by:.*\b[Cc]line\b'
+ 'Co-authored-by:.*\b[Aa]mazon Q\b'
+ 'Co-authored-by:.*\b[Jj]ules\b'
+ 'Co-authored-by:.*\bOpenCode\b'
+)
+
+# Build a single alternation regex from all patterns.
+regex=""
+for pattern in "${AGENT_PATTERNS[@]}"; do
+ if [[ -n "$regex" ]]; then
+ regex="${regex}|${pattern}"
+ else
+ regex="$pattern"
+ fi
+done
+
+# Collect Co-authored-by lines from every commit in the PR range.
+violations=""
+while IFS= read -r sha; do
+ message="$(git log -1 --format='%B' "$sha")"
+ matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
+ if [[ -z "$matched_lines" ]]; then
+ continue
+ fi
+
+ while IFS= read -r line; do
+ if echo "$line" | grep -iqE "$regex"; then
+ short="$(git log -1 --format='%h' "$sha")"
+ violations="${violations} ${short}: ${line}"$'\n'
+ fi
+ done <<< "$matched_lines"
+done < <(git rev-list "${base_sha}..${head_sha}")
+
+if [[ -n "$violations" ]]; then
+ echo "::error::AI agent Co-authored-by trailers detected in PR commits."
+ echo ""
+ echo "The following commits contain Co-authored-by trailers from AI coding agents:"
+ echo ""
+ echo "$violations"
+ echo "These trailers should be removed before merging."
+ echo ""
+ echo "To fix, rewrite the commit messages with:"
+ echo " git rebase -i ${base_sha}"
+ echo ""
+ echo "and remove the Co-authored-by lines, then force-push your branch."
+ echo ""
+ echo "If you believe this is a false positive, please open an issue."
+ exit 1
+fi
+
+echo "No AI agent Co-authored-by trailers found."
diff --git a/.github/workflows/check-ai-co-authors.yml b/.github/workflows/check-ai-co-authors.yml
new file mode 100644
index 000000000..2ad9ac972
--- /dev/null
+++ b/.github/workflows/check-ai-co-authors.yml
@@ -0,0 +1,19 @@
+name: Check AI Co-Authors
+
+on:
+ pull_request:
+ branches: ['*']
+
+jobs:
+ check-ai-co-authors:
+ name: Check for AI agent co-author trailers
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Check commits for AI co-author trailers
+ run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
From b202f842af10824b62a3158f0887ee371e16beb6 Mon Sep 17 00:00:00 2001
From: blepping <157360029+blepping@users.noreply.github.com>
Date: Mon, 16 Mar 2026 14:00:42 -0600
Subject: [PATCH 37/80] Skip running model finalizers at exit (#12994)
---
comfy/model_management.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index a4af5ddb2..2c250dacc 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -541,6 +541,7 @@ class LoadedModel:
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
+ self._patcher_finalizer.atexit = False
def _switch_parent(self):
model = self._parent_model()
@@ -587,6 +588,7 @@ class LoadedModel:
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
+ self.model_finalizer.atexit = False
return real_model
def should_reload_model(self, force_patch_weights=False):
From 7a16e8aa4e4672733280887a38758be530ba13ea Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 16 Mar 2026 13:50:13 -0700
Subject: [PATCH 38/80] Add --enable-dynamic-vram options to force enable it.
(#13002)
---
comfy/cli_args.py | 3 +++
main.py | 4 ++--
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 0a0bf2f30..13612175e 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -149,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
+parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@@ -262,4 +263,6 @@ else:
args.fast = set(args.fast)
def enables_dynamic_vram():
+ if args.enable_dynamic_vram:
+ return True
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
diff --git a/main.py b/main.py
index 8905fd09a..f99aee38e 100644
--- a/main.py
+++ b/main.py
@@ -206,8 +206,8 @@ import hook_breaker_ac10a0
import comfy.memory_management
import comfy.model_patcher
-if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
- if comfy.model_management.torch_version_numeric < (2, 8):
+if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
+ if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG':
From 20561aa91926508c6ad6db185193c9604cfdf3c9 Mon Sep 17 00:00:00 2001
From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>
Date: Tue, 17 Mar 2026 09:31:50 +0800
Subject: [PATCH 39/80] [Trainer] FP4, 8, 16 training by native dtype support
and quant linear autograd function (#12681)
---
comfy/ops.py | 101 ++++++++++++++++++++++++++++++++++--
comfy/utils.py | 4 ++
comfy_extras/nodes_train.py | 68 +++++++++++++++++-------
3 files changed, 150 insertions(+), 23 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index f47d4137a..1518ec9de 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -776,6 +776,71 @@ from .quant_ops import (
)
+class QuantLinearFunc(torch.autograd.Function):
+ """Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
+ Handles any input rank by flattening to 2D for matmul and restoring shape after.
+ """
+
+ @staticmethod
+ def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
+ input_shape = input_float.shape
+ inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
+
+ # Quantize input (same as inference path)
+ if layout_type is not None:
+ q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
+ else:
+ q_input = inp
+
+ w = weight.detach() if weight.requires_grad else weight
+ b = bias.detach() if bias is not None and bias.requires_grad else bias
+
+ output = torch.nn.functional.linear(q_input, w, b)
+
+ # Restore original input shape
+ if len(input_shape) > 2:
+ output = output.unflatten(0, input_shape[:-1])
+
+ ctx.save_for_backward(input_float, weight)
+ ctx.input_shape = input_shape
+ ctx.has_bias = bias is not None
+ ctx.compute_dtype = compute_dtype
+ ctx.weight_requires_grad = weight.requires_grad
+
+ return output
+
+ @staticmethod
+ @torch.autograd.function.once_differentiable
+ def backward(ctx, grad_output):
+ input_float, weight = ctx.saved_tensors
+ compute_dtype = ctx.compute_dtype
+ grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
+
+ # Dequantize weight to compute dtype for backward matmul
+ if isinstance(weight, QuantizedTensor):
+ weight_f = weight.dequantize().to(compute_dtype)
+ else:
+ weight_f = weight.to(compute_dtype)
+
+ # grad_input = grad_output @ weight
+ grad_input = torch.mm(grad_2d, weight_f)
+ if len(ctx.input_shape) > 2:
+ grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
+
+ # grad_weight (only if weight requires grad, typically frozen for quantized training)
+ grad_weight = None
+ if ctx.weight_requires_grad:
+ input_f = input_float.flatten(0, -2).to(compute_dtype)
+ grad_weight = torch.mm(grad_2d.t(), input_f)
+
+ # grad_bias
+ grad_bias = None
+ if ctx.has_bias:
+ grad_bias = grad_2d.sum(dim=0)
+
+ return grad_input, grad_weight, grad_bias, None, None, None
+
+
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
@@ -970,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
- if (getattr(self, 'layout_type', None) is not None and
+ _use_quantized = (
+ getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
- len(self.weight_function) == 0 and len(self.bias_function) == 0):
+ len(self.weight_function) == 0 and len(self.bias_function) == 0
+ )
+
+ # Training path: quantized forward with compute_dtype backward via autograd function
+ if (input.requires_grad and _use_quantized):
+
+ weight, bias, offload_stream = cast_bias_weight(
+ self,
+ input,
+ offloadable=True,
+ compute_dtype=compute_dtype,
+ want_requant=True
+ )
+
+ scale = getattr(self, 'input_scale', None)
+ if scale is not None:
+ scale = comfy.model_management.cast_to_device(scale, input.device, None)
+
+ output = QuantLinearFunc.apply(
+ input, weight, bias, self.layout_type, scale, compute_dtype
+ )
+
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return output
+
+ # Inference path (unchanged)
+ if _use_quantized:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@@ -1021,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items():
if param is None:
continue
- self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
+ p = fn(param)
+ if p.is_inference():
+ p = p.clone()
+ self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
diff --git a/comfy/utils.py b/comfy/utils.py
index 9931fe3b4..e331b618b 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
return prev
def set_attr_param(obj, attr, value):
+ # Clone inference tensors (created under torch.inference_mode) since
+ # their version counter is frozen and nn.Parameter() cannot wrap them.
+ if value.is_inference():
+ value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):
diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py
index aa2d88673..0ad0acee6 100644
--- a/comfy_extras/nodes_train.py
+++ b/comfy_extras/nodes_train.py
@@ -15,6 +15,7 @@ import comfy.sampler_helpers
import comfy.sd
import comfy.utils
import comfy.model_management
+from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
@@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
+ use_grad_scaler=False,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
+ # GradScaler for fp16 training
+ self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
@@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
- loss = self.loss_fn(x0_pred, x0)
+ loss = self.loss_fn(x0_pred.float(), x0.float())
if bwd:
bwd_loss = loss / self.grad_acc
- bwd_loss.backward()
+ if self.grad_scaler is not None:
+ self.grad_scaler.scale(bwd_loss).backward()
+ else:
+ bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
@@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
- total_loss.backward()
+ if self.grad_scaler is not None:
+ self.grad_scaler.scale(total_loss).backward()
+ else:
+ total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
@@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
+ if self.grad_scaler is not None:
+ self.grad_scaler.unscale_(self.optimizer)
for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]:
if param.grad is None:
continue
param.grad.data = param.grad.data.to(param.data.dtype)
- self.optimizer.step()
+ if self.grad_scaler is not None:
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+ else:
+ self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
torch.cuda.empty_cache()
@@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
),
io.Combo.Input(
"training_dtype",
- options=["bf16", "fp32"],
+ options=["bf16", "fp32", "none"],
default="bf16",
- tooltip="The dtype to use for training.",
+ tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
),
io.Combo.Input(
"lora_dtype",
@@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input(
"offloading",
default=False,
- tooltip="Offload the Model to RAM. Requires Bypass Mode.",
+ tooltip="Offload model weights to CPU during training to save GPU memory.",
),
io.Combo.Input(
"existing_lora",
@@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype
mp = model.clone()
- dtype = node_helpers.string_to_torch_dtype(training_dtype)
+ use_grad_scaler = False
+ if training_dtype != "none":
+ dtype = node_helpers.string_to_torch_dtype(training_dtype)
+ mp.set_model_compute_dtype(dtype)
+ else:
+ # Detect model's native dtype for autocast
+ model_dtype = mp.model.get_dtype()
+ if model_dtype == torch.float16:
+ dtype = torch.float16
+ use_grad_scaler = True
+ # Warn about fp16 accumulation instability during training
+ if PerformanceFeature.Fp16Accumulation in args.fast:
+ logging.warning(
+ "WARNING: FP16 model detected with fp16_accumulation enabled. "
+ "This combination can be numerically unstable during training and may cause NaN values. "
+ "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
+ )
+ else:
+ # For fp8, bf16, or other dtypes, use bf16 autocast
+ dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
- mp.set_model_compute_dtype(dtype)
-
- if mp.is_dynamic():
- if not bypass_mode:
- logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
- bypass_mode = True
- offloading = True
- elif offloading:
- if not bypass_mode:
- logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts
+ latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
- latents, dtype, bucket_mode
+ latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
@@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
+ use_grad_scaler=use_grad_scaler,
)
else:
train_sampler = TrainSampler(
@@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
+ use_grad_scaler=use_grad_scaler,
)
# Setup guider
@@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
io.Int.Input(
"steps",
optional=True,
- tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
+ tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
),
],
outputs=[],
From ca17fc835593593f04b0aec04e266afc32a2ccfb Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 16 Mar 2026 18:38:40 -0700
Subject: [PATCH 40/80] Fix potential issue. (#13009)
---
comfy/utils.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/utils.py b/comfy/utils.py
index e331b618b..13b7ca6c8 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -899,7 +899,7 @@ def set_attr(obj, attr, value):
def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
- if value.is_inference():
+ if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
From 9a870b5102fa831d805f53b255123623d063f660 Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Mon, 16 Mar 2026 18:56:35 -0700
Subject: [PATCH 41/80] fix: atomic writes for userdata to prevent data loss on
crash (#12987)
Write to a temp file in the same directory then os.replace() onto the
target path. If the process crashes mid-write, the original file is
left intact instead of being truncated to zero bytes.
Fixes #11298
---
app/user_manager.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/app/user_manager.py b/app/user_manager.py
index e2c00dab2..e18afb71b 100644
--- a/app/user_manager.py
+++ b/app/user_manager.py
@@ -6,6 +6,7 @@ import uuid
import glob
import shutil
import logging
+import tempfile
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
@@ -377,8 +378,15 @@ class UserManager():
try:
body = await request.read()
- with open(path, "wb") as f:
- f.write(body)
+ dir_name = os.path.dirname(path)
+ fd, tmp_path = tempfile.mkstemp(dir=dir_name)
+ try:
+ with os.fdopen(fd, "wb") as f:
+ f.write(body)
+ os.replace(tmp_path, path)
+ except:
+ os.unlink(tmp_path)
+ raise
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(
From 8cc746a86411bd7a08d42829dc805f39f8bced65 Mon Sep 17 00:00:00 2001
From: Paulo Muggler Moreira
Date: Tue, 17 Mar 2026 03:27:27 +0100
Subject: [PATCH 42/80] fix: disable SageAttention for Hunyuan3D v2.1 DiT
(#12772)
---
comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
index d48d9d642..f67ba84e9 100644
--- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
+++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
@@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
+ low_precision_attention=False,
)
out = self.out_proj(x)
@@ -412,6 +413,7 @@ class Attention(nn.Module):
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
+ low_precision_attention=False,
)
x = self.out_proj(x)
From 379fbd1a827cd2ce97984a7e8ea8b7159780cd1c Mon Sep 17 00:00:00 2001
From: ComfyUI Wiki
Date: Tue, 17 Mar 2026 12:53:18 +0800
Subject: [PATCH 43/80] chore: update workflow templates to v0.9.26 (#13012)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 7e59ef206..0ce163f71 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.41.20
-comfyui-workflow-templates==0.9.21
+comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3
torch
torchsde
From ed7c2c65790c36871b90fff2bdd3de25a17a5431 Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Tue, 17 Mar 2026 07:24:00 -0700
Subject: [PATCH 44/80] Mark weight_dtype as advanced input in Load Diffusion
Model node (#12769)
Mark the weight_dtype parameter in UNETLoader (Load Diffusion Model) as
an advanced input to reduce UI complexity for new users. The parameter
is now hidden behind an expandable Advanced section, matching the
pattern used for other advanced inputs like device, tile_size, and
overlap.
Amp-Thread-ID: https://ampcode.com/threads/T-019cbaf1-d3c0-718e-a325-318baba86dec
---
nodes.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/nodes.py b/nodes.py
index 03dcc9d4a..e93fa9767 100644
--- a/nodes.py
+++ b/nodes.py
@@ -952,7 +952,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
- "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
+ "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
From 1a157e1f97d32c27b3b8bd842bfc5e448c240fe7 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 17 Mar 2026 14:32:43 -0700
Subject: [PATCH 45/80] Reduce LTX VAE VRAM usage and save use cases from
OOMs/Tiler (#13013)
* ltx: vae: scale the chunk size with the users VRAM
Scale this linearly down for users with low VRAM.
* ltx: vae: free non-chunking recursive intermediates
* ltx: vae: cleanup some intermediates
The conv layer can be the VRAM peak and it does a torch.cat. So cleanup
the pieces of the cat. Also clear our the cache ASAP as each layer detect
its end as this VAE surges in VRAM at the end due to the ended padding
increasing the size of the final frame convolutions off-the-books to
the chunker. So if all the earlier layers free up their cache it can
offset that surge.
Its a fragmentation nightmare, and the chance of it having to recache the
pyt allocator is very high, but you wont OOM.
---
comfy/ldm/lightricks/vae/causal_conv3d.py | 4 ++
.../vae/causal_video_autoencoder.py | 41 +++++++++++++++----
2 files changed, 38 insertions(+), 7 deletions(-)
diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py
index b8341edbc..356394239 100644
--- a/comfy/ldm/lightricks/vae/causal_conv3d.py
+++ b/comfy/ldm/lightricks/vae/causal_conv3d.py
@@ -65,9 +65,13 @@ class CausalConv3d(nn.Module):
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
+ del pieces
+ del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
+ elif is_end:
+ self.temporal_cache_state[tid] = (None, True)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index 9f14f64a5..0504140ef 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -297,7 +297,23 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None)
-MAX_CHUNK_SIZE=(128 * 1024 ** 2)
+MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
+MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
+MIN_CHUNK_SIZE = 32 * 1024 ** 2
+MAX_CHUNK_SIZE = 128 * 1024 ** 2
+
+def get_max_chunk_size(device: torch.device) -> int:
+ total_memory = comfy.model_management.get_total_memory(dev=device)
+
+ if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
+ return MIN_CHUNK_SIZE
+ if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
+ return MAX_CHUNK_SIZE
+
+ interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
+ MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
+ )
+ return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
class Decoder(nn.Module):
r"""
@@ -525,8 +541,11 @@ class Decoder(nn.Module):
timestep_shift_scale = ada_values.unbind(dim=1)
output = []
+ max_chunk_size = get_max_chunk_size(sample.device)
- def run_up(idx, sample, ended):
+ def run_up(idx, sample_ref, ended):
+ sample = sample_ref[0]
+ sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
@@ -554,13 +573,21 @@ class Decoder(nn.Module):
return
total_bytes = sample.numel() * sample.element_size()
- num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
- samples = torch.chunk(sample, chunks=num_chunks, dim=2)
+ num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
- for chunk_idx, sample1 in enumerate(samples):
- run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
+ if num_chunks == 1:
+ # when we are not chunking, detach our x so the callee can free it as soon as they are done
+ next_sample_ref = [sample]
+ del sample
+ run_up(idx + 1, next_sample_ref, ended)
+ return
+ else:
+ samples = torch.chunk(sample, chunks=num_chunks, dim=2)
- run_up(0, sample, True)
+ for chunk_idx, sample1 in enumerate(samples):
+ run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
+
+ run_up(0, [sample], True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
From 035414ede49c1b043ea6de054ca512bcbf0f6b35 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 17 Mar 2026 14:34:39 -0700
Subject: [PATCH 46/80] Reduce WAN VAE VRAM, Save use cases for OOM/Tiler
(#13014)
* wan: vae: encoder: Add feature cache layer that corks singles
If a downsample only gives you a single frame, save it to the feature
cache and return nothing to the top level. This increases the
efficiency of cacheability, but also prepares support for going two
by two rather than four by four on the frames.
* wan: remove all concatentation with the feature cache
The loopers are now responsible for ensuring that non-final frames are
processes at least two-by-two, elimiating the need for this cat case.
* wan: vae: recurse and chunk for 2+2 frames on decode
Avoid having to clone off slices of 4 frame chunks and reduce the size
of the big 6 frame convolutions down to 4. Save the VRAMs.
* wan: encode frames 2x2.
Reduce VRAM usage greatly by encoding frames 2 at a time rather than
4.
* wan: vae: remove cloning
The loopers now control the chunking such there is noever more than 2
frames, so just cache these slices directly and avoid the clone
allocations completely.
* wan: vae: free consumer caller tensors on recursion
* wan: vae: restyle a little to match LTX
---
comfy/ldm/wan/vae.py | 180 +++++++++++++++++++------------------------
1 file changed, 81 insertions(+), 99 deletions(-)
diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py
index 71f73c64e..a96b83c6c 100644
--- a/comfy/ldm/wan/vae.py
+++ b/comfy/ldm/wan/vae.py
@@ -99,7 +99,7 @@ class Resample(nn.Module):
else:
self.resample = nn.Identity()
- def forward(self, x, feat_cache=None, feat_idx=[0]):
+ def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
@@ -109,22 +109,7 @@ class Resample(nn.Module):
feat_idx[0] += 1
else:
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[
- idx] is not None and feat_cache[idx] != 'Rep':
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
- if cache_x.shape[2] < 2 and feat_cache[
- idx] is not None and feat_cache[idx] == 'Rep':
- cache_x = torch.cat([
- torch.zeros_like(cache_x).to(cache_x.device),
- cache_x
- ],
- dim=2)
+ cache_x = x[:, :, -CACHE_T:, :, :]
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
@@ -145,19 +130,24 @@ class Resample(nn.Module):
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
- feat_cache[idx] = x.clone()
- feat_idx[0] += 1
+ feat_cache[idx] = x
else:
- cache_x = x[:, :, -1:, :, :].clone()
- # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
- # # cache last frame of last two chunk
- # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
-
+ cache_x = x[:, :, -1:, :, :]
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
- feat_idx[0] += 1
+
+ deferred_x = feat_cache[idx + 1]
+ if deferred_x is not None:
+ x = torch.cat([deferred_x, x], 2)
+ feat_cache[idx + 1] = None
+
+ if x.shape[2] == 1 and not final:
+ feat_cache[idx + 1] = x
+ x = None
+
+ feat_idx[0] += 2
return x
@@ -177,19 +167,12 @@ class ResidualBlock(nn.Module):
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
- def forward(self, x, feat_cache=None, feat_idx=[0]):
+ def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
+ cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -213,7 +196,7 @@ class AttentionBlock(nn.Module):
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()
- def forward(self, x):
+ def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
@@ -283,17 +266,10 @@ class Encoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
- def forward(self, x, feat_cache=None, feat_idx=[0]):
+ def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
if feat_cache is not None:
idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
+ cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -303,14 +279,16 @@ class Encoder3d(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
+ x = layer(x, feat_cache, feat_idx, final=final)
+ if x is None:
+ return None
else:
x = layer(x)
## middle
for layer in self.middle:
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx, final=final)
else:
x = layer(x)
@@ -318,14 +296,7 @@ class Encoder3d(nn.Module):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
+ cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -393,14 +364,7 @@ class Decoder3d(nn.Module):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
+ cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -409,42 +373,56 @@ class Decoder3d(nn.Module):
## middle
for layer in self.middle:
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
- else:
- x = layer(x)
-
- ## upsamples
- for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
- ## head
- for layer in self.head:
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
- idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
- x = layer(x, feat_cache[idx])
- feat_cache[idx] = cache_x
- feat_idx[0] += 1
+ out_chunks = []
+
+ def run_up(layer_idx, x_ref, feat_idx):
+ x = x_ref[0]
+ x_ref[0] = None
+ if layer_idx >= len(self.upsamples):
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ cache_x = x[:, :, -CACHE_T:, :, :]
+ x = layer(x, feat_cache[feat_idx[0]])
+ feat_cache[feat_idx[0]] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ out_chunks.append(x)
+ return
+
+ layer = self.upsamples[layer_idx]
+ if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
+ for frame_idx in range(x.shape[2]):
+ run_up(
+ layer_idx,
+ [x[:, :, frame_idx:frame_idx + 1, :, :]],
+ feat_idx.copy(),
+ )
+ del x
+ return
+
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
- return x
+
+ next_x_ref = [x]
+ del x
+ run_up(layer_idx + 1, next_x_ref, feat_idx)
+
+ run_up(0, [x], feat_idx)
+ return out_chunks
-def count_conv3d(model):
+def count_cache_layers(model):
count = 0
for m in model.modules():
- if isinstance(m, CausalConv3d):
+ if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
count += 1
return count
@@ -482,11 +460,12 @@ class WanVAE(nn.Module):
conv_idx = [0]
## cache
t = x.shape[2]
- iter_ = 1 + (t - 1) // 4
+ t = 1 + ((t - 1) // 4) * 4
+ iter_ = 1 + (t - 1) // 2
feat_map = None
if iter_ > 1:
- feat_map = [None] * count_conv3d(self.encoder)
- ## 对encode输入的x,按时间拆分为1、4、4、4....
+ feat_map = [None] * count_cache_layers(self.encoder)
+ ## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
for i in range(iter_):
conv_idx = [0]
if i == 0:
@@ -496,20 +475,23 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.encoder(
- x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
+ x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
- feat_idx=conv_idx)
+ feat_idx=conv_idx,
+ final=(i == (iter_ - 1)))
+ if out_ is None:
+ continue
out = torch.cat([out, out_], 2)
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
return mu
def decode(self, z):
- conv_idx = [0]
# z: [b,c,t,h,w]
- iter_ = z.shape[2]
+ iter_ = 1 + z.shape[2] // 2
feat_map = None
if iter_ > 1:
- feat_map = [None] * count_conv3d(self.decoder)
+ feat_map = [None] * count_cache_layers(self.decoder)
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
@@ -520,8 +502,8 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.decoder(
- x[:, :, i:i + 1, :, :],
+ x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
- out = torch.cat([out, out_], 2)
- return out
+ out += out_
+ return torch.cat(out, 2)
From 8b9d039f26f5230ab3d3d6d9dd5d55590681b970 Mon Sep 17 00:00:00 2001
From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com>
Date: Wed, 18 Mar 2026 07:17:03 +0900
Subject: [PATCH 47/80] bump manager version to 4.1b6 (#13022)
---
manager_requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/manager_requirements.txt b/manager_requirements.txt
index 1c5e8f071..5b06b56f6 100644
--- a/manager_requirements.txt
+++ b/manager_requirements.txt
@@ -1 +1 @@
-comfyui_manager==4.1b5
\ No newline at end of file
+comfyui_manager==4.1b6
\ No newline at end of file
From 735a0465e5daf1f77909b553b02a9d16d1671be9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Wed, 18 Mar 2026 02:20:49 +0200
Subject: [PATCH 48/80] Inplace VAE output processing to reduce peak RAM
consumption. (#13028)
---
comfy/sd.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/sd.py b/comfy/sd.py
index 4d427bb9a..652e76d3e 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -455,7 +455,7 @@ class VAE:
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
- self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
+ self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
From 68d542cc0602132d3d2fe624ee7077e44b0fb0ab Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Tue, 17 Mar 2026 17:46:22 -0700
Subject: [PATCH 49/80] Fix case where pixel space VAE could cause issues.
(#13030)
---
comfy/sd.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfy/sd.py b/comfy/sd.py
index 652e76d3e..df0c4d1d1 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -952,8 +952,8 @@ class VAE:
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
- samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
- out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
+ samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
+ out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out
From cad24ce26278a72095d33a2b4391572573201542 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 17 Mar 2026 17:59:10 -0700
Subject: [PATCH 50/80] cascade: remove dead weight init code (#13026)
This weight init process is fully shadowed be the weight load and
doesnt work in dynamic_vram were the weight allocation is deferred.
---
comfy/ldm/cascade/stage_a.py | 11 +----------
1 file changed, 1 insertion(+), 10 deletions(-)
diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py
index 145e6e69a..e4e30cacd 100644
--- a/comfy/ldm/cascade/stage_a.py
+++ b/comfy/ldm/cascade/stage_a.py
@@ -136,16 +136,7 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c),
)
- self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
-
- # Init weights
- def _basic_init(module):
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
- torch.nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- nn.init.constant_(module.bias, 0)
-
- self.apply(_basic_init)
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
From b941913f1d2d11dc69c098a375309b13c13bca23 Mon Sep 17 00:00:00 2001
From: Anton Bukov
Date: Wed, 18 Mar 2026 05:21:32 +0400
Subject: [PATCH 51/80] fix: run text encoders on MPS GPU instead of CPU for
Apple Silicon (#12809)
On Apple Silicon, `vram_state` is set to `VRAMState.SHARED` because
CPU and GPU share unified memory. However, `text_encoder_device()`
only checked for `HIGH_VRAM` and `NORMAL_VRAM`, causing all text
encoders to fall back to CPU on MPS devices.
Adding `VRAMState.SHARED` to the condition allows non-quantized text
encoders (e.g. bf16 Gemma 3 12B) to run on the MPS GPU, providing
significant speedup for text encoding and prompt generation.
Note: quantized models (fp4/fp8) that use float8_e4m3fn internally
will still fall back to CPU via the `supports_cast()` check in
`CLIP.__init__()`, since MPS does not support fp8 dtypes.
---
comfy/model_management.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 2c250dacc..5f2e6ef67 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -1003,7 +1003,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
- elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
+ elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:
From 06957022d4cc6f91e101cf5afdd421e462f820c0 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 18 Mar 2026 19:21:58 +0200
Subject: [PATCH 52/80] fix(api-nodes): add support for "thought_image" in Nano
Banana 2 and corrected price badges (#13038)
---
comfy_api_nodes/apis/gemini.py | 1 +
comfy_api_nodes/nodes_gemini.py | 17 ++++++++++++++---
2 files changed, 15 insertions(+), 3 deletions(-)
diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py
index 639035fef..22879fe18 100644
--- a/comfy_api_nodes/apis/gemini.py
+++ b/comfy_api_nodes/apis/gemini.py
@@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)
+ thought: bool | None = Field(None)
class GeminiTextPart(BaseModel):
diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py
index 8225ea67e..25d747e76 100644
--- a/comfy_api_nodes/nodes_gemini.py
+++ b/comfy_api_nodes/nodes_gemini.py
@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
$m := widgets.model;
$r := widgets.resolution;
$isFlash := $contains($m, "nano banana 2");
- $flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
+ $flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
$prices := $isFlash ? $flashPrices : $proPrices;
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
@@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
-async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
+async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/*")
for part in parts:
+ if (part.thought is True) != thought:
+ continue
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
@@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
outputs=[
IO.Image.Output(),
IO.String.Output(),
+ IO.Image.Output(
+ display_name="thought_image",
+ tooltip="First image from the model's thinking process. "
+ "Only available with thinking_level HIGH and IMAGE+TEXT modality.",
+ ),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
- return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
+ return IO.NodeOutput(
+ await get_image_from_response(response),
+ get_text_from_response(response),
+ await get_image_from_response(response, thought=True),
+ )
class GeminiExtension(ComfyExtension):
From b67ed2a45fad8322629289b3347ea15f8926cd45 Mon Sep 17 00:00:00 2001
From: Alexander Brown
Date: Wed, 18 Mar 2026 13:36:39 -0700
Subject: [PATCH 53/80] Update comfyui-frontend-package version to 1.41.21
(#13035)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 0ce163f71..ad0344ed4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.41.20
+comfyui-frontend-package==1.41.21
comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3
torch
From dcd659590faac35a1ac36393077f4ab8aac3fea8 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 18 Mar 2026 15:14:18 -0700
Subject: [PATCH 54/80] Make more intermediate values follow the intermediate
dtype. (#13051)
---
comfy/sample.py | 4 ++--
comfy/sd1_clip.py | 8 ++++----
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/comfy/sample.py b/comfy/sample.py
index a2a39b527..e9c2259ab 100644
--- a/comfy/sample.py
+++ b/comfy/sample.py
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
- samples = samples.to(comfy.model_management.intermediate_device())
+ samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
- samples = samples.to(comfy.model_management.intermediate_device())
+ samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index d89550840..f970510ad 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
out, pooled = o[:2]
if pooled is not None:
- first_pooled = pooled[0:1].to(model_management.intermediate_device())
+ first_pooled = pooled[0:1].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
else:
first_pooled = pooled
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
- r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
+ r = (out[-1:].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
else:
- r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
+ r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
- v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
+ v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
extra[k] = v
r = r + (extra,)
From 9fff091f354815378b913c6e0ee3a39c0ed79a70 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Thu, 19 Mar 2026 00:32:26 +0200
Subject: [PATCH 55/80] Further Reduce LTX VAE decode peak RAM usage (#13052)
---
.../vae/causal_video_autoencoder.py | 42 +++++++++++++++----
comfy/sd.py | 19 +++++++--
2 files changed, 48 insertions(+), 13 deletions(-)
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index 0504140ef..f7aae26da 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -473,6 +473,17 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False
+ # Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
+ ts, hs, ws, to = 1, 1, 1, 0
+ for block in self.up_blocks:
+ if isinstance(block, DepthToSpaceUpsample):
+ ts *= block.stride[0]
+ hs *= block.stride[1]
+ ws *= block.stride[2]
+ if block.stride[0] > 1:
+ to = to * block.stride[0] + 1
+ self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
+
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
@@ -494,11 +505,15 @@ class Decoder(nn.Module):
)
- # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
+ def decode_output_shape(self, input_shape):
+ c, (ts, hs, ws), to = self._output_scale
+ return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
+
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
+ output_buffer: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
@@ -540,7 +555,13 @@ class Decoder(nn.Module):
)
timestep_shift_scale = ada_values.unbind(dim=1)
- output = []
+ if output_buffer is None:
+ output_buffer = torch.empty(
+ self.decode_output_shape(sample.shape),
+ dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
+ )
+ output_offset = [0]
+
max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample_ref, ended):
@@ -556,7 +577,10 @@ class Decoder(nn.Module):
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
- output.append(sample.to(comfy.model_management.intermediate_device()))
+ sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
+ t = sample.shape[2]
+ output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
+ output_offset[0] += t
return
up_block = self.up_blocks[idx]
@@ -588,11 +612,8 @@ class Decoder(nn.Module):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
run_up(0, [sample], True)
- sample = torch.cat(output, dim=2)
- sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
-
- return sample
+ return output_buffer
def forward(self, *args, **kwargs):
try:
@@ -1226,7 +1247,10 @@ class VideoVAE(nn.Module):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
- def decode(self, x):
+ def decode_output_shape(self, input_shape):
+ return self.decoder.decode_output_shape(input_shape)
+
+ def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
- return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
+ return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
diff --git a/comfy/sd.py b/comfy/sd.py
index df0c4d1d1..1f9510959 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -951,12 +951,23 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
+ # Pre-allocate output for VAEs that support direct buffer writes
+ preallocated = False
+ if hasattr(self.first_stage_model, 'decode_output_shape'):
+ pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
+ preallocated = True
+
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
- out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
- if pixel_samples is None:
- pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
- pixel_samples[x:x+batch_number] = out
+ if preallocated:
+ self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
+ else:
+ out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
+ if pixel_samples is None:
+ pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
+ pixel_samples[x:x+batch_number].copy_(out)
+ del out
+ self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
From 56ff88f9511c4e25cd8ac08b2bfcd21c8ad83121 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 18 Mar 2026 15:35:25 -0700
Subject: [PATCH 56/80] Fix regression. (#13053)
---
comfy/sd1_clip.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index f970510ad..a85170b26 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -72,7 +72,7 @@ class ClipTokenWeightEncoder:
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
- v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
+ v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
extra[k] = v
r = r + (extra,)
From f6b869d7d35f7160bf2fdeabaed378d737834540 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 18 Mar 2026 16:42:28 -0700
Subject: [PATCH 57/80] fp16 intermediates doen't work for some text enc
models. (#13056)
---
comfy/sd1_clip.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index a85170b26..0eb30df27 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
out, pooled = o[:2]
if pooled is not None:
- first_pooled = pooled[0:1].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
+ first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
else:
first_pooled = pooled
@@ -63,9 +63,9 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
- r = (out[-1:].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
+ r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
else:
- r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
+ r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
if len(o) > 2:
extra = {}
From fabed694a2198b1662d521b1c47e11e625601ebe Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Thu, 19 Mar 2026 09:58:47 -0700
Subject: [PATCH 58/80] ltx: vae: implement chunked encoder + CPU IO chunking
(Big VRAM reductions) (#13062)
* ltx: vae: add cache state to downsample block
* ltx: vae: Add time stride awareness to causal_conv_3d
* ltx: vae: Automate truncation for encoder
Other VAEs just truncate without error. Do the same.
* sd/ltx: Make chunked_io a flag in its own right
Taking this bi-direcitonal, so make it a for-purpose named flag.
* ltx: vae: implement chunked encoder + CPU IO chunking
People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.
* ltx: vae-encode: round chunk sizes more strictly
Only powers of 2 and multiple of 8 are valid due to cache slicing.
---
comfy/ldm/lightricks/vae/causal_conv3d.py | 16 +++-
.../vae/causal_video_autoencoder.py | 91 +++++++++++++++----
comfy/sd.py | 11 ++-
3 files changed, 92 insertions(+), 26 deletions(-)
diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py
index 356394239..7515f0d4e 100644
--- a/comfy/ldm/lightricks/vae/causal_conv3d.py
+++ b/comfy/ldm/lightricks/vae/causal_conv3d.py
@@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels
+ if isinstance(stride, int):
+ self.time_stride = stride
+ else:
+ self.time_stride = stride[0]
+
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
@@ -58,18 +63,23 @@ class CausalConv3d(nn.Module):
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
+ input_length = sum([piece.shape[2] for piece in pieces])
+ cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
needs_caching = not is_end
- if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
+ if needs_caching and cache_length == 0:
+ self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
needs_caching = False
- self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
+ if needs_caching and x.shape[2] >= cache_length:
+ needs_caching = False
+ self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
- self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
+ self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index f7aae26da..1a15cafd0 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -233,10 +233,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
- def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
- r"""The forward method of the `Encoder` class."""
-
- sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
+ def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
sample = self.conv_in(sample)
checkpoint_fn = (
@@ -247,10 +244,14 @@ class Encoder(nn.Module):
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
+ if sample is None or sample.shape[2] == 0:
+ return None
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
+ if sample is None or sample.shape[2] == 0:
+ return None
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
@@ -282,9 +283,35 @@ class Encoder(nn.Module):
return sample
+ def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
+ r"""The forward method of the `Encoder` class."""
+
+ max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
+ frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
+ frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
+
+ outputs = []
+ samples = [sample[:, :, :1, :, :]]
+ if sample.shape[2] > 1:
+ chunk_t = max(2, max_chunk_size // frame_size)
+ if chunk_t < 4:
+ chunk_t = 2
+ elif chunk_t < 8:
+ chunk_t = 4
+ else:
+ chunk_t = (chunk_t // 8) * 8
+ samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
+ for chunk_idx, chunk in enumerate(samples):
+ if chunk_idx == len(samples) - 1:
+ mark_conv3d_ended(self)
+ chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
+ output = self._forward_chunk(chunk)
+ if output is not None:
+ outputs.append(output)
+
+ return torch_cat_if_needed(outputs, dim=2)
+
def forward(self, *args, **kwargs):
- #No encoder support so just flag the end so it doesnt use the cache.
- mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
@@ -737,12 +764,25 @@ class SpaceToDepthDownsample(nn.Module):
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
+ self.temporal_cache_state = {}
def forward(self, x, causal: bool = True):
- if self.stride[0] == 2:
+ tid = threading.get_ident()
+ cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
+ if cached_input is not None:
+ x = torch_cat_if_needed([cached_input, x], dim=2)
+ cached_input = None
+
+ if self.stride[0] == 2 and pad_first:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
+ pad_first = False
+
+ if x.shape[2] < self.stride[0]:
+ cached_input = x
+ self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
+ return None
# skip connection
x_in = rearrange(
@@ -757,15 +797,26 @@ class SpaceToDepthDownsample(nn.Module):
# conv
x = self.conv(x, causal=causal)
- x = rearrange(
- x,
- "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
- p1=self.stride[0],
- p2=self.stride[1],
- p3=self.stride[2],
- )
+ if self.stride[0] == 2 and x.shape[2] == 1:
+ if cached_x is not None:
+ x = torch_cat_if_needed([cached_x, x], dim=2)
+ cached_x = None
+ else:
+ cached_x = x
+ x = None
- x = x + x_in
+ if x is not None:
+ x = rearrange(
+ x,
+ "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
+ p1=self.stride[0],
+ p2=self.stride[1],
+ p3=self.stride[2],
+ )
+
+ cached = add_exchange_cache(x, cached, x_in, dim=2)
+
+ self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return x
@@ -1098,6 +1149,8 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
+ comfy_has_chunked_io = True
+
def __init__(self, version=0, config=None):
super().__init__()
@@ -1240,11 +1293,9 @@ class VideoVAE(nn.Module):
}
return config
- def encode(self, x):
- frames_count = x.shape[2]
- if ((frames_count - 1) % 8) != 0:
- raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
- means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
+ def encode(self, x, device=None):
+ x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
+ means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode_output_shape(self, input_shape):
diff --git a/comfy/sd.py b/comfy/sd.py
index 1f9510959..b5e7c93a9 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -953,7 +953,7 @@ class VAE:
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
- if hasattr(self.first_stage_model, 'decode_output_shape'):
+ if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
@@ -1038,8 +1038,13 @@ class VAE:
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
- pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
- out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
+ pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
+ if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
+ out = self.first_stage_model.encode(pixels_in, device=self.device)
+ else:
+ pixels_in = pixels_in.to(self.device)
+ out = self.first_stage_model.encode(pixels_in)
+ out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
From 6589562ae3e35dd7694f430629a805306157f530 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Thu, 19 Mar 2026 10:01:12 -0700
Subject: [PATCH 59/80] ltx: vae: implement chunked encoder + CPU IO chunking
(Big VRAM reductions) (#13062)
* ltx: vae: add cache state to downsample block
* ltx: vae: Add time stride awareness to causal_conv_3d
* ltx: vae: Automate truncation for encoder
Other VAEs just truncate without error. Do the same.
* sd/ltx: Make chunked_io a flag in its own right
Taking this bi-direcitonal, so make it a for-purpose named flag.
* ltx: vae: implement chunked encoder + CPU IO chunking
People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.
* ltx: vae-encode: round chunk sizes more strictly
Only powers of 2 and multiple of 8 are valid due to cache slicing.
From ab14541ef7965dc61956c447d3066dd3d5c9f33b Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Thu, 19 Mar 2026 10:03:20 -0700
Subject: [PATCH 60/80] memory: Add more exclusion criteria to pinned read
(#13067)
---
comfy/memory_management.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/comfy/memory_management.py b/comfy/memory_management.py
index 563224098..f9078fe7c 100644
--- a/comfy/memory_management.py
+++ b/comfy/memory_management.py
@@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination):
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
- or destination.numel() * destination.element_size() < info.size):
+ or destination.numel() * destination.element_size() < info.size
+ or tensor.numel() * tensor.element_size() != info.size
+ or tensor.storage_offset() != 0
+ or not tensor.is_contiguous()):
return False
if info.size == 0:
From fd0261d2bc0c32fa6c21d20994702f44fd927d4c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Thu, 19 Mar 2026 19:29:34 +0200
Subject: [PATCH 61/80] Reduce tiled decode peak memory (#13050)
---
comfy/utils.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/comfy/utils.py b/comfy/utils.py
index 13b7ca6c8..78c491b98 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pbar.update(1)
continue
- out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
- out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
+ out = output[b:b+1].zero_()
+ out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
@@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
- mask = torch.ones_like(ps)
+ mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
@@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if pbar is not None:
pbar.update(1)
- output[b:b+1] = out/out_div
+ out.div_(out_div)
return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
From 8458ae2686a8d62ee206d3903123868425a4e6a7 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 19 Mar 2026 12:27:55 -0700
Subject: [PATCH 62/80] =?UTF-8?q?Revert=20"fix:=20run=20text=20encoders=20?=
=?UTF-8?q?on=20MPS=20GPU=20instead=20of=20CPU=20for=20Apple=20Silicon=20(?=
=?UTF-8?q?#=E2=80=A6"=20(#13070)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This reverts commit b941913f1d2d11dc69c098a375309b13c13bca23.
---
comfy/model_management.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 5f2e6ef67..2c250dacc 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -1003,7 +1003,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
- elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
+ elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:
From 82b868a45a753c875677091d0a91bb5bbaf04cbe Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Thu, 19 Mar 2026 19:30:27 -0700
Subject: [PATCH 63/80] Fix VRAM leak in tiler fallback in video VAEs (#13073)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* sd: soft_empty_cache on tiler fallback
This doesnt cost a lot and creates the expected VRAM reduction in
resource monitors when you fallback to tiler.
* wan: vae: Don't recursion in local fns (move run_up)
Moved Decoder3d’s recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
* ltx: vae: Don't recursion in local fns (move run_up)
Mov the recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
---
.../vae/causal_video_autoencoder.py | 96 +++++++++----------
comfy/ldm/wan/vae.py | 74 +++++++-------
comfy/sd.py | 2 +
3 files changed, 88 insertions(+), 84 deletions(-)
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index 1a15cafd0..dd1dfeba0 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -536,6 +536,53 @@ class Decoder(nn.Module):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
+ def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
+ sample = sample_ref[0]
+ sample_ref[0] = None
+ if idx >= len(self.up_blocks):
+ sample = self.conv_norm_out(sample)
+ if timestep_shift_scale is not None:
+ shift, scale = timestep_shift_scale
+ sample = sample * (1 + scale) + shift
+ sample = self.conv_act(sample)
+ if ended:
+ mark_conv3d_ended(self.conv_out)
+ sample = self.conv_out(sample, causal=self.causal)
+ if sample is not None and sample.shape[2] > 0:
+ sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
+ t = sample.shape[2]
+ output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
+ output_offset[0] += t
+ return
+
+ up_block = self.up_blocks[idx]
+ if ended:
+ mark_conv3d_ended(up_block)
+ if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
+ sample = checkpoint_fn(up_block)(
+ sample, causal=self.causal, timestep=scaled_timestep
+ )
+ else:
+ sample = checkpoint_fn(up_block)(sample, causal=self.causal)
+
+ if sample is None or sample.shape[2] == 0:
+ return
+
+ total_bytes = sample.numel() * sample.element_size()
+ num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
+
+ if num_chunks == 1:
+ # when we are not chunking, detach our x so the callee can free it as soon as they are done
+ next_sample_ref = [sample]
+ del sample
+ self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
+ return
+ else:
+ samples = torch.chunk(sample, chunks=num_chunks, dim=2)
+
+ for chunk_idx, sample1 in enumerate(samples):
+ self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
+
def forward_orig(
self,
sample: torch.FloatTensor,
@@ -591,54 +638,7 @@ class Decoder(nn.Module):
max_chunk_size = get_max_chunk_size(sample.device)
- def run_up(idx, sample_ref, ended):
- sample = sample_ref[0]
- sample_ref[0] = None
- if idx >= len(self.up_blocks):
- sample = self.conv_norm_out(sample)
- if timestep_shift_scale is not None:
- shift, scale = timestep_shift_scale
- sample = sample * (1 + scale) + shift
- sample = self.conv_act(sample)
- if ended:
- mark_conv3d_ended(self.conv_out)
- sample = self.conv_out(sample, causal=self.causal)
- if sample is not None and sample.shape[2] > 0:
- sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
- t = sample.shape[2]
- output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
- output_offset[0] += t
- return
-
- up_block = self.up_blocks[idx]
- if (ended):
- mark_conv3d_ended(up_block)
- if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
- sample = checkpoint_fn(up_block)(
- sample, causal=self.causal, timestep=scaled_timestep
- )
- else:
- sample = checkpoint_fn(up_block)(sample, causal=self.causal)
-
- if sample is None or sample.shape[2] == 0:
- return
-
- total_bytes = sample.numel() * sample.element_size()
- num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
-
- if num_chunks == 1:
- # when we are not chunking, detach our x so the callee can free it as soon as they are done
- next_sample_ref = [sample]
- del sample
- run_up(idx + 1, next_sample_ref, ended)
- return
- else:
- samples = torch.chunk(sample, chunks=num_chunks, dim=2)
-
- for chunk_idx, sample1 in enumerate(samples):
- run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
-
- run_up(0, [sample], True)
+ self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return output_buffer
diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py
index a96b83c6c..deeb8695b 100644
--- a/comfy/ldm/wan/vae.py
+++ b/comfy/ldm/wan/vae.py
@@ -360,6 +360,43 @@ class Decoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))
+ def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
+ x = x_ref[0]
+ x_ref[0] = None
+ if layer_idx >= len(self.upsamples):
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ cache_x = x[:, :, -CACHE_T:, :, :]
+ x = layer(x, feat_cache[feat_idx[0]])
+ feat_cache[feat_idx[0]] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ out_chunks.append(x)
+ return
+
+ layer = self.upsamples[layer_idx]
+ if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
+ for frame_idx in range(x.shape[2]):
+ self.run_up(
+ layer_idx,
+ [x[:, :, frame_idx:frame_idx + 1, :, :]],
+ feat_cache,
+ feat_idx.copy(),
+ out_chunks,
+ )
+ del x
+ return
+
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ next_x_ref = [x]
+ del x
+ self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
@@ -380,42 +417,7 @@ class Decoder3d(nn.Module):
out_chunks = []
- def run_up(layer_idx, x_ref, feat_idx):
- x = x_ref[0]
- x_ref[0] = None
- if layer_idx >= len(self.upsamples):
- for layer in self.head:
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
- cache_x = x[:, :, -CACHE_T:, :, :]
- x = layer(x, feat_cache[feat_idx[0]])
- feat_cache[feat_idx[0]] = cache_x
- feat_idx[0] += 1
- else:
- x = layer(x)
- out_chunks.append(x)
- return
-
- layer = self.upsamples[layer_idx]
- if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
- for frame_idx in range(x.shape[2]):
- run_up(
- layer_idx,
- [x[:, :, frame_idx:frame_idx + 1, :, :]],
- feat_idx.copy(),
- )
- del x
- return
-
- if feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
- else:
- x = layer(x)
-
- next_x_ref = [x]
- del x
- run_up(layer_idx + 1, next_x_ref, feat_idx)
-
- run_up(0, [x], feat_idx)
+ self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
return out_chunks
diff --git a/comfy/sd.py b/comfy/sd.py
index b5e7c93a9..e207bb0fd 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -978,6 +978,7 @@ class VAE:
do_tile = True
if do_tile:
+ comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -1059,6 +1060,7 @@ class VAE:
do_tile = True
if do_tile:
+ comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
From f49856af57888f60d09f470a6509456f5ee23c99 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Thu, 19 Mar 2026 19:34:58 -0700
Subject: [PATCH 64/80] ltx: vae: Fix missing init variable (#13074)
Forgot to push this ammendment. Previous test results apply to this.
---
comfy/ldm/lightricks/vae/causal_video_autoencoder.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index dd1dfeba0..998122c85 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -602,6 +602,7 @@ class Decoder(nn.Module):
)
timestep_shift_scale = None
+ scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
From e4455fd43acd3f975905455ace7497136962968a Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Fri, 20 Mar 2026 05:05:01 +0200
Subject: [PATCH 65/80] [API Nodes] mark seedream-3-0-t2i and seedance-1-0-lite
models as deprecated (#13060)
* chore(api-nodes): mark seedream-3-0-t2i and seedance-1-0-lite models as deprecated
* fix(api-nodes): fixed old regression in the ByteDanceImageReference node
---------
Co-authored-by: Jedrzej Kosinski
---
comfy_api_nodes/nodes_bytedance.py | 13 ++++++++++++-
1 file changed, 12 insertions(+), 1 deletion(-)
diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py
index 6dbd5984e..de0c22e70 100644
--- a/comfy_api_nodes/nodes_bytedance.py
+++ b/comfy_api_nodes/nodes_bytedance.py
@@ -47,6 +47,10 @@ SEEDREAM_MODELS = {
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
+DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
+
+logger = logging.getLogger(__name__)
+
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
@@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode):
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
+ is_deprecated=True,
)
@classmethod
@@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
]
return await process_video_task(
cls,
- payload=Image2VideoTaskCreationRequest(model=model, content=x),
+ payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
@@ -952,6 +957,12 @@ async def process_video_task(
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
+ if payload.model in DEPRECATED_MODELS:
+ logger.warning(
+ "Model '%s' is deprecated and will be deactivated on May 13, 2026. "
+ "Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
+ payload.model,
+ )
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
From 589228e671e84518bf77919ee4e574749ab772c8 Mon Sep 17 00:00:00 2001
From: drozbay <17261091+drozbay@users.noreply.github.com>
Date: Thu, 19 Mar 2026 21:42:42 -0600
Subject: [PATCH 66/80] Add slice_cond and per-model context window cond
resizing (#12645)
* Add slice_cond and per-model context window cond resizing
* Fix cond_value.size() call in context window cond resizing
* Expose additional advanced inputs for ContextWindowsManualNode
Necessary for WanAnimate context windows workflow, which needs cond_retain_index_list = 0 to work properly with its reference input.
---------
---
comfy/context_windows.py | 54 ++++++++++++++++++++++++++-
comfy/model_base.py | 32 ++++++++++++++++
comfy_extras/nodes_context_windows.py | 4 +-
3 files changed, 87 insertions(+), 3 deletions(-)
diff --git a/comfy/context_windows.py b/comfy/context_windows.py
index b54f7f39a..cb44ee6e8 100644
--- a/comfy/context_windows.py
+++ b/comfy/context_windows.py
@@ -93,6 +93,50 @@ class IndexListCallbacks:
return {}
+def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
+ if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
+ return None
+ cond_tensor = cond_value.cond
+ if temporal_dim >= cond_tensor.ndim:
+ return None
+
+ cond_size = cond_tensor.size(temporal_dim)
+
+ if temporal_scale == 1:
+ expected_size = x_in.size(window.dim) - temporal_offset
+ if cond_size != expected_size:
+ return None
+
+ if temporal_offset == 0 and temporal_scale == 1:
+ sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
+ return cond_value._copy_with(sliced)
+
+ # skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
+ if temporal_offset > 0:
+ indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
+ indices = [i for i in indices if 0 <= i]
+ else:
+ indices = list(window.index_list)
+
+ if not indices:
+ return None
+
+ if temporal_scale > 1:
+ scaled = []
+ for i in indices:
+ for k in range(temporal_scale):
+ si = i * temporal_scale + k
+ if si < cond_size:
+ scaled.append(si)
+ indices = scaled
+ if not indices:
+ return None
+
+ idx = tuple([slice(None)] * temporal_dim + [indices])
+ sliced = cond_tensor[idx].to(device)
+ return cond_value._copy_with(sliced)
+
+
@dataclass
class ContextSchedule:
name: str
@@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item[cond_key] = result
handled = True
break
+ if not handled and self._model is not None:
+ result = self._model.resize_cond_for_context_window(
+ cond_key, cond_value, window, x_in, device,
+ retain_index_list=self.cond_retain_index_list)
+ if result is not None:
+ new_cond_item[cond_key] = result
+ handled = True
if handled:
continue
if isinstance(cond_value, torch.Tensor):
- if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
+ if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
@@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
diff --git a/comfy/model_base.py b/comfy/model_base.py
index d9d5a9293..88905e191 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -285,6 +285,12 @@ class BaseModel(torch.nn.Module):
return data
return None
+ def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
+ """Override in subclasses to handle model-specific cond slicing for context windows.
+ Return a sliced cond object, or None to fall through to default handling.
+ Use comfy.context_windows.slice_cond() for common cases."""
+ return None
+
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
@@ -1375,6 +1381,12 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
+ def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
+ if cond_key == "vace_context":
+ import comfy.context_windows
+ return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
+ return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
+
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
@@ -1427,6 +1439,12 @@ class WAN21_HuMo(WAN21):
return out
+ def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
+ if cond_key == "audio_embed":
+ import comfy.context_windows
+ return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
+ return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
+
class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
@@ -1444,6 +1462,14 @@ class WAN22_Animate(WAN21):
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out
+ def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
+ import comfy.context_windows
+ if cond_key == "face_pixel_values":
+ return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
+ if cond_key == "pose_latents":
+ return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
+ return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
+
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@@ -1480,6 +1506,12 @@ class WAN22_S2V(WAN21):
out['reference_motion'] = reference_motion.shape
return out
+ def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
+ if cond_key == "audio_embed":
+ import comfy.context_windows
+ return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
+ return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
+
class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py
index 93a5204e1..0e43f2e44 100644
--- a/comfy_extras/nodes_context_windows.py
+++ b/comfy_extras/nodes_context_windows.py
@@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
- #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
- #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
+ io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
+ io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),
From c646d211be359df56617ffabcdd43cb53e191e97 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Fri, 20 Mar 2026 21:23:16 +0200
Subject: [PATCH 67/80] feat(api-nodes): add Quiver SVG nodes (#13047)
---
comfy_api_nodes/apis/quiver.py | 43 +++++
comfy_api_nodes/nodes_quiver.py | 291 ++++++++++++++++++++++++++++++++
2 files changed, 334 insertions(+)
create mode 100644 comfy_api_nodes/apis/quiver.py
create mode 100644 comfy_api_nodes/nodes_quiver.py
diff --git a/comfy_api_nodes/apis/quiver.py b/comfy_api_nodes/apis/quiver.py
new file mode 100644
index 000000000..bc8708754
--- /dev/null
+++ b/comfy_api_nodes/apis/quiver.py
@@ -0,0 +1,43 @@
+from pydantic import BaseModel, Field
+
+
+class QuiverImageObject(BaseModel):
+ url: str = Field(...)
+
+
+class QuiverTextToSVGRequest(BaseModel):
+ model: str = Field(default="arrow-preview")
+ prompt: str = Field(...)
+ instructions: str | None = Field(default=None)
+ references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
+ temperature: float | None = Field(default=None, ge=0, le=2)
+ top_p: float | None = Field(default=None, ge=0, le=1)
+ presence_penalty: float | None = Field(default=None, ge=-2, le=2)
+
+
+class QuiverImageToSVGRequest(BaseModel):
+ model: str = Field(default="arrow-preview")
+ image: QuiverImageObject = Field(...)
+ auto_crop: bool | None = Field(default=None)
+ target_size: int | None = Field(default=None, ge=128, le=4096)
+ temperature: float | None = Field(default=None, ge=0, le=2)
+ top_p: float | None = Field(default=None, ge=0, le=1)
+ presence_penalty: float | None = Field(default=None, ge=-2, le=2)
+
+
+class QuiverSVGResponseItem(BaseModel):
+ svg: str = Field(...)
+ mime_type: str | None = Field(default="image/svg+xml")
+
+
+class QuiverSVGUsage(BaseModel):
+ total_tokens: int | None = Field(default=None)
+ input_tokens: int | None = Field(default=None)
+ output_tokens: int | None = Field(default=None)
+
+
+class QuiverSVGResponse(BaseModel):
+ id: str | None = Field(default=None)
+ created: int | None = Field(default=None)
+ data: list[QuiverSVGResponseItem] = Field(...)
+ usage: QuiverSVGUsage | None = Field(default=None)
diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py
new file mode 100644
index 000000000..61533263f
--- /dev/null
+++ b/comfy_api_nodes/nodes_quiver.py
@@ -0,0 +1,291 @@
+from io import BytesIO
+
+from typing_extensions import override
+
+from comfy_api.latest import IO, ComfyExtension
+from comfy_api_nodes.apis.quiver import (
+ QuiverImageObject,
+ QuiverImageToSVGRequest,
+ QuiverSVGResponse,
+ QuiverTextToSVGRequest,
+)
+from comfy_api_nodes.util import (
+ ApiEndpoint,
+ sync_op,
+ upload_image_to_comfyapi,
+ validate_string,
+)
+from comfy_extras.nodes_images import SVG
+
+
+class QuiverTextToSVGNode(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="QuiverTextToSVGNode",
+ display_name="Quiver Text to SVG",
+ category="api node/image/Quiver",
+ description="Generate an SVG from a text prompt using Quiver AI.",
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Text description of the desired SVG output.",
+ ),
+ IO.String.Input(
+ "instructions",
+ multiline=True,
+ default="",
+ tooltip="Additional style or formatting guidance.",
+ optional=True,
+ ),
+ IO.Autogrow.Input(
+ "reference_images",
+ template=IO.Autogrow.TemplatePrefix(
+ IO.Image.Input("image"),
+ prefix="ref_",
+ min=0,
+ max=4,
+ ),
+ tooltip="Up to 4 reference images to guide the generation.",
+ optional=True,
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "arrow-preview",
+ [
+ IO.Float.Input(
+ "temperature",
+ default=1.0,
+ min=0.0,
+ max=2.0,
+ step=0.1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Randomness control. Higher values increase randomness.",
+ advanced=True,
+ ),
+ IO.Float.Input(
+ "top_p",
+ default=1.0,
+ min=0.05,
+ max=1.0,
+ step=0.05,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Nucleus sampling parameter.",
+ advanced=True,
+ ),
+ IO.Float.Input(
+ "presence_penalty",
+ default=0.0,
+ min=-2.0,
+ max=2.0,
+ step=0.1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Token presence penalty.",
+ advanced=True,
+ ),
+ ],
+ ),
+ ],
+ tooltip="Model to use for SVG generation.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Seed to determine if node should re-run; "
+ "actual results are nondeterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.SVG.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.429}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ model: dict,
+ seed: int,
+ instructions: str = None,
+ reference_images: IO.Autogrow.Type = None,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, strip_whitespace=False, min_length=1)
+
+ references = None
+ if reference_images:
+ references = []
+ for key in reference_images:
+ url = await upload_image_to_comfyapi(cls, reference_images[key])
+ references.append(QuiverImageObject(url=url))
+ if len(references) > 4:
+ raise ValueError("Maximum 4 reference images are allowed.")
+
+ instructions_val = instructions.strip() if instructions else None
+ if instructions_val == "":
+ instructions_val = None
+
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
+ response_model=QuiverSVGResponse,
+ data=QuiverTextToSVGRequest(
+ model=model["model"],
+ prompt=prompt,
+ instructions=instructions_val,
+ references=references,
+ temperature=model.get("temperature"),
+ top_p=model.get("top_p"),
+ presence_penalty=model.get("presence_penalty"),
+ ),
+ )
+
+ svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
+ return IO.NodeOutput(SVG(svg_data))
+
+
+class QuiverImageToSVGNode(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="QuiverImageToSVGNode",
+ display_name="Quiver Image to SVG",
+ category="api node/image/Quiver",
+ description="Vectorize a raster image into SVG using Quiver AI.",
+ inputs=[
+ IO.Image.Input(
+ "image",
+ tooltip="Input image to vectorize.",
+ ),
+ IO.Boolean.Input(
+ "auto_crop",
+ default=False,
+ tooltip="Automatically crop to the dominant subject.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "arrow-preview",
+ [
+ IO.Int.Input(
+ "target_size",
+ default=1024,
+ min=128,
+ max=4096,
+ tooltip="Square resize target in pixels.",
+ ),
+ IO.Float.Input(
+ "temperature",
+ default=1.0,
+ min=0.0,
+ max=2.0,
+ step=0.1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Randomness control. Higher values increase randomness.",
+ advanced=True,
+ ),
+ IO.Float.Input(
+ "top_p",
+ default=1.0,
+ min=0.05,
+ max=1.0,
+ step=0.05,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Nucleus sampling parameter.",
+ advanced=True,
+ ),
+ IO.Float.Input(
+ "presence_penalty",
+ default=0.0,
+ min=-2.0,
+ max=2.0,
+ step=0.1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Token presence penalty.",
+ advanced=True,
+ ),
+ ],
+ ),
+ ],
+ tooltip="Model to use for SVG vectorization.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Seed to determine if node should re-run; "
+ "actual results are nondeterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.SVG.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.429}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ image,
+ auto_crop: bool,
+ model: dict,
+ seed: int,
+ ) -> IO.NodeOutput:
+ image_url = await upload_image_to_comfyapi(cls, image)
+
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
+ response_model=QuiverSVGResponse,
+ data=QuiverImageToSVGRequest(
+ model=model["model"],
+ image=QuiverImageObject(url=image_url),
+ auto_crop=auto_crop if auto_crop else None,
+ target_size=model.get("target_size"),
+ temperature=model.get("temperature"),
+ top_p=model.get("top_p"),
+ presence_penalty=model.get("presence_penalty"),
+ ),
+ )
+
+ svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
+ return IO.NodeOutput(SVG(svg_data))
+
+
+class QuiverExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ QuiverTextToSVGNode,
+ QuiverImageToSVGNode,
+ ]
+
+
+async def comfy_entrypoint() -> QuiverExtension:
+ return QuiverExtension()
From 45d5c83a3005e7fc28ce9e4ff04b77875052eb51 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 20 Mar 2026 13:08:26 -0700
Subject: [PATCH 68/80] Make EmptyImage node follow intermediate device/dtype.
(#13079)
---
nodes.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/nodes.py b/nodes.py
index e93fa9767..2c4650a20 100644
--- a/nodes.py
+++ b/nodes.py
@@ -1966,9 +1966,11 @@ class EmptyImage:
CATEGORY = "image"
def generate(self, width, height, batch_size=1, color=0):
- r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
- g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
- b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
+ dtype = comfy.model_management.intermediate_dtype()
+ device = comfy.model_management.intermediate_device()
+ r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF, device=device, dtype=dtype)
+ g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF, device=device, dtype=dtype)
+ b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF, device=device, dtype=dtype)
return (torch.cat((r, g, b), dim=-1), )
class ImagePadForOutpaint:
From 87cda1fc25ca11a55ede88bf264cfe0a20d340ce Mon Sep 17 00:00:00 2001
From: Jedrzej Kosinski
Date: Fri, 20 Mar 2026 17:03:42 -0700
Subject: [PATCH 69/80] Move inline comfy.context_windows imports to top-level
in model_base.py (#13083)
The recent PR that added resize_cond_for_context_window methods to
model classes used inline 'import comfy.context_windows' in each
method body. This moves that import to the top-level import section,
replacing 4 duplicate inline imports with a single top-level one.
---
comfy/model_base.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 88905e191..43ec93324 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
+import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
@@ -1383,7 +1384,6 @@ class WAN21_Vace(WAN21):
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "vace_context":
- import comfy.context_windows
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
@@ -1441,7 +1441,6 @@ class WAN21_HuMo(WAN21):
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
- import comfy.context_windows
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
@@ -1463,7 +1462,6 @@ class WAN22_Animate(WAN21):
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
- import comfy.context_windows
if cond_key == "face_pixel_values":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
if cond_key == "pose_latents":
@@ -1508,7 +1506,6 @@ class WAN22_S2V(WAN21):
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
- import comfy.context_windows
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
From dc719cde9c448c65242ae2d4ba400ba18c36846f Mon Sep 17 00:00:00 2001
From: comfyanonymous
Date: Fri, 20 Mar 2026 20:09:15 -0400
Subject: [PATCH 70/80] ComfyUI version 0.18.0
---
comfyui_version.py | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfyui_version.py b/comfyui_version.py
index 701f4d66a..a3b7204dc 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.17.0"
+__version__ = "0.18.0"
diff --git a/pyproject.toml b/pyproject.toml
index e2ca79be7..6db9b1267 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.17.0"
+version = "0.18.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
From a11f68dd3b5393b6afc37e01c91fa84963d2668a Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 20 Mar 2026 20:15:50 -0700
Subject: [PATCH 71/80] Fix canny node not working with fp16. (#13085)
---
comfy_extras/nodes_canny.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py
index 5e7c4eabb..648b4279d 100644
--- a/comfy_extras/nodes_canny.py
+++ b/comfy_extras/nodes_canny.py
@@ -3,6 +3,7 @@ from typing_extensions import override
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
+import torch
class Canny(io.ComfyNode):
@@ -29,8 +30,8 @@ class Canny(io.ComfyNode):
@classmethod
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
- output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
- img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
+ output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
+ img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
return io.NodeOutput(img_out)
From b5d32e6ad23f3deb0cd16b5f2afa81ff92d89e6e Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sat, 21 Mar 2026 14:47:42 -0700
Subject: [PATCH 72/80] Fix sampling issue with fp16 intermediates. (#13099)
---
comfy/samplers.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/comfy/samplers.py b/comfy/samplers.py
index 8be449ef7..0a4d062db 100755
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -985,8 +985,8 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
- noise = noise.to(device)
- latent_image = latent_image.to(device)
+ noise = noise.to(device=device, dtype=torch.float32)
+ latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
@@ -1028,6 +1028,7 @@ class CFGGuider:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
+ denoise_mask = denoise_mask.float()
self.conds = {}
for k in self.original_conds:
From 11c15d8832ab8a95ebe31f85c131429978668c76 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sat, 21 Mar 2026 14:53:25 -0700
Subject: [PATCH 73/80] Fix fp16 intermediates giving different results.
(#13100)
---
comfy/sample.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfy/sample.py b/comfy/sample.py
index e9c2259ab..653829582 100644
--- a/comfy/sample.py
+++ b/comfy/sample.py
@@ -8,12 +8,12 @@ import comfy.nested_tensor
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
- return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
+ return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
- noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
+ noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
From 25b6d1d6298c380c1d4de90ff9f38484a84ada19 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Sat, 21 Mar 2026 15:44:35 -0700
Subject: [PATCH 74/80] wan: vae: Fix light/color change (#13101)
There was an issue where the resample split was too early and dropped one
of the rolling convolutions a frame early. This is most noticable as a
lighting/color change between pixel frames 5->6 (latent 2->3), or as a
lighting change between the first and last frame in an FLF wan flow.
---
comfy/ldm/wan/vae.py | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py
index deeb8695b..57b0dabf7 100644
--- a/comfy/ldm/wan/vae.py
+++ b/comfy/ldm/wan/vae.py
@@ -376,11 +376,16 @@ class Decoder3d(nn.Module):
return
layer = self.upsamples[layer_idx]
- if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
- for frame_idx in range(x.shape[2]):
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
+ for frame_idx in range(0, x.shape[2], 2):
self.run_up(
- layer_idx,
- [x[:, :, frame_idx:frame_idx + 1, :, :]],
+ layer_idx + 1,
+ [x[:, :, frame_idx:frame_idx + 2, :, :]],
feat_cache,
feat_idx.copy(),
out_chunks,
@@ -388,11 +393,6 @@ class Decoder3d(nn.Module):
del x
return
- if feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
- else:
- x = layer(x)
-
next_x_ref = [x]
del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
From ebf6b52e322664af91fcdc8b8848d31d5fb98f66 Mon Sep 17 00:00:00 2001
From: comfyanonymous
Date: Sat, 21 Mar 2026 22:32:16 -0400
Subject: [PATCH 75/80] ComfyUI v0.18.1
---
comfyui_version.py | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfyui_version.py b/comfyui_version.py
index a3b7204dc..61d7672ca 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.18.0"
+__version__ = "0.18.1"
diff --git a/pyproject.toml b/pyproject.toml
index 6db9b1267..1fc9402a1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.18.0"
+version = "0.18.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
From d49420b3c7daf86cae1d7419e37848a974e1b7be Mon Sep 17 00:00:00 2001
From: Talmaj
Date: Sun, 22 Mar 2026 04:51:05 +0100
Subject: [PATCH 76/80] LongCat-Image edit (#13003)
---
comfy/ldm/flux/model.py | 2 +-
comfy/model_base.py | 5 +++--
comfy/text_encoders/llama.py | 11 +++++++++--
comfy/text_encoders/longcat_image.py | 25 ++++++++++++++++++++-----
comfy/text_encoders/qwen_vl.py | 3 +++
5 files changed, 36 insertions(+), 10 deletions(-)
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index 8e7912e6d..2020326c2 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -386,7 +386,7 @@ class Flux(nn.Module):
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
- kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
+ kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 43ec93324..bfffe2402 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -937,9 +937,10 @@ class LongCatImage(Flux):
transformer_options = transformer_options.copy()
rope_opts = transformer_options.get("rope_options", {})
rope_opts = dict(rope_opts)
+ pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
rope_opts.setdefault("shift_t", 1.0)
- rope_opts.setdefault("shift_y", 512.0)
- rope_opts.setdefault("shift_x", 512.0)
+ rope_opts.setdefault("shift_y", pe_len)
+ rope_opts.setdefault("shift_x", pe_len)
transformer_options["rope_options"] = rope_opts
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index ccc200b7a..9fdea999c 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
- position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
+ position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
- position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
+ if attention_mask is not None:
+ # Assign compact sequential positions to attended tokens only,
+ # skipping over padding so post-padding tokens aren't inflated.
+ after_mask = attention_mask[0, end:]
+ text_positions = after_mask.cumsum(0) - 1 + start_next + offset
+ position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
+ else:
+ position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
diff --git a/comfy/text_encoders/longcat_image.py b/comfy/text_encoders/longcat_image.py
index 882d80901..0962779e3 100644
--- a/comfy/text_encoders/longcat_image.py
+++ b/comfy/text_encoders/longcat_image.py
@@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
return [output]
+IMAGE_PAD_TOKEN_ID = 151655
+
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
+ T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
+ EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
+ SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
@@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
name="qwen25_7b",
tokenizer=LongCatImageBaseTokenizer,
)
- self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
- self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
- def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
+ def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
skip_template = False
if text.startswith("<|im_start|>"):
skip_template = True
@@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
else:
+ has_images = images is not None and len(images) > 0
+ template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
+
prefix_ids = base_tok.tokenizer(
- self.longcat_template_prefix, add_special_tokens=False
+ template_prefix, add_special_tokens=False
)["input_ids"]
suffix_ids = base_tok.tokenizer(
- self.longcat_template_suffix, add_special_tokens=False
+ self.SUFFIX, add_special_tokens=False
)["input_ids"]
prompt_tokens = base_tok.tokenize_with_weights(
@@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
suffix_pairs = [(t, 1.0) for t in suffix_ids]
combined = prefix_pairs + prompt_pairs + suffix_pairs
+
+ if has_images:
+ embed_count = 0
+ for i in range(len(combined)):
+ if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
+ combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
+ embed_count += 1
+
tokens = {"qwen25_7b": [combined]}
return tokens
diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py
index 3b18ce730..98c350a12 100644
--- a/comfy/text_encoders/qwen_vl.py
+++ b/comfy/text_encoders/qwen_vl.py
@@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module):
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
hidden_states = self.merger(hidden_states)
+ # Potentially important for spatially precise edits. This is present in the HF implementation.
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
return hidden_states
From 6265a239f379f1a5cf2bfdcd3a9631d4c11e50fb Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sun, 22 Mar 2026 15:46:18 -0700
Subject: [PATCH 77/80] Add warning for users who disable dynamic vram.
(#13113)
---
main.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/main.py b/main.py
index f99aee38e..cd4483c67 100644
--- a/main.py
+++ b/main.py
@@ -471,6 +471,9 @@ if __name__ == "__main__":
if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
+ if args.disable_dynamic_vram:
+ logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.")
+
event_loop, _, start_all_func = start_comfyui()
try:
x = start_all_func()
From da6edb5a4e5745869d64ae05b96263da42d5364e Mon Sep 17 00:00:00 2001
From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com>
Date: Tue, 24 Mar 2026 01:59:21 +0900
Subject: [PATCH 78/80] bump manager version to 4.1b8 (#13108)
---
manager_requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/manager_requirements.txt b/manager_requirements.txt
index 5b06b56f6..90a2be84e 100644
--- a/manager_requirements.txt
+++ b/manager_requirements.txt
@@ -1 +1 @@
-comfyui_manager==4.1b6
\ No newline at end of file
+comfyui_manager==4.1b8
From e87858e9743f92222cdb478f1f835135750b6a0b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Tue, 24 Mar 2026 00:22:24 +0200
Subject: [PATCH 79/80] feat: LTX2: Support reference audio (ID-LoRA) (#13111)
---
comfy/ldm/lightricks/av_model.py | 42 +++++++++++++++++
comfy/model_base.py | 4 ++
comfy_extras/nodes_lt.py | 80 ++++++++++++++++++++++++++++++++
3 files changed, 126 insertions(+)
diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py
index 08d686b7b..6f2ba41ef 100644
--- a/comfy/ldm/lightricks/av_model.py
+++ b/comfy/ldm/lightricks/av_model.py
@@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax)
+
+ # Inject reference audio for ID-LoRA in-context conditioning
+ ref_audio = kwargs.get("ref_audio", None)
+ ref_audio_seq_len = 0
+ if ref_audio is not None:
+ ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
+ if ref_tokens.shape[0] < ax.shape[0]:
+ ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
+ ref_audio_seq_len = ref_tokens.shape[1]
+ B = ax.shape[0]
+
+ # Compute negative temporal positions matching ID-LoRA convention:
+ # offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
+ p = self.a_patchifier
+ tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
+ ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
+ ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
+ time_offset = ref_end[-1].item() + tpl
+ ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
+ ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
+ ref_pos = torch.stack([ref_start, ref_end], dim=-1)
+
+ additional_args["ref_audio_seq_len"] = ref_audio_seq_len
+ additional_args["target_audio_seq_len"] = ax.shape[1]
+ ax = torch.cat([ref_tokens, ax], dim=1)
+ a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
+
ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)})
@@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
+ ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
+ if ref_audio_seq_len > 0 and a_timestep is not None:
+ # Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
+ target_len = kwargs.get("target_audio_seq_len")
+ if a_timestep.dim() <= 1:
+ a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
+ ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
+ a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
@@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]
+ # Trim reference audio tokens before unpatchification
+ ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
+ if ref_audio_seq_len > 0:
+ ax = ax[:, ref_audio_seq_len:]
+ if a_embedded_timestep.shape[1] > 1:
+ a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
+
# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()
diff --git a/comfy/model_base.py b/comfy/model_base.py
index bfffe2402..70aff886e 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1061,6 +1061,10 @@ class LTXAV(BaseModel):
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
+ ref_audio = kwargs.get("ref_audio", None)
+ if ref_audio is not None:
+ out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
+
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py
index c05571143..d7c2e8744 100644
--- a/comfy_extras/nodes_lt.py
+++ b/comfy_extras/nodes_lt.py
@@ -3,6 +3,7 @@ import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
+import comfy.samplers
import comfy.utils
import math
import numpy as np
@@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
return io.NodeOutput(video_latent, audio_latent)
+class LTXVReferenceAudio(io.ComfyNode):
+ @classmethod
+ def define_schema(cls) -> io.Schema:
+ return io.Schema(
+ node_id="LTXVReferenceAudio",
+ display_name="LTXV Reference Audio (ID-LoRA)",
+ category="conditioning/audio",
+ description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
+ inputs=[
+ io.Model.Input("model"),
+ io.Conditioning.Input("positive"),
+ io.Conditioning.Input("negative"),
+ io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
+ io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
+ io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
+ io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
+ io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
+ ],
+ outputs=[
+ io.Model.Output(),
+ io.Conditioning.Output(display_name="positive"),
+ io.Conditioning.Output(display_name="negative"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
+ # Encode reference audio to latents and patchify
+ audio_latents = audio_vae.encode(reference_audio)
+ b, c, t, f = audio_latents.shape
+ ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
+ ref_audio = {"tokens": ref_tokens}
+
+ positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
+ negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
+
+ # Patch model with identity guidance
+ m = model.clone()
+ scale = identity_guidance_scale
+ model_sampling = m.get_model_object("model_sampling")
+ sigma_start = model_sampling.percent_to_sigma(start_percent)
+ sigma_end = model_sampling.percent_to_sigma(end_percent)
+
+ def post_cfg_function(args):
+ if scale == 0:
+ return args["denoised"]
+
+ sigma = args["sigma"]
+ sigma_ = sigma[0].item()
+ if sigma_ > sigma_start or sigma_ < sigma_end:
+ return args["denoised"]
+
+ cond_pred = args["cond_denoised"]
+ cond = args["cond"]
+ cfg_result = args["denoised"]
+ model_options = args["model_options"].copy()
+ x = args["input"]
+
+ # Strip ref_audio from conditioning for the no-reference pass
+ noref_cond = []
+ for entry in cond:
+ new_entry = entry.copy()
+ mc = new_entry.get("model_conds", {}).copy()
+ mc.pop("ref_audio", None)
+ new_entry["model_conds"] = mc
+ noref_cond.append(new_entry)
+
+ (pred_noref,) = comfy.samplers.calc_cond_batch(
+ args["model"], [noref_cond], x, sigma, model_options
+ )
+
+ return cfg_result + (cond_pred - pred_noref) * scale
+
+ m.set_model_sampler_post_cfg_function(post_cfg_function)
+
+ return io.NodeOutput(m, positive, negative)
+
+
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
LTXVCropGuides,
LTXVConcatAVLatent,
LTXVSeparateAVLatent,
+ LTXVReferenceAudio,
]
From 2d4970ff677970fbca9f9f562296eda46de8aa4c Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 23 Mar 2026 17:43:41 -0700
Subject: [PATCH 80/80] Update frontend version to 1.42.8 (#13126)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index ad0344ed4..26cc94354 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.41.21
+comfyui-frontend-package==1.42.8
comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3
torch