Compare commits

...

64 Commits

Author SHA1 Message Date
xmarre
95244bd97d
Merge 6d4f9e86ab into 025e6792ee 2026-05-03 08:55:45 -07:00
Jukka Seppänen
025e6792ee
Batch broadcasting in JoinImageWithAlpha node (#13686)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
* Batch broadcasting in JoinImageWithAlpha node
2026-05-03 16:30:00 +03:00
Luke Mino-Altherr
867b8d2408
fix: gracefully handle port-in-use error on server startup (#13001)
Catch EADDRINUSE OSError when binding the TCP site and exit with a clear error message instead of an unhandled traceback.
2026-05-03 20:44:20 +08:00
Alexis Rolland
d0f0b15cf5
Update ComfyUI screenshot in README (#13683)
Update ComfyUI screenshot to showcase a more modern workflow
2026-05-03 18:48:58 +08:00
Alexis Rolland
b5bb83c964
Fix issue blend images with alpha (#13615)
Make ImageBlend and ImageCompositeMasked nodes handle images with different channel counts
2026-05-03 18:17:08 +08:00
Alexis Rolland
f6d5068ac0
Update README (#13679)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Updated the README to include a new screenshot, improved description and add Ernie Image to supported models.
2026-05-03 12:20:17 +08:00
Jukka Seppänen
be95871adc
feat: Gemma4 text generation support (CORE-30) (#13376)
* initial gemma4 support

* parity with reference implementation

outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize

* Cleanup, video fixes

* cleanup, enable fused rms norm by default

* update comment

* Cleanup

* Update sd.py

* Various fixes

* Add fp8 scaled embedding support

* small fixes

* Translate think tokens

* Fix image encoder attention mask type

So it works with basic attention

* Handle thinking tokens different only for Gemma4

* Code cleanup

* Update nodes_textgen.py

* Use embed scale class instead of buffer

Slight difference to HF, but technically more accurate and simpler code

* Default to fused rms_norm

* Update gemma4.py
2026-05-02 22:46:15 -04:00
Alexander Piskun
f756d801a1
[Partner Nodes] Topaz Astra 2 model (#13672)
* feat(api-nodes): add Topaz Astra 2 model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat(api-nodes): make Astra 2 the default Topaz upscaler model

Reorder UPSCALER_MODELS_MAP and the upscaler_model dynamic combo so
"Astra 2" appears first, surfacing it as the default selection.

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Marwan Mostafa <marawan206@gmail.com>
2026-05-02 19:29:00 -07:00
Daxiong (Lin)
1d23a875ed
chore: update workflow templates to v0.9.68 (#13678) 2026-05-03 10:06:55 +08:00
comfyanonymous
ef6722f6be
Some cleanups to the load image node. (#13677) 2026-05-02 20:34:27 -04:00
rattus
783782d5d7
Implement block prefetch + Lora Async load + and adopt in LTX (Speedup!) (CORE-111) (#13618)
* mm: Use Aimdo raw allocator for cast buffers

pytorch manages allocation of growing buffers on streams poorly. Pyt
has no windows support for the expandable segments allocator (which is
the right tool for this job), while also segmenting the memory by
stream such that it can be generally re-used. So kick the problem to
aimdo which can just grow a virtual region thats freed per stream.

* plan

* ops: move cpu handler up to the caller

* ops: split up prefetch from weight prep block prefetching API

Split up the casting and weight formating/lora stuff in prep for
arbitrary prefetch support.

* ops: implement block prefetching API

allow a model to construct a prefetch list and operate it for increased
async offload.

* ltxv2: Implement block prefetching

* Implement lora async offload

Implement async offload of loras.
2026-05-02 19:23:24 -04:00
xmarre
6d4f9e86ab
Merge branch 'Comfy-Org:master' into master 2026-04-18 09:20:41 +02:00
xmarre
1548aee40e
Merge pull request #3 from xmarre/codex/vae-encode-tiled-admission-fix
Fix 2D tiled VAE encode memory admission estimation
2026-04-16 12:59:16 +02:00
xmarre
9c210473fc Fix tiled VAE encode memory admission estimate 2026-04-16 12:49:49 +02:00
xmarre
c1e9164c63
Merge branch 'master' into master 2026-04-16 10:07:30 +02:00
xmarre
5e9a90186f
Merge branch 'master' into master 2026-04-14 20:11:29 +02:00
xmarre
ece906328a
Merge branch 'master' into master 2026-04-02 18:55:32 +02:00
xmarre
500ca8e02a
Merge branch 'Comfy-Org:master' into master 2026-03-25 17:45:49 +01:00
xmarre
3143b7981f
Merge branch 'Comfy-Org:master' into master 2026-03-23 02:13:08 +01:00
xmarre
c9b3f81e83
Merge branch 'master' into master 2026-03-18 14:06:06 +01:00
xmarre
5e74e9b3ed
Merge pull request #1 from xmarre/codex/fix-cache-signature-shallow-check
Enforce shallow is_changed signature handling
2026-03-18 13:29:34 +01:00
xmarre
c702cddf75 Fix shallow is_changed logic 2026-03-18 13:15:04 +01:00
xmarre
e13da8104c Fix shallow is_changed handling 2026-03-18 12:26:30 +01:00
xmarre
fdcc38b9ea Return Unhashable on missing node 2026-03-17 07:48:14 +01:00
xmarre
c1ce00287c Stop requeueing live containers 2026-03-16 19:21:24 +01:00
xmarre
6e3bd33665 Prevent dict key canonicalization 2026-03-16 17:06:09 +01:00
xmarre
ce05e377a8 Stop canonicalizing dict keys 2026-03-16 16:48:42 +01:00
xmarre
1a00f7743f Stop traversing dict keys 2026-03-16 16:10:01 +01:00
xmarre
a6472b1514 Fix to_hashable traversal stack handling 2026-03-16 15:34:15 +01:00
xmarre
6158cd5820 Prevent redundant signature rewalk 2026-03-16 13:31:02 +01:00
xmarre
bff714dda0 Fix non-link input cache signature 2026-03-16 10:13:04 +01:00
xmarre
fce22da313 Prevent signature traversal of raw 2026-03-16 09:29:00 +01:00
xmarre
9f9d37bd9a
Merge branch 'master' into master 2026-03-16 09:07:29 +01:00
xmarre
088778c35d Stop canonicalizing is_changed 2026-03-15 17:06:20 +01:00
xmarre
4c5f82971e Restrict is_changed canonicalization 2026-03-15 16:44:25 +01:00
xmarre
f1d91a4c8c Prevent canonicalizing is_changed 2026-03-15 16:14:23 +01:00
xmarre
dbed5a1b52 Replace sanitize and hash passes 2026-03-15 07:39:10 +01:00
xmarre
24fdbb9aca Replace sanitize hash two pass 2026-03-15 07:30:18 +01:00
xmarre
a6624a9afd Unify signature sanitize and hash 2026-03-15 07:09:24 +01:00
xmarre
0b512198e8 Adopt single-pass signature hashing 2026-03-15 05:41:39 +01:00
xmarre
9feb26928c Change signature cache to bail early 2026-03-15 04:31:32 +01:00
xmarre
fadd79ad48 Fix nondeterministic set signing 2026-03-15 03:29:59 +01:00
xmarre
77bc7bdd6b Merge branch 'master' of https://github.com/xmarre/ComfyUI 2026-03-15 02:56:09 +01:00
xmarre
117afbc1d7 Add docstrings and harden signature 2026-03-15 02:55:39 +01:00
xmarre
064eec2278
Merge branch 'master' into master 2026-03-15 02:32:56 +01:00
xmarre
aceaa5e579 fail closed on ambiguous container ordering in cache signatures 2026-03-15 02:32:25 +01:00
xmarre
763089f681
Merge branch 'master' into master 2026-03-15 01:48:10 +01:00
xmarre
1693dabc8f
Merge branch 'master' into master 2026-03-15 00:28:34 +01:00
xmarre
08063d2638
Merge branch 'Comfy-Org:master' into master 2026-03-14 23:38:46 +01:00
xmarre
e069617e54
Merge branch 'Comfy-Org:master' into master 2026-03-14 21:27:17 +01:00
xmarre
2bea0ee5d7 Simplify Unhashable sentinel implementation 2026-03-14 12:42:04 +01:00
xmarre
17863f603a Add comprehensive docstrings for cache key helpers 2026-03-14 12:26:27 +01:00
xmarre
31ba844624 Add cycle detection to signature input sanitization 2026-03-14 12:04:31 +01:00
xmarre
1451001f64 Add docstrings for cache signature hardening helpers 2026-03-14 10:57:45 +01:00
xmarre
1af99b2e81 Update caching hash recursion 2026-03-14 10:31:07 +01:00
xmarre
3568b82b76 Revert "Add missing docstrings"
This reverts commit 4b431ffc27.
2026-03-14 10:11:35 +01:00
xmarre
6728d4d439 Revert "Harden to_hashable against cycles"
This reverts commit 880b51ac4f.
2026-03-14 10:11:04 +01:00
xmarre
4b431ffc27 Add missing docstrings 2026-03-14 09:57:22 +01:00
xmarre
880b51ac4f Harden to_hashable against cycles 2026-03-14 09:46:27 +01:00
xmarre
4d9516b909 Fix caching sanitization logic 2026-03-14 07:06:39 +01:00
xmarre
39086890e2 Fix sanitize_signature_input 2026-03-14 06:56:49 +01:00
xmarre
2adde5a0e1 Keep container types in sanitizer 2026-03-14 06:36:06 +01:00
xmarre
0c1bfad0df
Merge branch 'Comfy-Org:master' into master 2026-03-14 06:13:25 +01:00
xmarre
7d76a4447e Sanitize execution cache inputs 2026-03-14 02:36:40 +01:00
30 changed files with 3246 additions and 161 deletions

View File

@ -1,7 +1,7 @@
<div align="center">
# ComfyUI
**The most powerful and modular visual AI engine and application.**
**The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url]
@ -31,10 +31,16 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
<br>
</div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Ernie Image
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

View File

@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
import comfy.model_prefetch
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep,
)
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):

View File

@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled():
import xformers
import xformers.ops
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
if "scale" in kwargs:
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out

View File

@ -17,6 +17,7 @@
"""
from __future__ import annotations
import comfy.memory_management
import comfy.utils
import comfy.model_management
import comfy.model_base
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight
return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return value

View File

@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module):
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
transformer_options = transformer_options.copy()
transformer_options["prefetch_dynamic_vbars"] = (
self.current_patcher is not None and self.current_patcher.is_dynamic()
)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -1175,6 +1176,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer
def get_aimdo_cast_buffer(offload_stream, device):
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):

View File

@ -121,9 +121,20 @@ class LowVramPatch:
self.patches = patches
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
def clear_prepared(self):
self.prepared_patches = None
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2

65
comfy/model_prefetch.py Normal file
View File

@ -0,0 +1,65 @@
import comfy_aimdo.model_vbar
import comfy.model_management
import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
offload_stream.wait_stream(comfy.model_management.current_stream(device))
_, comfy_modules = prefetch_state
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
prefetch = queue[0]
if prefetch is not None:
comfy_modules = []
for s in prefetch.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))
def make_prefetch_queue(queue, device, transformer_options):
if (not transformer_options.get("prefetch_dynamic_vbars", False)
or comfy.model_management.NUM_STREAMS == 0
or comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return None
queue = [None] + queue + [None]
PREFETCH_QUEUES.append(queue)
return queue

View File

@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
# FIXME: add n=1 cache hit fast path
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
offload_stream = None
xfer_dest = None
cast_buffer = None
cast_buffer_offset = 0
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
nonlocal cast_buffer
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
return
current_size = 0 if cast_buffer is None else cast_buffer.size()
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = None
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
def get_cast_buffer(buffer_size):
nonlocal offload_stream
nonlocal cast_buffer
nonlocal cast_buffer_offset
if buffer_size == 0:
return None
if offload_stream is None:
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
cast_buffer_offset += buffer_size
return buffer
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
prefetch = {
"signature": signature,
"resident": resident,
}
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
s._prefetch = prefetch
continue
if not resident:
materialize_meta_param(s, ["weight", "bias"])
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
needs_cast = False
xfer_source = [ s.weight, s.bias ]
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None:
continue
if data.dtype != geometry.dtype:
needs_cast = True
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None:
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
return offload_stream
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
prefetch = getattr(s, "_prefetch", None)
if prefetch["resident"]:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = prefetch["xfer_dest"]
if prefetch["needs_cast"]:
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
if prefetch["signature"] is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
s._v_signature = prefetch["signature"]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
if x is None:
return None
orig = x
def to_dequant(tensor, dtype):
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x)
return x
update_weight = signature is not None
update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
if bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
return weight, bias
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
def format_return(result, offloadable):
weight, bias, offload_stream = result
return (weight, bias, offload_stream) if offloadable else (weight, bias)
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return format_return((weight, bias, (None, None, None)), offloadable)
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None
if not prefetched:
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function:
weight = f(weight)
if offloadable:
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
def uncast_bias_weight(s, weight, bias, offload_stream):
@ -1173,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf)
return self
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
if scale is not None:
scale = scale.float()
manually_loaded_keys.append(scale_key)
params = layout_cls.Params(
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.num_embeddings, self.embedding_dim),
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight
# Optimized path: lookup in fp8, dequantize only the selected rows.
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
if isinstance(qdata, QuantizedTensor):
scale = qdata._params.scale
qdata = qdata._qdata
else:
scale = None
x = torch.nn.functional.embedding(
input, qdata, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
uncast_bias_weight(self, qdata, None, offload_stream)
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
x = x.to(dtype=target_dtype)
if scale is not None and scale != 1.0:
x = x * scale.to(dtype=target_dtype)
return x
# Fallback for non-quantized or weight_function (LoRA) case
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):

View File

@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6):
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)

View File

@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.model_patcher
import comfy.lora
@ -1120,7 +1121,17 @@ class VAE:
else:
pixel_samples = pixel_samples.unsqueeze(2)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
if dims == 2:
default_tile_x = 512 if tile_x is None else tile_x
default_tile_y = 512 if tile_y is None else tile_y
tile_shapes = [
(1, pixel_samples.shape[1], min(pixel_samples.shape[2], max(1, default_tile_y)), min(pixel_samples.shape[3], max(1, default_tile_x))),
(1, pixel_samples.shape[1], min(pixel_samples.shape[2], max(1, default_tile_y // 2)), min(pixel_samples.shape[3], max(1, default_tile_x * 2))),
(1, pixel_samples.shape[1], min(pixel_samples.shape[2], max(1, default_tile_y * 2)), min(pixel_samples.shape[3], max(1, default_tile_x // 2))),
]
memory_used = max(self.memory_used_encode(shape, self.vae_dtype) for shape in tile_shapes)
else:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
args = {}
@ -1271,6 +1282,9 @@ class TEModel(Enum):
QWEN35_9B = 26
QWEN35_27B = 27
MINISTRAL_3_3B = 28
GEMMA_4_E4B = 29
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
def detect_te_model(sd):
@ -1296,6 +1310,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1435,6 +1455,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

File diff suppressed because it is too large Load Diff

View File

@ -521,7 +521,7 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
@ -533,12 +533,12 @@ class Attention(nn.Module):
return self.o_proj(output), present_key_value
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
super().__init__()
ops = ops or nn
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh":
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
class ScaledEmbedding(ops.Embedding):
def forward(self, input_ids, out_dtype=None):
return super().forward(input_ids, out_dtype=out_dtype) * scale
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
else:
transformer = TransformerBlock
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
self.config.rope_dims,
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
@ -850,7 +848,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
device = embeds.device
if stop_tokens is None:
@ -875,14 +873,16 @@ class BaseGenerate:
pbar = comfy.utils.ProgressBar(max_length)
# Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
current_input_ids = next_token if initial_input_ids is not None else None
pbar.update(1)
if token_id in stop_tokens:

View File

@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):

View File

@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = super().process_tokens(tokens, device)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)

View File

@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional
from pydantic import BaseModel, Field
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
realism: float | None = Field(None, description="ast-2 realism control")
class OutputInformationVideo(BaseModel):
@ -90,7 +93,7 @@ class Overrides(BaseModel):
class CreateVideoRequest(BaseModel):
source: CreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))

View File

@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
)
UPSCALER_MODELS_MAP = {
"Astra 2": "ast-2",
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5",
}
AST2_MAX_FRAMES = 9000
AST2_MAX_FRAMES_WITH_PROMPT = 450
class TopazImageEnhance(IO.ComfyNode):
@classmethod
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance",
display_name="Topaz Video Enhance (Legacy)",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input(
"upscaler_model",
options=[
"Starlight (Astra) Fast",
"Starlight (Astra) Creative",
"Starlight Precise 2.5",
],
),
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"upscaler_creativity",
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazVideoEnhanceV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.DynamicCombo.Input(
"upscaler_model",
options=[
IO.DynamicCombo.Option(
"Astra 2",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Float.Input(
"creativity",
default=0.5,
min=0.0,
max=1.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Creative strength of the upscale.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional descriptive (not instructive) scene prompt."
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
),
IO.Float.Input(
"sharp",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pre-enhance sharpness: "
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
advanced=True,
),
IO.Float.Input(
"realism",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pulls output toward photographic realism."
"Leave at 0 for the model default.",
advanced=True,
),
],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Fast",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Creative",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"creativity",
options=["low", "middle", "high"],
default="low",
tooltip="Creative strength of the upscale.",
),
],
),
IO.DynamicCombo.Option(
"Starlight Precise 2.5",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
),
IO.DynamicCombo.Option("Disabled", []),
],
),
IO.DynamicCombo.Input(
"interpolation_model",
options=[
IO.DynamicCombo.Option("Disabled", []),
IO.DynamicCombo.Option(
"apo-8",
[
IO.Int.Input(
"interpolation_frame_rate",
default=60,
min=15,
max=240,
display_mode=IO.NumberDisplay.number,
tooltip="Output frame rate.",
),
IO.Int.Input(
"interpolation_slowmo",
default=1,
min=1,
max=16,
display_mode=IO.NumberDisplay.number,
tooltip="Slow-motion factor applied to the input video. "
"For example, 2 makes the output twice as slow and doubles the duration.",
advanced=True,
),
IO.Boolean.Input(
"interpolation_duplicate",
default=False,
tooltip="Analyze the input for duplicate frames and remove them.",
advanced=True,
),
IO.Float.Input(
"interpolation_duplicate_threshold",
default=0.01,
min=0.001,
max=0.1,
step=0.001,
display_mode=IO.NumberDisplay.number,
tooltip="Detection sensitivity for duplicate frames.",
advanced=True,
),
],
),
],
),
IO.Combo.Input(
"dynamic_compression_level",
options=["Low", "Mid", "High"],
default="Low",
tooltip="CQP level.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=[
"upscaler_model",
"upscaler_model.upscaler_resolution",
"interpolation_model",
]),
expr="""
(
$model := $lookup(widgets, "upscaler_model");
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
$interp := $lookup(widgets, "interpolation_model");
$is4k := $contains($res, "4k");
$hasInterp := $interp != "disabled";
$rates := {
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
"astra 2": {"hd": 1.72, "uhd": 2.85},
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
};
$surcharge := $is4k ? 0.28 : 0.14;
$entry := $lookup($rates, $model);
$base := $is4k ? $entry.uhd : $entry.hd;
$hi := $base + ($hasInterp ? $surcharge : 0);
$model = "disabled"
? {"type":"text","text":"Interpolation only"}
: ($hasInterp
? {"type":"text","text":"~" & $string($base) & "" & $string($hi) & " credits/src frame"}
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
upscaler_model: dict,
interpolation_model: dict,
dynamic_compression_level: str = "Low",
) -> IO.NodeOutput:
upscaler_choice = upscaler_model["upscaler_model"]
interpolation_choice = interpolation_model["interpolation_model"]
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
target_frame_rate = src_frame_rate
filters = []
if upscaler_choice != "Disabled":
if "1080p" in upscaler_model["upscaler_resolution"]:
target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
if model_id == "slc-1":
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
isOptimizedMode=True,
)
)
elif model_id == "ast-2":
n_frames = video.get_frame_count()
ast2_prompt = (upscaler_model["prompt"] or "").strip()
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
raise ValueError(
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
)
if n_frames > AST2_MAX_FRAMES:
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
realism = upscaler_model["realism"]
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
prompt=(ast2_prompt or None),
sharp=upscaler_model["sharp"],
realism=(realism if realism > 0 else None),
)
)
else:
filters.append(VideoEnhancementFilter(model=model_id))
if interpolation_choice != "Disabled":
target_frame_rate = interpolation_model["interpolation_frame_rate"]
filters.append(
VideoFrameInterpolationFilter(
model=interpolation_choice,
slowmo=interpolation_model["interpolation_slowmo"],
fps=interpolation_model["interpolation_frame_rate"],
duplicate=interpolation_model["interpolation_duplicate"],
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
),
)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=CreateVideoResponse,
data=CreateVideoRequest(
source=CreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=Resolution(width=src_width, height=src_height),
),
filters=filters,
output=OutputInformationVideo(
resolution=Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
dynamicCompressionLevel=dynamic_compression_level,
),
),
wait_label="Creating task",
final_label_on_success="Task created",
)
upload_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
if len(upload_res.urls) > 1:
raise NotImplementedError(
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
)
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
if isinstance(src_video_stream, BytesIO):
src_video_stream.seek(0)
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
else:
with builtins.open(src_video_stream, "rb") as video_file:
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=VideoCompleteUploadResponse,
data=VideoCompleteUploadRequest(
uploadResults=[
VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
],
),
wait_label="Finalizing upload",
final_label_on_success="Upload completed",
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TopazImageEnhance,
TopazVideoEnhance,
TopazVideoEnhanceV2,
]

View File

@ -1,5 +1,6 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@ -17,6 +18,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
"""Return whether a node class includes UNIQUE_ID among its hidden inputs."""
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -24,52 +26,412 @@ def include_unique_id_in_input(class_type: str) -> bool:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet(ABC):
"""Base helper for building and storing cache keys for prompt nodes."""
def __init__(self, dynprompt, node_ids, is_changed_cache):
"""Initialize cache-key storage for a dynamic prompt execution pass."""
self.keys = {}
self.subcache_keys = {}
@abstractmethod
async def add_keys(self, node_ids):
"""Populate cache keys for the provided node ids."""
raise NotImplementedError()
def all_node_ids(self):
"""Return the set of node ids currently tracked by this key set."""
return set(self.keys.keys())
def get_used_keys(self):
"""Return the computed cache keys currently in use."""
return self.keys.values()
def get_used_subcache_keys(self):
"""Return the computed subcache keys currently in use."""
return self.subcache_keys.values()
def get_data_key(self, node_id):
"""Return the cache key for a node, if present."""
return self.keys.get(node_id, None)
def get_subcache_key(self, node_id):
"""Return the subcache key for a node, if present."""
return self.subcache_keys.get(node_id, None)
class Unhashable:
def __init__(self):
self.value = float("NaN")
"""Hashable identity sentinel for values that cannot be represented safely in cache keys."""
pass
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
else:
# TODO - Support other objects like tensors?
_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None))
_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset)
_MAX_SIGNATURE_DEPTH = 32
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
_FAILED_SIGNATURE = object()
def _shallow_is_changed_signature(value):
"""Reduce execution-time `is_changed` values through a fail-closed builtin canonicalizer."""
value_type = type(value)
if value_type in _PRIMITIVE_SIGNATURE_TYPES:
return value
if value_type not in _CONTAINER_SIGNATURE_TYPES:
return Unhashable()
canonical = _signature_to_hashable(value, max_nodes=64)
if type(canonical) is Unhashable:
return canonical
if value_type is list or value_type is tuple:
container_tag = "is_changed_list" if value_type is list else "is_changed_tuple"
return (container_tag, canonical[1])
return canonical
def _primitive_signature_sort_key(obj):
"""Return a deterministic ordering key for primitive signature values."""
obj_type = type(obj)
return ("primitive", obj_type.__module__, obj_type.__qualname__, repr(obj))
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
"""Return a deterministic ordering key for sanitized built-in container content."""
if depth >= max_depth:
return ("MAX_DEPTH",)
if active is None:
active = set()
if memo is None:
memo = {}
obj_type = type(obj)
if obj_type is Unhashable:
return ("UNHASHABLE",)
elif obj_type in _PRIMITIVE_SIGNATURE_TYPES:
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
elif obj_type not in _CONTAINER_SIGNATURE_TYPES:
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if obj_id in active:
return ("CYCLE",)
active.add(obj_id)
try:
if obj_type is dict:
items = [
(
_sanitized_sort_key(k, depth + 1, max_depth, active, memo),
_sanitized_sort_key(v, depth + 1, max_depth, active, memo),
)
for k, v in obj.items()
]
items.sort()
result = ("dict", tuple(items))
elif obj_type is list:
result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
elif obj_type is tuple:
result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
elif obj_type is set:
result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
else:
result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
finally:
active.discard(obj_id)
memo[obj_id] = result
return result
def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
"""Canonicalize signature inputs directly into their final hashable form."""
if depth >= max_depth:
return _FAILED_SIGNATURE
if active is None:
active = set()
if memo is None:
memo = {}
if budget is None:
budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS}
obj_type = type(obj)
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
return obj, _primitive_signature_sort_key(obj)
if obj_type is Unhashable or obj_type not in _CONTAINER_SIGNATURE_TYPES:
return _FAILED_SIGNATURE
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if obj_id in active:
return _FAILED_SIGNATURE
budget["remaining"] -= 1
if budget["remaining"] < 0:
return _FAILED_SIGNATURE
active.add(obj_id)
try:
if obj_type is dict:
try:
items = list(obj.items())
except RuntimeError:
return _FAILED_SIGNATURE
ordered_items = []
for key, value in items:
if type(key) not in _PRIMITIVE_SIGNATURE_TYPES:
return _FAILED_SIGNATURE
key_result = (key, _primitive_signature_sort_key(key))
value_result = _signature_to_hashable_impl(value, depth + 1, max_depth, active, memo, budget)
if value_result is _FAILED_SIGNATURE:
return _FAILED_SIGNATURE
key_value, key_sort = key_result
value_value, value_sort = value_result
ordered_items.append((key_sort, value_sort, key_value, value_value))
ordered_items.sort(key=lambda item: (item[0], item[1]))
for index in range(1, len(ordered_items)):
previous_key_sort = ordered_items[index - 1][0]
current_key_sort = ordered_items[index][0]
if previous_key_sort == current_key_sort:
return _FAILED_SIGNATURE
value = ("dict", tuple((key_value, value_value) for _, _, key_value, value_value in ordered_items))
sort_key = ("dict", tuple((key_sort, value_sort) for key_sort, value_sort, _, _ in ordered_items))
elif obj_type is list or obj_type is tuple:
try:
items = list(obj)
except RuntimeError:
return _FAILED_SIGNATURE
child_results = []
for item in items:
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
if child_result is _FAILED_SIGNATURE:
return _FAILED_SIGNATURE
child_results.append(child_result)
container_tag = "list" if obj_type is list else "tuple"
value = (container_tag, tuple(child for child, _ in child_results))
sort_key = (container_tag, tuple(child_sort for _, child_sort in child_results))
else:
try:
items = list(obj)
except RuntimeError:
return _FAILED_SIGNATURE
ordered_items = []
for item in items:
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
if child_result is _FAILED_SIGNATURE:
return _FAILED_SIGNATURE
child_value, child_sort = child_result
ordered_items.append((child_sort, child_value))
ordered_items.sort(key=lambda item: item[0])
for index in range(1, len(ordered_items)):
previous_sort_key, previous_value = ordered_items[index - 1]
current_sort_key, current_value = ordered_items[index]
if previous_sort_key == current_sort_key and previous_value != current_value:
return _FAILED_SIGNATURE
container_tag = "set" if obj_type is set else "frozenset"
value = (container_tag, tuple(child_value for _, child_value in ordered_items))
sort_key = (container_tag, tuple(child_sort for child_sort, _ in ordered_items))
finally:
active.discard(obj_id)
memo[obj_id] = (value, sort_key)
return memo[obj_id]
def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
"""Build the final cache-signature representation in one fail-closed pass."""
try:
result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes})
except RuntimeError:
return Unhashable()
if result is _FAILED_SIGNATURE:
return Unhashable()
return result[0]
def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
"""Convert sanitized prompt inputs into a stable hashable representation.
The input is expected to already be sanitized to plain built-in containers,
but this function still fails safe for anything unexpected. Traversal is
iterative and memoized so shared built-in substructures do not trigger
exponential re-walks during cache-key construction.
"""
obj_type = type(obj)
if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable:
return obj
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
return Unhashable()
memo = {}
active = set()
snapshots = {}
sort_memo = {}
processed = 0
# Keep traversal state separate from container snapshots/results.
work_stack = [(obj, False)]
def resolve_value(value):
"""Resolve a child value from the completed memo table when available."""
value_type = type(value)
if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable:
return value
return memo.get(id(value), Unhashable())
def is_failed(value):
"""Return whether a resolved child value represents failed canonicalization."""
return type(value) is Unhashable
def resolve_unordered_values(current_items, container_tag):
"""Resolve a set-like container or fail closed if ordering is ambiguous."""
try:
ordered_items = [
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
for item in current_items
]
if any(is_failed(value) for _, value in ordered_items):
return Unhashable()
ordered_items.sort(key=lambda item: item[0])
except RuntimeError:
return Unhashable()
for index in range(1, len(ordered_items)):
previous_key, previous_value = ordered_items[index - 1]
current_key, current_value = ordered_items[index]
if previous_key == current_key and previous_value != current_value:
return Unhashable()
return (container_tag, tuple(value for _, value in ordered_items))
while work_stack:
entry = work_stack.pop()
if len(entry) == 3:
_, current_id, current_type = entry
current = None
expanded = True
else:
current, expanded = entry
current_type = type(current)
current_id = id(current)
if not expanded and (current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable):
continue
if not expanded and current_type not in _CONTAINER_SIGNATURE_TYPES:
memo[current_id] = Unhashable()
continue
if current_id in memo:
continue
if expanded:
active.discard(current_id)
try:
items = snapshots.pop(current_id, None)
if items is None:
memo[current_id] = Unhashable()
continue
if current_type is dict:
ordered_items = [
(_sanitized_sort_key(k, memo=sort_memo), k, resolve_value(v))
for k, v in items
]
if any(type(key) not in _PRIMITIVE_SIGNATURE_TYPES or is_failed(value) for _, key, value in ordered_items):
memo[current_id] = Unhashable()
continue
ordered_items.sort(key=lambda item: item[0])
for index in range(1, len(ordered_items)):
if ordered_items[index - 1][0] == ordered_items[index][0]:
memo[current_id] = Unhashable()
break
else:
memo[current_id] = (
"dict",
tuple((key, value) for _, key, value in ordered_items),
)
elif current_type is list:
resolved_items = tuple(resolve_value(item) for item in items)
if any(is_failed(item) for item in resolved_items):
memo[current_id] = Unhashable()
else:
memo[current_id] = ("list", resolved_items)
elif current_type is tuple:
resolved_items = tuple(resolve_value(item) for item in items)
if any(is_failed(item) for item in resolved_items):
memo[current_id] = Unhashable()
else:
memo[current_id] = ("tuple", resolved_items)
elif current_type is set:
memo[current_id] = resolve_unordered_values(items, "set")
else:
memo[current_id] = resolve_unordered_values(items, "frozenset")
except RuntimeError:
memo[current_id] = Unhashable()
continue
if current_id in active:
memo[current_id] = Unhashable()
continue
processed += 1
if processed > max_nodes:
return Unhashable()
active.add(current_id)
if current_type is dict:
try:
items = list(current.items())
snapshots[current_id] = items
except RuntimeError:
memo[current_id] = Unhashable()
active.discard(current_id)
continue
for key, value in items:
if type(key) not in _PRIMITIVE_SIGNATURE_TYPES:
snapshots.pop(current_id, None)
memo[current_id] = Unhashable()
active.discard(current_id)
break
else:
work_stack.append(("EXPANDED", current_id, current_type))
for _, value in reversed(items):
work_stack.append((value, False))
continue
continue
else:
try:
items = list(current)
snapshots[current_id] = items
except RuntimeError:
memo[current_id] = Unhashable()
active.discard(current_id)
continue
work_stack.append(("EXPANDED", current_id, current_type))
for item in reversed(items):
work_stack.append((item, False))
return memo.get(id(obj), Unhashable())
class CacheKeySetID(CacheKeySet):
"""Cache-key strategy that keys nodes by node id and class type."""
def __init__(self, dynprompt, node_ids, is_changed_cache):
"""Initialize identity-based cache keys for the supplied dynamic prompt."""
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
async def add_keys(self, node_ids):
"""Populate identity-based keys for nodes that exist in the dynamic prompt."""
for node_id in node_ids:
if node_id in self.keys:
continue
@ -80,15 +442,19 @@ class CacheKeySetID(CacheKeySet):
self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet):
"""Cache-key strategy that hashes a node's immediate inputs plus ancestor references."""
def __init__(self, dynprompt, node_ids, is_changed_cache):
"""Initialize input-signature-based cache keys for the supplied dynamic prompt."""
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
def include_node_id_in_input(self) -> bool:
"""Return whether node ids should be included in computed input signatures."""
return False
async def add_keys(self, node_ids):
"""Populate input-signature-based keys for nodes in the dynamic prompt."""
for node_id in node_ids:
if node_id in self.keys:
continue
@ -99,21 +465,37 @@ class CacheKeySetInputSignature(CacheKeySet):
self.subcache_keys[node_id] = (node_id, node["class_type"])
async def get_node_signature(self, dynprompt, node_id):
"""Build the full cache signature for a node and its ordered ancestors."""
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
immediate = await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)
if type(immediate) is Unhashable:
return immediate
signature.append(immediate)
for ancestor_id in ancestors:
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
immediate = await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)
if type(immediate) is Unhashable:
return immediate
signature.append(immediate)
return tuple(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
"""Build the immediate cache-signature fragment for a node.
Link inputs are reduced to ancestor references here. Non-link values
are canonicalized or failed closed before being appended so the final
node signature is assembled from already-hashable fragments.
"""
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
return Unhashable()
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, await self.is_changed_cache.get(node_id)]
is_changed_signature = _shallow_is_changed_signature(await self.is_changed_cache.get(node_id))
if type(is_changed_signature) is Unhashable:
return is_changed_signature
signature = [class_type, is_changed_signature]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
@ -123,18 +505,23 @@ class CacheKeySetInputSignature(CacheKeySet):
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
return signature
value_signature = to_hashable(inputs[key])
if type(value_signature) is Unhashable:
return value_signature
signature.append((key, value_signature))
return tuple(signature)
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
"""Return ancestors in deterministic traversal order and their index mapping."""
ancestors = []
order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
"""Recursively collect ancestors in input order without revisiting prior nodes."""
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]

View File

@ -1,11 +1,17 @@
def is_link(obj):
if not isinstance(obj, list):
"""Return whether obj is a plain prompt link of the form [node_id, output_index]."""
# Prompt links produced by the frontend / GraphBuilder are plain Python
# lists in the form [node_id, output_index]. Some custom-node paths can
# inject foreign runtime objects into prompt inputs during on-prompt graph
# rewriting or subgraph construction. Be strict here so cache signature
# building never tries to treat list-like proxy objects as links.
if type(obj) is not list:
return False
if len(obj) != 2:
return False
if not isinstance(obj[0], str):
if type(obj[0]) is not str:
return False
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
if type(obj[1]) is not int:
return False
return True

View File

@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha))
out_images = []
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
return io.NodeOutput(torch.stack(out_images))
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
class CompositingExtension(ComfyExtension):

View File

@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True),
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
io.Audio.Input("audio", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
seed=seed
)
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
generated_text = clip.decode(generated_ids)
return io.NodeOutput(generated_text)
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
class TextgenExtension(ComfyExtension):

View File

@ -15,6 +15,7 @@ import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy.model_prefetch
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if args.verbose == "DEBUG":
comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers()
comfy.model_prefetch.cleanup_prefetch_queues()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:

View File

@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
if destination.shape[-1] < source.shape[-1]:
source = source[...,:destination.shape[-1]]
elif destination.shape[-1] > source.shape[-1]:
destination = torch.nn.functional.pad(destination, (0, 1))
destination[..., -1] = 1.0
source = torch.nn.functional.pad(source, (0, 1))
source[..., -1] = 1.0
return destination, source

View File

@ -1694,26 +1694,27 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
dtype = comfy.model_management.intermediate_dtype()
device = comfy.model_management.intermediate_device()
components = InputImpl.VideoFromFile(image_path).get_components()
if components.images.shape[0] > 0:
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
# This code is left here to handle animated webp which pyav does not support loading
img = node_helpers.pillow(Image.open, image_path)
output_images = []
output_masks = []
w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
@ -1728,25 +1729,15 @@ class LoadImage:
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO":
break # ignore all frames except the first one for MPO format
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
@classmethod
def IS_CHANGED(s, image):

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.15
comfyui-workflow-templates==0.9.66
comfyui-workflow-templates==0.9.68
comfyui-embedded-docs==0.4.4
torch
torchsde

View File

@ -1,3 +1,4 @@
import errno
import os
import sys
import asyncio
@ -1245,7 +1246,13 @@ class PromptServer():
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
try:
await site.start()
except OSError as e:
if e.errno == errno.EADDRINUSE:
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
raise SystemExit(1)
raise
if not hasattr(self, 'address'):
self.address = address #TODO: remove this

View File

@ -0,0 +1,473 @@
"""Unit tests for cache-signature canonicalization hardening."""
import asyncio
import importlib
import sys
import types
import pytest
class _DummyNode:
"""Minimal node stub used to satisfy cache-signature class lookups."""
@staticmethod
def INPUT_TYPES():
"""Return a minimal empty input schema for unit tests."""
return {"required": {}}
class _FakeDynPrompt:
"""Small DynamicPrompt stand-in with only the methods these tests need."""
def __init__(self, nodes_by_id):
"""Store test nodes by id."""
self._nodes_by_id = nodes_by_id
def has_node(self, node_id):
"""Return whether the fake prompt contains the requested node."""
return node_id in self._nodes_by_id
def get_node(self, node_id):
"""Return the stored node payload for the requested id."""
return self._nodes_by_id[node_id]
class _FakeIsChangedCache:
"""Async stub for `is_changed` lookups used by cache-key generation."""
def __init__(self, values):
"""Store canned `is_changed` responses keyed by node id."""
self._values = values
async def get(self, node_id):
"""Return the canned `is_changed` value for a node."""
return self._values[node_id]
class _OpaqueValue:
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
@pytest.fixture
def caching_module(monkeypatch):
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
torch_module = types.ModuleType("torch")
psutil_module = types.ModuleType("psutil")
nodes_module = types.ModuleType("nodes")
nodes_module.NODE_CLASS_MAPPINGS = {}
graph_module = types.ModuleType("comfy_execution.graph")
class DynamicPrompt:
"""Placeholder graph type so the caching module can import cleanly."""
pass
graph_module.DynamicPrompt = DynamicPrompt
monkeypatch.setitem(sys.modules, "torch", torch_module)
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
module = importlib.import_module("comfy_execution.caching")
module = importlib.reload(module)
return module, nodes_module
def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module):
"""Shared built-in substructures should canonicalize without collapsing to Unhashable."""
caching, _ = caching_module
shared = [{"value": 1}, {"value": 2}]
signature = caching._signature_to_hashable([shared, shared])
assert signature[0] == "list"
assert signature[1][0] == signature[1][1]
assert signature[1][0][0] == "list"
assert signature[1][0][1][0] == ("dict", (("value", 1),))
assert signature[1][0][1][1] == ("dict", (("value", 2),))
def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module):
"""Opaque values should collapse the full signature to Unhashable immediately."""
caching, _ = caching_module
signature = caching._signature_to_hashable(["safe", object()])
assert isinstance(signature, caching.Unhashable)
def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch):
"""Once canonicalization fails, later recursive descent should stop immediately."""
caching, _ = caching_module
original = caching._signature_to_hashable_impl
marker = object()
marker_seen = False
def tracking_canonicalize(obj, *args, **kwargs):
"""Track whether recursion reaches the nested marker after failure."""
nonlocal marker_seen
if obj is marker:
marker_seen = True
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize)
signature = caching._signature_to_hashable([object(), [marker]])
assert isinstance(signature, caching.Unhashable)
assert marker_seen is False
def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch):
"""List canonicalization should read a point-in-time snapshot before recursive descent."""
caching, _ = caching_module
original = caching._signature_to_hashable_impl
marker = ("marker",)
values = [marker, 2]
def mutating_canonicalize(obj, *args, **kwargs):
"""Mutate the live list during recursion to verify snapshot-based traversal."""
if obj is marker:
values[1] = 3
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
signature = caching._signature_to_hashable(values)
assert signature == ("list", (("tuple", ("marker",)), 2))
assert values[1] == 3
def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch):
"""Dict canonicalization should read a point-in-time snapshot before recursive descent."""
caching, _ = caching_module
original = caching._signature_to_hashable_impl
marker = ("marker",)
values = {"first": marker, "second": 2}
def mutating_canonicalize(obj, *args, **kwargs):
"""Mutate the live dict during recursion to verify snapshot-based traversal."""
if obj is marker:
values["second"] = 3
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
signature = caching._signature_to_hashable(values)
assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2)))
assert values["second"] == 3
@pytest.mark.parametrize(
"container_factory",
[
lambda marker: [marker],
lambda marker: (marker,),
lambda marker: {marker},
lambda marker: frozenset({marker}),
lambda marker: {"key": marker},
],
)
def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
"""Traversal RuntimeError should degrade canonicalization to Unhashable."""
caching, _ = caching_module
original = caching._signature_to_hashable_impl
marker = object()
def raising_canonicalize(obj, *args, **kwargs):
"""Raise a traversal RuntimeError for the marker value and delegate otherwise."""
if obj is marker:
raise RuntimeError("container changed during iteration")
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize)
signature = caching._signature_to_hashable(container_factory(marker))
assert isinstance(signature, caching.Unhashable)
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
"""The legacy helper should still hash sanitized built-ins stably when used directly."""
caching, _ = caching_module
shared = [{"value": 1}, {"value": 2}]
sanitized = [shared, shared]
hashable = caching.to_hashable(sanitized)
assert hashable[0] == "list"
assert hashable[1][0] == hashable[1][1]
assert hashable[1][0][0] == "list"
def test_to_hashable_uses_parent_snapshot_during_expanded_phase(caching_module, monkeypatch):
"""Expanded-phase assembly should not reread a live parent container after snapshotting."""
caching, _ = caching_module
original_sort_key = caching._sanitized_sort_key
outer = [{"marker"}, 2]
def mutating_sort_key(obj, *args, **kwargs):
"""Mutate the live parent while a child container is being canonicalized."""
if obj == "marker":
outer[1] = 3
return original_sort_key(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_sanitized_sort_key", mutating_sort_key)
hashable = caching.to_hashable(outer)
assert hashable == ("list", (("set", ("marker",)), 2))
assert outer[1] == 3
def test_to_hashable_fails_closed_for_ordered_container_with_opaque_child(caching_module):
"""Ordered containers should fail closed when a child cannot be canonicalized."""
caching, _ = caching_module
result = caching.to_hashable([object()])
assert isinstance(result, caching.Unhashable)
def test_to_hashable_canonicalizes_dict_insertion_order(caching_module):
"""Dicts with the same content should hash identically regardless of insertion order."""
caching, _ = caching_module
first = {"b": 2, "a": 1}
second = {"a": 1, "b": 2}
assert caching.to_hashable(first) == ("dict", (("a", 1), ("b", 2)))
assert caching.to_hashable(first) == caching.to_hashable(second)
def test_to_hashable_fails_closed_for_opaque_dict_key(caching_module):
"""Opaque dict keys should fail closed instead of being traversed during hashing."""
caching, _ = caching_module
hashable = caching.to_hashable({_OpaqueValue(): 1})
assert isinstance(hashable, caching.Unhashable)
@pytest.mark.parametrize(
"container_factory",
[
set,
frozenset,
],
)
def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
"""Traversal RuntimeError should degrade unordered hash conversion to Unhashable."""
caching, _ = caching_module
def raising_sort_key(obj, *args, **kwargs):
"""Raise a traversal RuntimeError while unordered values are canonicalized."""
raise RuntimeError("container changed during iteration")
monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key)
hashable = caching.to_hashable(container_factory({"value"}))
assert isinstance(hashable, caching.Unhashable)
def test_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module, monkeypatch):
"""Ambiguous dict key ordering should fail closed instead of using insertion order."""
caching, _ = caching_module
original_sort_key = caching._sanitized_sort_key
ambiguous = {"a": 1, "b": 1}
def colliding_sort_key(obj, *args, **kwargs):
"""Force two distinct primitive keys to share the same ordering key."""
if obj == "a" or obj == "b":
return ("COLLIDE",)
return original_sort_key(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_sanitized_sort_key", colliding_sort_key)
hashable = caching.to_hashable(ambiguous)
assert isinstance(hashable, caching.Unhashable)
def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module, monkeypatch):
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
caching, _ = caching_module
original_sort_key = caching._primitive_signature_sort_key
ambiguous = {"a": 1, "b": 1}
def colliding_sort_key(obj):
"""Force two distinct primitive keys to share the same ordering key."""
if obj == "a" or obj == "b":
return ("COLLIDE",)
return original_sort_key(obj)
monkeypatch.setattr(caching, "_primitive_signature_sort_key", colliding_sort_key)
sanitized = caching._signature_to_hashable(ambiguous)
assert isinstance(sanitized, caching.Unhashable)
def test_signature_to_hashable_fails_closed_for_opaque_dict_key(caching_module):
"""Opaque dict keys should fail closed instead of being recursively canonicalized."""
caching, _ = caching_module
sanitized = caching._signature_to_hashable({_OpaqueValue(): 1})
assert isinstance(sanitized, caching.Unhashable)
def test_signature_to_hashable_fails_closed_on_dict_key_sort_collisions_even_with_distinct_values(caching_module, monkeypatch):
"""Different values must not mask dict key-sort collisions during canonicalization."""
caching, _ = caching_module
original_sort_key = caching._primitive_signature_sort_key
def colliding_sort_key(obj):
"""Force two distinct primitive keys to share the same ordering key."""
if obj == "a" or obj == "b":
return ("COLLIDE",)
return original_sort_key(obj)
monkeypatch.setattr(caching, "_primitive_signature_sort_key", colliding_sort_key)
sanitized = caching._signature_to_hashable({"a": 1, "b": 2})
assert isinstance(sanitized, caching.Unhashable)
@pytest.mark.parametrize(
"container_factory",
[
set,
frozenset,
],
)
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, monkeypatch, container_factory):
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
caching, _ = caching_module
original_sort_key = caching._sanitized_sort_key
container = container_factory({"a", "b"})
def colliding_sort_key(obj, *args, **kwargs):
"""Force two distinct primitive values to share the same ordering key."""
if obj == "a" or obj == "b":
return ("COLLIDE",)
return original_sort_key(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_sanitized_sort_key", colliding_sort_key)
hashable = caching.to_hashable(container)
assert isinstance(hashable, caching.Unhashable)
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
caching, nodes_module = caching_module
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
monkeypatch.setattr(
caching,
"to_hashable",
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
)
is_changed_value = []
is_changed_value.append(is_changed_value)
dynprompt = _FakeDynPrompt(
{
"node": {
"class_type": "UnitTestNode",
"inputs": {"value": 5},
}
}
)
key_set = caching.CacheKeySetInputSignature(
dynprompt,
["node"],
_FakeIsChangedCache({"node": is_changed_value}),
)
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
assert isinstance(signature, caching.Unhashable)
def test_shallow_is_changed_signature_accepts_primitive_lists(caching_module):
"""Primitive-only `is_changed` lists should stay hashable without deep descent."""
caching, _ = caching_module
sanitized = caching._shallow_is_changed_signature([1, "two", None, True])
assert sanitized == ("is_changed_list", (1, "two", None, True))
def test_shallow_is_changed_signature_accepts_structured_builtin_fingerprint_lists(caching_module):
"""Structured built-in `is_changed` fingerprints should remain representable."""
caching, _ = caching_module
sanitized = caching._shallow_is_changed_signature([("seed", 42), {"cfg": 8}])
assert sanitized == (
"is_changed_list",
(
("tuple", ("seed", 42)),
("dict", (("cfg", 8),)),
),
)
def test_shallow_is_changed_signature_fails_closed_for_opaque_payload(caching_module):
"""Opaque `is_changed` payloads should still fail closed."""
caching, _ = caching_module
sanitized = caching._shallow_is_changed_signature([_OpaqueValue()])
assert isinstance(sanitized, caching.Unhashable)
def test_get_immediate_node_signature_fails_closed_for_unhashable_is_changed(caching_module, monkeypatch):
"""Recursive `is_changed` payloads should fail the full fragment closed."""
caching, nodes_module = caching_module
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
is_changed_value = []
is_changed_value.append(is_changed_value)
dynprompt = _FakeDynPrompt(
{
"node": {
"class_type": "UnitTestNode",
"inputs": {"value": 5},
}
}
)
key_set = caching.CacheKeySetInputSignature(
dynprompt,
["node"],
_FakeIsChangedCache({"node": is_changed_value}),
)
signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "node", {}))
assert isinstance(signature, caching.Unhashable)
def test_get_immediate_node_signature_fails_closed_for_missing_node(caching_module):
"""Missing nodes should return the fail-closed sentinel instead of a NaN tuple."""
caching, _ = caching_module
dynprompt = _FakeDynPrompt({})
key_set = caching.CacheKeySetInputSignature(
dynprompt,
[],
_FakeIsChangedCache({}),
)
signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "missing", {}))
assert isinstance(signature, caching.Unhashable)

View File

@ -0,0 +1,242 @@
import asyncio
from comfy_execution import caching
class _StubDynPrompt:
def __init__(self, nodes):
self._nodes = nodes
def has_node(self, node_id):
return node_id in self._nodes
def get_node(self, node_id):
return self._nodes[node_id]
class _StubIsChangedCache:
async def get(self, node_id):
return None
class _StubNode:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
def test_shallow_is_changed_signature_keeps_primitive_only_list_shallow():
assert caching._shallow_is_changed_signature([1, "two", None, True]) == (
"is_changed_list",
(1, "two", None, True),
)
def test_shallow_is_changed_signature_keeps_primitive_only_tuple_shallow():
assert caching._shallow_is_changed_signature((1, "two", None, True)) == (
"is_changed_tuple",
(1, "two", None, True),
)
def test_shallow_is_changed_signature_keeps_structured_builtin_fingerprint_list():
assert caching._shallow_is_changed_signature([("seed", 42), {"cfg": 8}]) == (
"is_changed_list",
(
("tuple", ("seed", 42)),
("dict", (("cfg", 8),)),
),
)
def test_shallow_is_changed_signature_does_not_use_to_hashable(monkeypatch):
monkeypatch.setattr(
caching,
"to_hashable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
AssertionError("is_changed signature must not deep-canonicalize")
),
)
signature = caching._shallow_is_changed_signature([("seed", 42), {"cfg": 8}])
assert signature == (
"is_changed_list",
(
("tuple", ("seed", 42)),
("dict", (("cfg", 8),)),
),
)
def test_get_immediate_node_signature_canonicalizes_non_link_inputs(monkeypatch):
live_value = [1, {"nested": [2, 3]}]
dynprompt = _StubDynPrompt(
{
"1": {
"class_type": "TestCacheNode",
"inputs": {"value": live_value},
}
}
)
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {}))
assert signature == (
"TestCacheNode",
None,
("value", ("list", (1, ("dict", (("nested", ("list", (2, 3))),))))),
)
def test_to_hashable_walks_dicts_without_rebinding_traversal_stack():
live_value = {
"outer": {"nested": [2, 3]},
"items": [{"leaf": 4}],
}
assert caching.to_hashable(live_value) == (
"dict",
(
("items", ("list", (("dict", (("leaf", 4),)),))),
("outer", ("dict", (("nested", ("list", (2, 3))),))),
),
)
def test_get_immediate_node_signature_fails_closed_for_opaque_non_link_input(monkeypatch):
class OpaqueRuntimeValue:
pass
live_value = OpaqueRuntimeValue()
dynprompt = _StubDynPrompt(
{
"1": {
"class_type": "TestCacheNode",
"inputs": {"value": live_value},
}
}
)
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {}))
assert isinstance(signature, caching.Unhashable)
def test_get_node_signature_propagates_unhashable_immediate_fragment(monkeypatch):
class OpaqueRuntimeValue:
pass
dynprompt = _StubDynPrompt(
{
"1": {
"class_type": "TestCacheNode",
"inputs": {"value": OpaqueRuntimeValue()},
}
}
)
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
assert isinstance(signature, caching.Unhashable)
def test_get_node_signature_never_visits_raw_non_link_input(monkeypatch):
live_value = [1, 2, 3]
dynprompt = _StubDynPrompt(
{
"1": {
"class_type": "TestCacheNode",
"inputs": {"value": live_value},
}
}
)
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
monkeypatch.setattr(
caching,
"_signature_to_hashable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
AssertionError("outer signature canonicalizer should not run")
),
)
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
assert isinstance(signature, tuple)
def test_get_node_signature_keeps_deep_canonicalized_input_fragment(monkeypatch):
live_value = 1
for _ in range(8):
live_value = [live_value]
expected = caching.to_hashable(live_value)
dynprompt = _StubDynPrompt(
{
"1": {
"class_type": "TestCacheNode",
"inputs": {"value": live_value},
}
}
)
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
assert isinstance(signature, tuple)
assert signature[0][2][0] == "value"
assert signature[0][2][1] == expected
def test_get_node_signature_keeps_large_precanonicalized_fragment(monkeypatch):
live_value = object()
canonical_fragment = ("tuple", tuple(("list", (index, index + 1)) for index in range(256)))
dynprompt = _StubDynPrompt(
{
"1": {
"class_type": "TestCacheNode",
"inputs": {"value": live_value},
}
}
)
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
monkeypatch.setattr(
caching,
"to_hashable",
lambda value, max_nodes=caching._MAX_SIGNATURE_CONTAINER_VISITS: (
canonical_fragment if value is live_value else caching.Unhashable()
),
)
monkeypatch.setattr(
caching,
"_signature_to_hashable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
AssertionError("outer signature canonicalizer should not run")
),
)
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
assert isinstance(signature, tuple)
assert signature[0][2] == ("value", canonical_fragment)