mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 18:47:29 +08:00
Compare commits
67 Commits
5ca7ab26e5
...
ab8414eba3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab8414eba3 | ||
|
|
56c74094c7 | ||
|
|
594de378fe | ||
|
|
c8673542f7 | ||
|
|
df7bf1d3dc | ||
|
|
ef8f25601a | ||
|
|
8dc3f3f209 | ||
|
|
c011fb520c | ||
|
|
c945a433ae | ||
|
|
25757a53c9 | ||
|
|
1b25f1289e | ||
|
|
e35348aa53 | ||
|
|
cd8c7a2306 | ||
|
|
6bcd8b96ab | ||
|
|
6d4f9e86ab | ||
|
|
1548aee40e | ||
|
|
9c210473fc | ||
|
|
c1e9164c63 | ||
|
|
5e9a90186f | ||
|
|
ece906328a | ||
|
|
500ca8e02a | ||
|
|
3143b7981f | ||
|
|
c9b3f81e83 | ||
|
|
5e74e9b3ed | ||
|
|
c702cddf75 | ||
|
|
e13da8104c | ||
|
|
fdcc38b9ea | ||
|
|
c1ce00287c | ||
|
|
6e3bd33665 | ||
|
|
ce05e377a8 | ||
|
|
1a00f7743f | ||
|
|
a6472b1514 | ||
|
|
6158cd5820 | ||
|
|
bff714dda0 | ||
|
|
fce22da313 | ||
|
|
9f9d37bd9a | ||
|
|
088778c35d | ||
|
|
4c5f82971e | ||
|
|
f1d91a4c8c | ||
|
|
dbed5a1b52 | ||
|
|
24fdbb9aca | ||
|
|
a6624a9afd | ||
|
|
0b512198e8 | ||
|
|
9feb26928c | ||
|
|
fadd79ad48 | ||
|
|
77bc7bdd6b | ||
|
|
117afbc1d7 | ||
|
|
064eec2278 | ||
|
|
aceaa5e579 | ||
|
|
763089f681 | ||
|
|
1693dabc8f | ||
|
|
08063d2638 | ||
|
|
e069617e54 | ||
|
|
2bea0ee5d7 | ||
|
|
17863f603a | ||
|
|
31ba844624 | ||
|
|
1451001f64 | ||
|
|
1af99b2e81 | ||
|
|
3568b82b76 | ||
|
|
6728d4d439 | ||
|
|
4b431ffc27 | ||
|
|
880b51ac4f | ||
|
|
4d9516b909 | ||
|
|
39086890e2 | ||
|
|
2adde5a0e1 | ||
|
|
0c1bfad0df | ||
|
|
7d76a4447e |
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@ -145,6 +145,8 @@ jobs:
|
||||
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
echo 'local-portable' > ComfyUI/.comfy_environment
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
|
||||
@ -27,7 +27,7 @@ def frontend_install_warning_message():
|
||||
return f"""
|
||||
{get_missing_requirements_message()}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code.
|
||||
""".strip()
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
@ -31,8 +33,22 @@ class NodeReplaceManager:
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
"""Register a node replacement mapping.
|
||||
|
||||
Idempotent: if a replacement with the same (old_node_id, new_node_id)
|
||||
is already registered, the duplicate is ignored. This prevents stale
|
||||
entries from accumulating when custom nodes are reloaded in the same
|
||||
process (e.g. via ComfyUI-Manager).
|
||||
"""
|
||||
existing = self._replacements.setdefault(node_replace.old_node_id, [])
|
||||
for entry in existing:
|
||||
if entry.new_node_id == node_replace.new_node_id:
|
||||
logging.debug(
|
||||
"Node replacement %s -> %s already registered, ignoring duplicate.",
|
||||
node_replace.old_node_id, node_replace.new_node_id,
|
||||
)
|
||||
return
|
||||
existing.append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
|
||||
@ -1859,6 +1859,23 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
output = torch.zeros_like(x)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
current_start_frame = 0
|
||||
|
||||
# I2V: seed KV cache with the initial image latent before the denoising loop
|
||||
initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
|
||||
if initial_latent is not None:
|
||||
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
|
||||
n_init = initial_latent.shape[2]
|
||||
output[:, :, :n_init] = initial_latent
|
||||
|
||||
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
|
||||
transformer_options["ar_state"] = ar_state
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
|
||||
|
||||
current_start_frame = n_init
|
||||
remaining = lat_t - n_init
|
||||
num_blocks = -(-remaining // num_frame_per_block)
|
||||
|
||||
num_sigma_steps = len(sigmas) - 1
|
||||
total_real_steps = num_blocks * num_sigma_steps
|
||||
step_count = 0
|
||||
|
||||
@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
|
||||
return high_res_masks
|
||||
|
||||
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1):
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||
target_device=None, target_dtype=None):
|
||||
"""Track video with optional per-frame text-prompted detection."""
|
||||
bb = self.detector.backbone["vision_backbone"]
|
||||
|
||||
@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
|
||||
return self.tracker.track_video_with_detection(
|
||||
backbone_fn, images, initial_masks, detect_fn,
|
||||
new_det_thresh=new_det_thresh, max_objects=max_objects,
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
# SAM3 (non-multiplex) — no detection support, requires initial masks
|
||||
if initial_masks is None:
|
||||
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
|
||||
@ -200,8 +200,13 @@ def pack_masks(masks):
|
||||
|
||||
def unpack_masks(packed):
|
||||
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
|
||||
shifts = torch.arange(8, device=packed.device)
|
||||
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
|
||||
bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
|
||||
return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
|
||||
|
||||
|
||||
def _prep_frame(images, idx, device, dt, size):
|
||||
"""Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
|
||||
return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
|
||||
|
||||
|
||||
def _compute_backbone(backbone_fn, frame, frame_idx=None):
|
||||
@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
|
||||
# SAM3: drop last FPN level
|
||||
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
|
||||
|
||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
|
||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
|
||||
target_device=None, target_dtype=None):
|
||||
"""Track one object, computing backbone per frame to save VRAM."""
|
||||
N = images.shape[0]
|
||||
device, dt = images.device, images.dtype
|
||||
device = target_device if target_device is not None else images.device
|
||||
dt = target_dtype if target_dtype is not None else images.dtype
|
||||
size = self.image_size
|
||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||
all_masks = []
|
||||
|
||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
|
||||
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
|
||||
backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
|
||||
mask_input = None
|
||||
if frame_idx == 0:
|
||||
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
|
||||
@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
|
||||
|
||||
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
|
||||
|
||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
|
||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None,
|
||||
target_device=None, target_dtype=None, **kwargs):
|
||||
"""Track one or more objects across video frames.
|
||||
|
||||
Args:
|
||||
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
|
||||
images: [N, 3, 1008, 1008] video frames
|
||||
images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
|
||||
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
|
||||
pbar: optional progress bar
|
||||
|
||||
@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
|
||||
per_object = []
|
||||
for obj_idx in range(N_obj):
|
||||
obj_masks = self._track_single_object(
|
||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
|
||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
per_object.append(obj_masks)
|
||||
|
||||
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
|
||||
@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
|
||||
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
|
||||
return []
|
||||
|
||||
INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
|
||||
|
||||
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||
backbone_obj=None, pbar=None):
|
||||
backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
|
||||
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
|
||||
N, device, dt = images.shape[0], images.device, images.dtype
|
||||
if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
|
||||
max_objects = self.INTERNAL_MAX_OBJECTS
|
||||
N = images.shape[0]
|
||||
device = target_device if target_device is not None else images.device
|
||||
dt = target_dtype if target_dtype is not None else images.dtype
|
||||
size = self.image_size
|
||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||
all_masks = []
|
||||
idev = comfy.model_management.intermediate_device()
|
||||
@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
|
||||
prefetch = True
|
||||
except RuntimeError:
|
||||
pass
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
|
||||
|
||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
|
||||
@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
|
||||
backbone_stream.wait_stream(torch.cuda.current_stream(device))
|
||||
with torch.cuda.stream(backbone_stream):
|
||||
next_bb = self._compute_backbone_frame(
|
||||
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
|
||||
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
|
||||
det_masks = torch.empty(0, device=device)
|
||||
@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
|
||||
current_out = self._condition_with_masks(
|
||||
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
|
||||
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
|
||||
images[frame_idx:frame_idx + 1], trunk_out)
|
||||
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
|
||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||
obj_scores = [1.0] * mux_state.total_valid_entries
|
||||
if keep_alive is not None:
|
||||
@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
|
||||
current_out = self._condition_with_masks(
|
||||
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
|
||||
output_dict, N, mux_state, backbone_obj,
|
||||
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
|
||||
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
|
||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
|
||||
if keep_alive is not None:
|
||||
@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
|
||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||
cur_bb = next_bb
|
||||
else:
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
continue
|
||||
else:
|
||||
N_obj = mux_state.total_valid_entries
|
||||
@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
|
||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||
cur_bb = next_bb
|
||||
else:
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
|
||||
if not all_masks or all(m is None for m in all_masks):
|
||||
return {"packed_masks": None, "n_frames": N, "scores": []}
|
||||
|
||||
@ -26,6 +26,7 @@ import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
import comfy.float
|
||||
import comfy.hooks
|
||||
@ -1651,7 +1652,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||
|
||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
log_key = (self.patches_uuid, allocated_size, num_patches, len(self.backup), self.model.model_loaded_weight_memory)
|
||||
in_loop = bool(getattr(tqdm.tqdm, "_instances", None))
|
||||
level = logging.DEBUG if in_loop and getattr(self, "_last_prepare_log_key", None) == log_key else logging.INFO
|
||||
self._last_prepare_log_key = log_key
|
||||
logging.log(level, f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
12
comfy/sd.py
12
comfy/sd.py
@ -1122,7 +1122,17 @@ class VAE:
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||
if dims == 2:
|
||||
default_tile_x = 512 if tile_x is None else tile_x
|
||||
default_tile_y = 512 if tile_y is None else tile_y
|
||||
tile_shapes = [
|
||||
(1, pixel_samples.shape[1], min(pixel_samples.shape[2], max(1, default_tile_y)), min(pixel_samples.shape[3], max(1, default_tile_x))),
|
||||
(1, pixel_samples.shape[1], min(pixel_samples.shape[2], max(1, default_tile_y // 2)), min(pixel_samples.shape[3], max(1, default_tile_x * 2))),
|
||||
(1, pixel_samples.shape[1], min(pixel_samples.shape[2], max(1, default_tile_y * 2)), min(pixel_samples.shape[3], max(1, default_tile_x // 2))),
|
||||
]
|
||||
memory_used = max(self.memory_used_encode(shape, self.vae_dtype) for shape in tile_shapes)
|
||||
else:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
|
||||
args = {}
|
||||
|
||||
@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_text_inputs(resolutions: list[str]):
|
||||
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||
return [
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||
default=default_ratio,
|
||||
tooltip="Aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0",
|
||||
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def _seedance2_reference_inputs(resolutions: list[str]):
|
||||
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||
return [
|
||||
*_seedance2_text_inputs(resolutions),
|
||||
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0",
|
||||
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
|
||||
@ -83,13 +83,16 @@ class GeminiImageModel(str, Enum):
|
||||
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: Input.Image,
|
||||
images: Input.Image | list[Input.Image],
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
if image_limit < 0:
|
||||
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
|
||||
total_images = get_number_of_images(images)
|
||||
|
||||
# Accept either a single (possibly-batched) tensor or a list of them; share URL budget across all.
|
||||
images_list: list[Input.Image] = images if isinstance(images, list) else [images]
|
||||
total_images = sum(get_number_of_images(img) for img in images_list)
|
||||
if total_images <= 0:
|
||||
raise ValueError("No images provided to create_image_parts; at least one image is required.")
|
||||
|
||||
@ -98,10 +101,18 @@ async def create_image_parts(
|
||||
|
||||
# Number of images we'll send as URLs (fileData)
|
||||
num_url_images = min(effective_max, 10) # Vertex API max number of image links
|
||||
upload_kwargs: dict = {"wait_label": "Uploading reference images"}
|
||||
if effective_max > num_url_images:
|
||||
# Split path (e.g. 11+ images): suppress per-image counter to avoid a confusing dual-fraction label.
|
||||
upload_kwargs = {
|
||||
"wait_label": f"Uploading reference images ({num_url_images}+)",
|
||||
"show_batch_index": False,
|
||||
}
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
images_list,
|
||||
max_images=num_url_images,
|
||||
**upload_kwargs,
|
||||
)
|
||||
for reference_image_url in reference_images_urls:
|
||||
image_parts.append(
|
||||
@ -112,15 +123,22 @@ async def create_image_parts(
|
||||
)
|
||||
)
|
||||
)
|
||||
for idx in range(num_url_images, effective_max):
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=tensor_to_base64_string(images[idx]),
|
||||
if effective_max > num_url_images:
|
||||
flat: list[torch.Tensor] = []
|
||||
for tensor in images_list:
|
||||
if len(tensor.shape) == 4:
|
||||
flat.extend(tensor[i] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat.append(tensor)
|
||||
for idx in range(num_url_images, effective_max):
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=tensor_to_base64_string(flat[idx]),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
|
||||
@ -891,10 +909,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
"9:16",
|
||||
"16:9",
|
||||
"21:9",
|
||||
# "1:4",
|
||||
# "4:1",
|
||||
# "8:1",
|
||||
# "1:8",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
@ -902,12 +916,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
# "512px",
|
||||
"1K",
|
||||
"2K",
|
||||
"4K",
|
||||
],
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@ -956,6 +965,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1016,6 +1026,197 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
def _nano_banana_2_v2_model_inputs():
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"auto",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"4:5",
|
||||
"5:4",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"21:9",
|
||||
"1:4",
|
||||
"4:1",
|
||||
"8:1",
|
||||
"1:8",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
"if no image is provided, a 16:9 square is usually generated.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
options=["MINIMAL", "HIGH"],
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 15)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional reference image(s). Up to 14 images total.",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2V2",
|
||||
display_name="Nano Banana 2",
|
||||
category="api node/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text prompt describing the image to generate or the edits to apply. "
|
||||
"Include any constraints, styles, or details the model should follow.",
|
||||
default="",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 (Gemini 3.1 Flash Image)",
|
||||
_nano_banana_2_v2_model_inputs(),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"response_modalities",
|
||||
options=["IMAGE", "IMAGE+TEXT"],
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
IO.Image.Output(
|
||||
display_name="thought_image",
|
||||
tooltip="First image from the model's thinking process. "
|
||||
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
response_modalities: str,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_choice = model["model"]
|
||||
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||
model_id = "gemini-3.1-flash-image-preview"
|
||||
else:
|
||||
model_id = model_choice
|
||||
|
||||
images = model.get("images") or {}
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
if images:
|
||||
image_tensors: list[Input.Image] = [t for t in images.values() if t is not None]
|
||||
if image_tensors:
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > 14:
|
||||
raise ValueError("The current maximum number of supported images is 14.")
|
||||
parts.extend(await create_image_parts(cls, image_tensors))
|
||||
files = model.get("files")
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
image_config = GeminiImageConfig(imageSize=model["resolution"])
|
||||
if model["aspect_ratio"] != "auto":
|
||||
image_config.aspectRatio = model["aspect_ratio"]
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model_id}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await get_image_from_response(response),
|
||||
get_text_from_response(response),
|
||||
await get_image_from_response(response, thought=True),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1024,6 +1225,7 @@ class GeminiExtension(ComfyExtension):
|
||||
GeminiImage,
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
GeminiNanoBanana2V2,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
|
||||
@ -54,7 +54,12 @@ class GrokImageNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
|
||||
options=[
|
||||
"grok-imagine-image-quality",
|
||||
"grok-imagine-image-pro",
|
||||
"grok-imagine-image",
|
||||
"grok-imagine-image-beta",
|
||||
],
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -111,10 +116,12 @@ class GrokImageNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
|
||||
$rate := widgets.model = "grok-imagine-image-quality"
|
||||
? (widgets.resolution = "1k" ? 0.05 : 0.07)
|
||||
: ($contains(widgets.model, "pro") ? 0.07 : 0.02);
|
||||
{"type":"usd","usd": $rate * widgets.number_of_images}
|
||||
)
|
||||
""",
|
||||
@ -167,7 +174,12 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
|
||||
options=[
|
||||
"grok-imagine-image-quality",
|
||||
"grok-imagine-image-pro",
|
||||
"grok-imagine-image",
|
||||
"grok-imagine-image-beta",
|
||||
],
|
||||
),
|
||||
IO.Image.Input("image", display_name="images"),
|
||||
IO.String.Input(
|
||||
@ -228,11 +240,19 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
|
||||
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
|
||||
$isQualityModel := widgets.model = "grok-imagine-image-quality";
|
||||
$isPro := $contains(widgets.model, "pro");
|
||||
$rate := $isQualityModel
|
||||
? (widgets.resolution = "1k" ? 0.05 : 0.07)
|
||||
: ($isPro ? 0.07 : 0.02);
|
||||
$base := $isQualityModel ? 0.01 : 0.002;
|
||||
$output := $rate * widgets.number_of_images;
|
||||
$isPro
|
||||
? {"type":"usd","usd": $base + $output}
|
||||
: {"type":"range_usd","min_usd": $base + $output, "max_usd": 3 * $base + $output}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@ -2787,11 +2787,15 @@ class MotionControl(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"std": 0.07, "pro": 0.112};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
|
||||
$prices := {
|
||||
"kling-v3": {"std": 0.126, "pro": 0.168},
|
||||
"kling-v2-6": {"std": 0.07, "pro": 0.112}
|
||||
};
|
||||
$modelPrices := $lookup($prices, widgets.model);
|
||||
{"type":"usd","usd": $lookup($modelPrices, widgets.mode), "format":{"suffix":"/second"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
@ -17,6 +18,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def include_unique_id_in_input(class_type: str) -> bool:
|
||||
"""Return whether a node class includes UNIQUE_ID among its hidden inputs."""
|
||||
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
@ -24,52 +26,412 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
|
||||
class CacheKeySet(ABC):
|
||||
"""Base helper for building and storing cache keys for prompt nodes."""
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""Initialize cache-key storage for a dynamic prompt execution pass."""
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
@abstractmethod
|
||||
async def add_keys(self, node_ids):
|
||||
"""Populate cache keys for the provided node ids."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
"""Return the set of node ids currently tracked by this key set."""
|
||||
return set(self.keys.keys())
|
||||
|
||||
def get_used_keys(self):
|
||||
"""Return the computed cache keys currently in use."""
|
||||
return self.keys.values()
|
||||
|
||||
def get_used_subcache_keys(self):
|
||||
"""Return the computed subcache keys currently in use."""
|
||||
return self.subcache_keys.values()
|
||||
|
||||
def get_data_key(self, node_id):
|
||||
"""Return the cache key for a node, if present."""
|
||||
return self.keys.get(node_id, None)
|
||||
|
||||
def get_subcache_key(self, node_id):
|
||||
"""Return the subcache key for a node, if present."""
|
||||
return self.subcache_keys.get(node_id, None)
|
||||
|
||||
class Unhashable:
|
||||
def __init__(self):
|
||||
self.value = float("NaN")
|
||||
"""Hashable identity sentinel for values that cannot be represented safely in cache keys."""
|
||||
pass
|
||||
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
elif isinstance(obj, Sequence):
|
||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||
else:
|
||||
# TODO - Support other objects like tensors?
|
||||
|
||||
_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None))
|
||||
_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset)
|
||||
_MAX_SIGNATURE_DEPTH = 32
|
||||
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
|
||||
_FAILED_SIGNATURE = object()
|
||||
|
||||
|
||||
def _shallow_is_changed_signature(value):
|
||||
"""Reduce execution-time `is_changed` values through a fail-closed builtin canonicalizer."""
|
||||
value_type = type(value)
|
||||
if value_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
return value
|
||||
|
||||
if value_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
return Unhashable()
|
||||
|
||||
canonical = _signature_to_hashable(value, max_nodes=64)
|
||||
if type(canonical) is Unhashable:
|
||||
return canonical
|
||||
if value_type is list or value_type is tuple:
|
||||
container_tag = "is_changed_list" if value_type is list else "is_changed_tuple"
|
||||
return (container_tag, canonical[1])
|
||||
|
||||
return canonical
|
||||
|
||||
|
||||
def _primitive_signature_sort_key(obj):
|
||||
"""Return a deterministic ordering key for primitive signature values."""
|
||||
obj_type = type(obj)
|
||||
return ("primitive", obj_type.__module__, obj_type.__qualname__, repr(obj))
|
||||
|
||||
|
||||
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
|
||||
"""Return a deterministic ordering key for sanitized built-in container content."""
|
||||
if depth >= max_depth:
|
||||
return ("MAX_DEPTH",)
|
||||
|
||||
if active is None:
|
||||
active = set()
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
obj_type = type(obj)
|
||||
if obj_type is Unhashable:
|
||||
return ("UNHASHABLE",)
|
||||
elif obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
|
||||
elif obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
if obj_id in active:
|
||||
return ("CYCLE",)
|
||||
|
||||
active.add(obj_id)
|
||||
try:
|
||||
if obj_type is dict:
|
||||
items = [
|
||||
(
|
||||
_sanitized_sort_key(k, depth + 1, max_depth, active, memo),
|
||||
_sanitized_sort_key(v, depth + 1, max_depth, active, memo),
|
||||
)
|
||||
for k, v in obj.items()
|
||||
]
|
||||
items.sort()
|
||||
result = ("dict", tuple(items))
|
||||
elif obj_type is list:
|
||||
result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
|
||||
elif obj_type is tuple:
|
||||
result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
|
||||
elif obj_type is set:
|
||||
result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
|
||||
else:
|
||||
result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
|
||||
finally:
|
||||
active.discard(obj_id)
|
||||
|
||||
memo[obj_id] = result
|
||||
return result
|
||||
|
||||
|
||||
def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
|
||||
"""Canonicalize signature inputs directly into their final hashable form."""
|
||||
if depth >= max_depth:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
if active is None:
|
||||
active = set()
|
||||
if memo is None:
|
||||
memo = {}
|
||||
if budget is None:
|
||||
budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS}
|
||||
|
||||
obj_type = type(obj)
|
||||
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
return obj, _primitive_signature_sort_key(obj)
|
||||
if obj_type is Unhashable or obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
if obj_id in active:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
budget["remaining"] -= 1
|
||||
if budget["remaining"] < 0:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
active.add(obj_id)
|
||||
try:
|
||||
if obj_type is dict:
|
||||
try:
|
||||
items = list(obj.items())
|
||||
except RuntimeError:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
ordered_items = []
|
||||
for key, value in items:
|
||||
if type(key) not in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
return _FAILED_SIGNATURE
|
||||
key_result = (key, _primitive_signature_sort_key(key))
|
||||
value_result = _signature_to_hashable_impl(value, depth + 1, max_depth, active, memo, budget)
|
||||
if value_result is _FAILED_SIGNATURE:
|
||||
return _FAILED_SIGNATURE
|
||||
key_value, key_sort = key_result
|
||||
value_value, value_sort = value_result
|
||||
ordered_items.append((key_sort, value_sort, key_value, value_value))
|
||||
|
||||
ordered_items.sort(key=lambda item: (item[0], item[1]))
|
||||
for index in range(1, len(ordered_items)):
|
||||
previous_key_sort = ordered_items[index - 1][0]
|
||||
current_key_sort = ordered_items[index][0]
|
||||
if previous_key_sort == current_key_sort:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
value = ("dict", tuple((key_value, value_value) for _, _, key_value, value_value in ordered_items))
|
||||
sort_key = ("dict", tuple((key_sort, value_sort) for key_sort, value_sort, _, _ in ordered_items))
|
||||
elif obj_type is list or obj_type is tuple:
|
||||
try:
|
||||
items = list(obj)
|
||||
except RuntimeError:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
child_results = []
|
||||
for item in items:
|
||||
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
|
||||
if child_result is _FAILED_SIGNATURE:
|
||||
return _FAILED_SIGNATURE
|
||||
child_results.append(child_result)
|
||||
|
||||
container_tag = "list" if obj_type is list else "tuple"
|
||||
value = (container_tag, tuple(child for child, _ in child_results))
|
||||
sort_key = (container_tag, tuple(child_sort for _, child_sort in child_results))
|
||||
else:
|
||||
try:
|
||||
items = list(obj)
|
||||
except RuntimeError:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
ordered_items = []
|
||||
for item in items:
|
||||
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
|
||||
if child_result is _FAILED_SIGNATURE:
|
||||
return _FAILED_SIGNATURE
|
||||
child_value, child_sort = child_result
|
||||
ordered_items.append((child_sort, child_value))
|
||||
|
||||
ordered_items.sort(key=lambda item: item[0])
|
||||
for index in range(1, len(ordered_items)):
|
||||
previous_sort_key, previous_value = ordered_items[index - 1]
|
||||
current_sort_key, current_value = ordered_items[index]
|
||||
if previous_sort_key == current_sort_key and previous_value != current_value:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
container_tag = "set" if obj_type is set else "frozenset"
|
||||
value = (container_tag, tuple(child_value for _, child_value in ordered_items))
|
||||
sort_key = (container_tag, tuple(child_sort for child_sort, _ in ordered_items))
|
||||
finally:
|
||||
active.discard(obj_id)
|
||||
|
||||
memo[obj_id] = (value, sort_key)
|
||||
return memo[obj_id]
|
||||
|
||||
|
||||
def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||
"""Build the final cache-signature representation in one fail-closed pass."""
|
||||
try:
|
||||
result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes})
|
||||
except RuntimeError:
|
||||
return Unhashable()
|
||||
if result is _FAILED_SIGNATURE:
|
||||
return Unhashable()
|
||||
return result[0]
|
||||
|
||||
|
||||
def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||
"""Convert sanitized prompt inputs into a stable hashable representation.
|
||||
|
||||
The input is expected to already be sanitized to plain built-in containers,
|
||||
but this function still fails safe for anything unexpected. Traversal is
|
||||
iterative and memoized so shared built-in substructures do not trigger
|
||||
exponential re-walks during cache-key construction.
|
||||
"""
|
||||
obj_type = type(obj)
|
||||
if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable:
|
||||
return obj
|
||||
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
return Unhashable()
|
||||
|
||||
memo = {}
|
||||
active = set()
|
||||
snapshots = {}
|
||||
sort_memo = {}
|
||||
processed = 0
|
||||
# Keep traversal state separate from container snapshots/results.
|
||||
work_stack = [(obj, False)]
|
||||
|
||||
def resolve_value(value):
|
||||
"""Resolve a child value from the completed memo table when available."""
|
||||
value_type = type(value)
|
||||
if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable:
|
||||
return value
|
||||
return memo.get(id(value), Unhashable())
|
||||
|
||||
def is_failed(value):
|
||||
"""Return whether a resolved child value represents failed canonicalization."""
|
||||
return type(value) is Unhashable
|
||||
|
||||
def resolve_unordered_values(current_items, container_tag):
|
||||
"""Resolve a set-like container or fail closed if ordering is ambiguous."""
|
||||
try:
|
||||
ordered_items = [
|
||||
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
|
||||
for item in current_items
|
||||
]
|
||||
if any(is_failed(value) for _, value in ordered_items):
|
||||
return Unhashable()
|
||||
ordered_items.sort(key=lambda item: item[0])
|
||||
except RuntimeError:
|
||||
return Unhashable()
|
||||
|
||||
for index in range(1, len(ordered_items)):
|
||||
previous_key, previous_value = ordered_items[index - 1]
|
||||
current_key, current_value = ordered_items[index]
|
||||
if previous_key == current_key and previous_value != current_value:
|
||||
return Unhashable()
|
||||
|
||||
return (container_tag, tuple(value for _, value in ordered_items))
|
||||
|
||||
while work_stack:
|
||||
entry = work_stack.pop()
|
||||
if len(entry) == 3:
|
||||
_, current_id, current_type = entry
|
||||
current = None
|
||||
expanded = True
|
||||
else:
|
||||
current, expanded = entry
|
||||
current_type = type(current)
|
||||
current_id = id(current)
|
||||
|
||||
if not expanded and (current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable):
|
||||
continue
|
||||
if not expanded and current_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
memo[current_id] = Unhashable()
|
||||
continue
|
||||
|
||||
if current_id in memo:
|
||||
continue
|
||||
|
||||
if expanded:
|
||||
active.discard(current_id)
|
||||
try:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
memo[current_id] = Unhashable()
|
||||
continue
|
||||
|
||||
if current_type is dict:
|
||||
ordered_items = [
|
||||
(_sanitized_sort_key(k, memo=sort_memo), k, resolve_value(v))
|
||||
for k, v in items
|
||||
]
|
||||
if any(type(key) not in _PRIMITIVE_SIGNATURE_TYPES or is_failed(value) for _, key, value in ordered_items):
|
||||
memo[current_id] = Unhashable()
|
||||
continue
|
||||
ordered_items.sort(key=lambda item: item[0])
|
||||
for index in range(1, len(ordered_items)):
|
||||
if ordered_items[index - 1][0] == ordered_items[index][0]:
|
||||
memo[current_id] = Unhashable()
|
||||
break
|
||||
else:
|
||||
memo[current_id] = (
|
||||
"dict",
|
||||
tuple((key, value) for _, key, value in ordered_items),
|
||||
)
|
||||
elif current_type is list:
|
||||
resolved_items = tuple(resolve_value(item) for item in items)
|
||||
if any(is_failed(item) for item in resolved_items):
|
||||
memo[current_id] = Unhashable()
|
||||
else:
|
||||
memo[current_id] = ("list", resolved_items)
|
||||
elif current_type is tuple:
|
||||
resolved_items = tuple(resolve_value(item) for item in items)
|
||||
if any(is_failed(item) for item in resolved_items):
|
||||
memo[current_id] = Unhashable()
|
||||
else:
|
||||
memo[current_id] = ("tuple", resolved_items)
|
||||
elif current_type is set:
|
||||
memo[current_id] = resolve_unordered_values(items, "set")
|
||||
else:
|
||||
memo[current_id] = resolve_unordered_values(items, "frozenset")
|
||||
except RuntimeError:
|
||||
memo[current_id] = Unhashable()
|
||||
continue
|
||||
|
||||
if current_id in active:
|
||||
memo[current_id] = Unhashable()
|
||||
continue
|
||||
|
||||
processed += 1
|
||||
if processed > max_nodes:
|
||||
return Unhashable()
|
||||
|
||||
active.add(current_id)
|
||||
if current_type is dict:
|
||||
try:
|
||||
items = list(current.items())
|
||||
snapshots[current_id] = items
|
||||
except RuntimeError:
|
||||
memo[current_id] = Unhashable()
|
||||
active.discard(current_id)
|
||||
continue
|
||||
for key, value in items:
|
||||
if type(key) not in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
snapshots.pop(current_id, None)
|
||||
memo[current_id] = Unhashable()
|
||||
active.discard(current_id)
|
||||
break
|
||||
else:
|
||||
work_stack.append(("EXPANDED", current_id, current_type))
|
||||
for _, value in reversed(items):
|
||||
work_stack.append((value, False))
|
||||
continue
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
items = list(current)
|
||||
snapshots[current_id] = items
|
||||
except RuntimeError:
|
||||
memo[current_id] = Unhashable()
|
||||
active.discard(current_id)
|
||||
continue
|
||||
work_stack.append(("EXPANDED", current_id, current_type))
|
||||
for item in reversed(items):
|
||||
work_stack.append((item, False))
|
||||
|
||||
return memo.get(id(obj), Unhashable())
|
||||
|
||||
class CacheKeySetID(CacheKeySet):
|
||||
"""Cache-key strategy that keys nodes by node id and class type."""
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""Initialize identity-based cache keys for the supplied dynamic prompt."""
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
|
||||
async def add_keys(self, node_ids):
|
||||
"""Populate identity-based keys for nodes that exist in the dynamic prompt."""
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
@ -80,15 +442,19 @@ class CacheKeySetID(CacheKeySet):
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
class CacheKeySetInputSignature(CacheKeySet):
|
||||
"""Cache-key strategy that hashes a node's immediate inputs plus ancestor references."""
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""Initialize input-signature-based cache keys for the supplied dynamic prompt."""
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
"""Return whether node ids should be included in computed input signatures."""
|
||||
return False
|
||||
|
||||
async def add_keys(self, node_ids):
|
||||
"""Populate input-signature-based keys for nodes in the dynamic prompt."""
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
@ -99,21 +465,37 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
async def get_node_signature(self, dynprompt, node_id):
|
||||
"""Build the full cache signature for a node and its ordered ancestors."""
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
immediate = await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)
|
||||
if type(immediate) is Unhashable:
|
||||
return immediate
|
||||
signature.append(immediate)
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
immediate = await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)
|
||||
if type(immediate) is Unhashable:
|
||||
return immediate
|
||||
signature.append(immediate)
|
||||
return tuple(signature)
|
||||
|
||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
"""Build the immediate cache-signature fragment for a node.
|
||||
|
||||
Link inputs are reduced to ancestor references here. Non-link values
|
||||
are canonicalized or failed closed before being appended so the final
|
||||
node signature is assembled from already-hashable fragments.
|
||||
"""
|
||||
if not dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
return Unhashable()
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
is_changed_signature = _shallow_is_changed_signature(await self.is_changed_cache.get(node_id))
|
||||
if type(is_changed_signature) is Unhashable:
|
||||
return is_changed_signature
|
||||
signature = [class_type, is_changed_signature]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
@ -123,18 +505,23 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, inputs[key]))
|
||||
return signature
|
||||
value_signature = to_hashable(inputs[key])
|
||||
if type(value_signature) is Unhashable:
|
||||
return value_signature
|
||||
signature.append((key, value_signature))
|
||||
return tuple(signature)
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
# deterministic based on which specific inputs the ancestor is connected by.
|
||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||
"""Return ancestors in deterministic traversal order and their index mapping."""
|
||||
ancestors = []
|
||||
order_mapping = {}
|
||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||
return ancestors, order_mapping
|
||||
|
||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||
"""Recursively collect ancestors in input order without revisiting prior nodes."""
|
||||
if not dynprompt.has_node(node_id):
|
||||
return
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
def is_link(obj):
|
||||
if not isinstance(obj, list):
|
||||
"""Return whether obj is a plain prompt link of the form [node_id, output_index]."""
|
||||
# Prompt links produced by the frontend / GraphBuilder are plain Python
|
||||
# lists in the form [node_id, output_index]. Some custom-node paths can
|
||||
# inject foreign runtime objects into prompt inputs during on-prompt graph
|
||||
# rewriting or subgraph construction. Be strict here so cache signature
|
||||
# building never tries to treat list-like proxy objects as links.
|
||||
if type(obj) is not list:
|
||||
return False
|
||||
if len(obj) != 2:
|
||||
return False
|
||||
if not isinstance(obj[0], str):
|
||||
if type(obj[0]) is not str:
|
||||
return False
|
||||
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
||||
if type(obj[1]) is not int:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ class SamplerEulerCFGpp(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SamplerEulerCFGpp",
|
||||
display_name="SamplerEulerCFG++",
|
||||
category="_for_testing", # "sampling/custom_sampling/samplers"
|
||||
category="experimental", # "sampling/custom_sampling/samplers"
|
||||
inputs=[
|
||||
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
|
||||
],
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
||||
- ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image)
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -9,6 +10,7 @@ from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
@ -71,12 +73,62 @@ class SamplerARVideo(io.ComfyNode):
|
||||
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
||||
|
||||
|
||||
class ARVideoI2V(io.ComfyNode):
|
||||
"""Image-to-video setup for AR video models (Causal Forcing, Self-Forcing).
|
||||
|
||||
VAE-encodes the start image and stores it in the model's transformer_options
|
||||
so that sample_ar_video can seed the KV cache before denoising.
|
||||
Uses the same T2V model checkpoint -- no separate I2V architecture needed.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ARVideoI2V",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("start_image"),
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="MODEL"),
|
||||
io.Latent.Output(display_name="LATENT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput:
|
||||
start_image = comfy.utils.common_upscale(
|
||||
start_image[:1].movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
|
||||
initial_latent = vae.encode(start_image[:, :, :, :3])
|
||||
|
||||
m = model.clone()
|
||||
to = m.model_options.setdefault("transformer_options", {})
|
||||
ar_cfg = to.setdefault("ar_config", {})
|
||||
ar_cfg["initial_latent"] = initial_latent
|
||||
|
||||
lat_t = ((length - 1) // 4) + 1
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput(m, {"samples": latent})
|
||||
|
||||
|
||||
class ARVideoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyARVideoLatent,
|
||||
SamplerARVideo,
|
||||
ARVideoI2V,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ class UNetSelfAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetSelfAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
category="experimental/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
@ -48,7 +48,7 @@ class UNetCrossAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetCrossAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
category="experimental/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
@ -72,7 +72,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
search_aliases=["clip attention scale", "text encoder attention"],
|
||||
category="_for_testing/attention_experiments",
|
||||
category="experimental/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
@ -106,7 +106,7 @@ class UNetTemporalAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetTemporalAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
category="experimental/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
|
||||
@ -10,6 +10,7 @@ class AudioEncoderLoader(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AudioEncoderLoader",
|
||||
display_name="Load Audio Encoder",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
|
||||
@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanCameraEmbedding",
|
||||
category="camera",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"camera_pose",
|
||||
|
||||
@ -8,7 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeControlnet",
|
||||
category="_for_testing/conditioning",
|
||||
category="experimental/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
@ -35,7 +35,7 @@ class T5TokenizerOptions(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="T5TokenizerOptions",
|
||||
category="_for_testing/conditioning",
|
||||
category="experimental/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
||||
|
||||
@ -10,7 +10,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ContextWindowsManual",
|
||||
display_name="Context Windows (Manual)",
|
||||
category="context",
|
||||
category="model_patches",
|
||||
description="Manually set context windows.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
|
||||
@ -984,7 +984,7 @@ class AddNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="AddNoise",
|
||||
category="_for_testing/custom_sampling/noise",
|
||||
category="experimental/custom_sampling/noise",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -1034,7 +1034,7 @@ class ManualSigmas(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
search_aliases=["custom noise schedule", "define sigmas"],
|
||||
category="_for_testing/custom_sampling",
|
||||
category="experimental/custom_sampling",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||
|
||||
@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
|
||||
node_id="DifferentialDiffusion",
|
||||
search_aliases=["inpaint gradient", "variable denoise strength"],
|
||||
display_name="Differential Diffusion",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
|
||||
@ -60,7 +60,7 @@ class FreSca(io.ComfyNode):
|
||||
node_id="FreSca",
|
||||
search_aliases=["frequency guidance"],
|
||||
display_name="FreSca",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
description="Applies frequency-dependent scaling to the guidance",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
|
||||
@ -131,6 +131,8 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15SuperResolution",
|
||||
display_name="Hunyuan Video 1.5 Super Resolution",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -381,6 +383,8 @@ class HunyuanRefinerLatent(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanRefinerLatent",
|
||||
display_name="Hunyuan Latent Refiner",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
|
||||
@ -40,7 +40,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2Conditioning",
|
||||
category="conditioning/video_models",
|
||||
category="conditioning/3d_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("clip_vision_output"),
|
||||
],
|
||||
@ -65,7 +65,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2ConditioningMultiView",
|
||||
category="conditioning/video_models",
|
||||
category="conditioning/3d_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("front", optional=True),
|
||||
IO.ClipVisionOutput.Input("left", optional=True),
|
||||
@ -424,6 +424,7 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
display_name="Voxel to Mesh (Basic)",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
@ -453,6 +454,7 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMesh",
|
||||
display_name="Voxel to Mesh",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
|
||||
@ -102,6 +102,7 @@ class HypernetworkLoader(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="HypernetworkLoader",
|
||||
display_name="Load Hypernetwork",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
|
||||
@ -91,7 +91,7 @@ class LoraSave(io.ComfyNode):
|
||||
node_id="LoraSave",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Extract and Save Lora",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
inputs=[
|
||||
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
|
||||
|
||||
@ -594,7 +594,8 @@ class LTXVPreprocess(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVPreprocess",
|
||||
category="image",
|
||||
display_name="LTXV Preprocess",
|
||||
category="video/preprocessors",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input(
|
||||
|
||||
@ -11,7 +11,7 @@ class Mahiro(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="Mahiro",
|
||||
display_name="Positive-Biased Guidance",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
|
||||
@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyMathExpression",
|
||||
display_name="Math Expression",
|
||||
category="math",
|
||||
category="logic",
|
||||
search_aliases=[
|
||||
"expression", "formula", "calculate", "calculator",
|
||||
"eval", "math",
|
||||
|
||||
@ -21,7 +21,7 @@ class NumberConvertNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyNumberConvert",
|
||||
display_name="Number Convert",
|
||||
category="math",
|
||||
category="utils",
|
||||
search_aliases=[
|
||||
"int to float", "float to int", "number convert",
|
||||
"int2float", "float2int", "cast", "parse number",
|
||||
|
||||
@ -24,8 +24,8 @@ class PerpNeg(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PerpNeg",
|
||||
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
|
||||
category="_for_testing",
|
||||
display_name="Perp-Neg (DEPRECATED by Perp-Neg Guider)",
|
||||
category="experimental",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("empty_conditioning"),
|
||||
@ -127,7 +127,8 @@ class PerpNegGuider(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PerpNegGuider",
|
||||
category="_for_testing",
|
||||
display_name="Perp-Neg Guider",
|
||||
category="experimental",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
|
||||
@ -123,7 +123,7 @@ class PhotoMakerLoader(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerLoader",
|
||||
category="_for_testing/photomaker",
|
||||
category="experimental/photomaker",
|
||||
inputs=[
|
||||
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
||||
],
|
||||
@ -149,7 +149,7 @@ class PhotoMakerEncode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerEncode",
|
||||
category="_for_testing/photomaker",
|
||||
category="experimental/photomaker",
|
||||
inputs=[
|
||||
io.Photomaker.Input("photomaker"),
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -116,6 +116,7 @@ class Quantize(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageQuantize",
|
||||
display_name="Quantize Image",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
@ -181,6 +182,7 @@ class Sharpen(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageSharpen",
|
||||
display_name="Sharpen Image",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
@ -436,7 +438,7 @@ class ResizeImageMaskNode(io.ComfyNode):
|
||||
node_id="ResizeImageMaskNode",
|
||||
display_name="Resize Image/Mask",
|
||||
description="Resize an image or mask using various scaling methods.",
|
||||
category="transform",
|
||||
category="image/transform",
|
||||
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
|
||||
inputs=[
|
||||
io.MatchType.Input("input", template=template),
|
||||
|
||||
@ -15,7 +15,7 @@ class RTDETR_detect(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="RTDETR_detect",
|
||||
display_name="RT-DETR Detect",
|
||||
category="detection/",
|
||||
category="detection",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
||||
inputs=[
|
||||
io.Model.Input("model", display_name="model"),
|
||||
@ -71,7 +71,7 @@ class DrawBBoxes(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="DrawBBoxes",
|
||||
display_name="Draw BBoxes",
|
||||
category="detection/",
|
||||
category="detection",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
|
||||
inputs=[
|
||||
io.Image.Input("image", optional=True),
|
||||
|
||||
@ -113,7 +113,7 @@ class SelfAttentionGuidance(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SelfAttentionGuidance",
|
||||
display_name="Self-Attention Guidance",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
|
||||
|
||||
@ -93,7 +93,7 @@ class SAM3_Detect(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SAM3_Detect",
|
||||
display_name="SAM3 Detect",
|
||||
category="detection/",
|
||||
category="detection",
|
||||
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
|
||||
inputs=[
|
||||
io.Model.Input("model", display_name="model"),
|
||||
@ -265,15 +265,15 @@ class SAM3_VideoTrack(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SAM3_VideoTrack",
|
||||
display_name="SAM3 Video Track",
|
||||
category="detection/",
|
||||
category="detection",
|
||||
search_aliases=["sam3", "video", "track", "propagate"],
|
||||
inputs=[
|
||||
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
|
||||
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
|
||||
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
|
||||
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
|
||||
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection."),
|
||||
io.Int.Input("max_objects", display_name="max_objects", default=4, min=0, max=64, tooltip="Max tracked objects. Initial masks count toward this limit. 0 uses the internal cap of 64."),
|
||||
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
|
||||
],
|
||||
outputs=[
|
||||
@ -290,8 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
||||
dtype = model.model.get_dtype()
|
||||
sam3_model = model.model.diffusion_model
|
||||
|
||||
frames = images[..., :3].movedim(-1, 1)
|
||||
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
|
||||
frames_in = images[..., :3].movedim(-1, 1)
|
||||
|
||||
init_masks = None
|
||||
if initial_mask is not None:
|
||||
@ -308,7 +307,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
||||
result = sam3_model.forward_video(
|
||||
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
|
||||
new_det_thresh=detection_threshold, max_objects=max_objects,
|
||||
detect_interval=detect_interval)
|
||||
detect_interval=detect_interval, target_device=device, target_dtype=dtype)
|
||||
result["orig_size"] = (H, W)
|
||||
return io.NodeOutput(result)
|
||||
|
||||
@ -321,7 +320,7 @@ class SAM3_TrackPreview(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SAM3_TrackPreview",
|
||||
display_name="SAM3 Track Preview",
|
||||
category="detection/",
|
||||
category="detection",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||
io.Image.Input("images", display_name="images", optional=True),
|
||||
@ -449,14 +448,18 @@ class SAM3_TrackPreview(io.ComfyNode):
|
||||
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
|
||||
has = area > 1
|
||||
scores = track_data.get("scores", [])
|
||||
label_scale = max(3, H // 240) # Scale font with resolutio
|
||||
size_caps = (area.float().sqrt() / 15).clamp_(min=1).long().tolist() #cap per-object so the number doesn't dwarf small masks
|
||||
for obj_idx in range(N_obj):
|
||||
if has[obj_idx]:
|
||||
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
|
||||
color = cls.COLORS[obj_idx % len(cls.COLORS)]
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
|
||||
obj_scale = min(label_scale, size_caps[obj_idx])
|
||||
score_scale = max(1, obj_scale * 2 // 3)
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color, scale=obj_scale)
|
||||
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
|
||||
_cx, _cy + 5 * 3 + 3, color, scale=2)
|
||||
_cx, _cy + 5 * obj_scale + 3, color, scale=score_scale)
|
||||
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
|
||||
else:
|
||||
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
|
||||
@ -475,7 +478,7 @@ class SAM3_TrackToMask(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SAM3_TrackToMask",
|
||||
display_name="SAM3 Track to Mask",
|
||||
category="detection/",
|
||||
category="detection",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||
io.String.Input("object_indices", display_name="object_indices", default="",
|
||||
@ -507,9 +510,10 @@ class SAM3_TrackToMask(io.ComfyNode):
|
||||
if not indices:
|
||||
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
||||
|
||||
selected = packed[:, indices]
|
||||
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
|
||||
union = binary.any(dim=1, keepdim=True).float()
|
||||
union_packed = packed[:, indices[0]].clone()
|
||||
for i in indices[1:]:
|
||||
union_packed |= packed[:, i]
|
||||
union = unpack_masks(union_packed).unsqueeze(1).float() # [N, 1, Hm, Wm]
|
||||
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
||||
return io.NodeOutput(mask_out)
|
||||
|
||||
|
||||
@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_SuperResolutionControlnet",
|
||||
category="_for_testing/stable_cascade",
|
||||
category="experimental/stable_cascade",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -26,7 +26,8 @@ class TextGenerate(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id="TextGenerate",
|
||||
category="textgen",
|
||||
display_name="Generate Text",
|
||||
category="text",
|
||||
search_aliases=["LLM", "gemma"],
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
@ -157,6 +158,7 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
||||
parent_schema = super().define_schema()
|
||||
return io.Schema(
|
||||
node_id="TextGenerateLTX2Prompt",
|
||||
display_name="Generate LTX2 Prompt",
|
||||
category=parent_schema.category,
|
||||
inputs=parent_schema.inputs,
|
||||
outputs=parent_schema.outputs,
|
||||
|
||||
@ -10,7 +10,7 @@ class TorchCompileModel(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="TorchCompileModel",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input(
|
||||
|
||||
@ -1361,7 +1361,7 @@ class SaveLoRA(io.ComfyNode):
|
||||
node_id="SaveLoRA",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Save LoRA Weights",
|
||||
category="loaders",
|
||||
category="advanced/model_merging",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
|
||||
@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader:
|
||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "loaders/video_models"
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
@ -22,7 +22,7 @@ class SaveImageWebsocket:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "api/image"
|
||||
CATEGORY = "image"
|
||||
|
||||
def save_images(self, images):
|
||||
pbar = comfy.utils.ProgressBar(images.shape[0])
|
||||
@ -42,3 +42,7 @@ class SaveImageWebsocket:
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveImageWebsocket": SaveImageWebsocket,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveImageWebsocket": "Save Image (Websocket)",
|
||||
}
|
||||
14
nodes.py
14
nodes.py
@ -330,7 +330,7 @@ class VAEDecodeTiled:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "experimental"
|
||||
|
||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||
if tile_size < overlap * 4:
|
||||
@ -377,7 +377,7 @@ class VAEEncodeTiled:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "experimental"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
@ -493,7 +493,7 @@ class SaveLatent:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "experimental"
|
||||
|
||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
@ -538,7 +538,7 @@ class LoadLatent:
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||
return {"required": {"latent": [sorted(files), ]}, }
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "experimental"
|
||||
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
FUNCTION = "load"
|
||||
@ -1443,7 +1443,7 @@ class LatentBlend:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "blend"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "experimental"
|
||||
|
||||
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
|
||||
|
||||
@ -2092,6 +2092,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StyleModelLoader": "Load Style Model",
|
||||
"CLIPVisionLoader": "Load CLIP Vision",
|
||||
"UNETLoader": "Load Diffusion Model",
|
||||
"unCLIPCheckpointLoader": "Load unCLIP Checkpoint",
|
||||
"GLIGENLoader": "Load GLIGEN Model",
|
||||
# Conditioning
|
||||
"CLIPVisionEncode": "CLIP Vision Encode",
|
||||
"StyleModelApply": "Apply Style Model",
|
||||
@ -2140,7 +2142,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageSharpen": "Sharpen Image",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
"GetImageSize": "Get Image Size",
|
||||
# _for_testing
|
||||
# experimental
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.43.17
|
||||
comfyui-workflow-templates==0.9.69
|
||||
comfyui-workflow-templates==0.9.72
|
||||
comfyui-embedded-docs==0.4.4
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -560,7 +560,7 @@ class PromptServer():
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
if 'channel' not in request.rel_url.query:
|
||||
channel = 'rgba'
|
||||
@ -580,7 +580,7 @@ class PromptServer():
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
elif channel == 'a':
|
||||
with Image.open(file) as img:
|
||||
@ -597,7 +597,7 @@ class PromptServer():
|
||||
alpha_buffer.seek(0)
|
||||
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
# Use the content type from asset resolution if available,
|
||||
# otherwise guess from the filename.
|
||||
@ -614,7 +614,7 @@ class PromptServer():
|
||||
return web.FileResponse(
|
||||
file,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=\"{filename}\"",
|
||||
"Content-Disposition": f"filename=\"{filename}\"",
|
||||
"Content-Type": content_type
|
||||
}
|
||||
)
|
||||
|
||||
90
tests-unit/app_test/node_replace_manager_test.py
Normal file
90
tests-unit/app_test/node_replace_manager_test.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Tests for NodeReplaceManager registration behavior."""
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def NodeReplaceManager(monkeypatch):
|
||||
"""Provide NodeReplaceManager with `nodes` stubbed.
|
||||
|
||||
`app.node_replace_manager` does `import nodes` at module level, which pulls in
|
||||
torch + the full ComfyUI graph. register() doesn't actually need it, so we
|
||||
stub `nodes` per-test (via monkeypatch so it's torn down) and reload the
|
||||
module so it picks up the stub instead of any cached real import.
|
||||
"""
|
||||
fake_nodes = types.ModuleType("nodes")
|
||||
fake_nodes.NODE_CLASS_MAPPINGS = {}
|
||||
monkeypatch.setitem(sys.modules, "nodes", fake_nodes)
|
||||
monkeypatch.delitem(sys.modules, "app.node_replace_manager", raising=False)
|
||||
module = importlib.import_module("app.node_replace_manager")
|
||||
yield module.NodeReplaceManager
|
||||
# Drop the freshly-imported module so the next test (or a later real import
|
||||
# of `nodes`) starts from a clean slate.
|
||||
sys.modules.pop("app.node_replace_manager", None)
|
||||
|
||||
|
||||
class FakeNodeReplace:
|
||||
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace."""
|
||||
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
||||
input_mapping=None, output_mapping=None):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
|
||||
def test_register_adds_replacement(NodeReplaceManager):
|
||||
manager = NodeReplaceManager()
|
||||
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||
assert manager.has_replacement("OldNode")
|
||||
assert len(manager.get_replacement("OldNode")) == 1
|
||||
|
||||
|
||||
def test_register_allows_multiple_alternatives_for_same_old_node(NodeReplaceManager):
|
||||
"""Different new_node_ids for the same old_node_id should all be kept."""
|
||||
manager = NodeReplaceManager()
|
||||
manager.register(FakeNodeReplace(new_node_id="AltA", old_node_id="OldNode"))
|
||||
manager.register(FakeNodeReplace(new_node_id="AltB", old_node_id="OldNode"))
|
||||
replacements = manager.get_replacement("OldNode")
|
||||
assert len(replacements) == 2
|
||||
assert {r.new_node_id for r in replacements} == {"AltA", "AltB"}
|
||||
|
||||
|
||||
def test_register_is_idempotent_for_duplicate_pair(NodeReplaceManager):
|
||||
"""Re-registering the same (old_node_id, new_node_id) should be a no-op."""
|
||||
manager = NodeReplaceManager()
|
||||
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||
assert len(manager.get_replacement("OldNode")) == 1
|
||||
|
||||
|
||||
def test_register_idempotent_preserves_first_registration(NodeReplaceManager):
|
||||
"""First registration wins; later duplicates with different mappings are ignored."""
|
||||
manager = NodeReplaceManager()
|
||||
first = FakeNodeReplace(
|
||||
new_node_id="NewNode", old_node_id="OldNode",
|
||||
input_mapping=[{"new_id": "a", "old_id": "x"}],
|
||||
)
|
||||
second = FakeNodeReplace(
|
||||
new_node_id="NewNode", old_node_id="OldNode",
|
||||
input_mapping=[{"new_id": "b", "old_id": "y"}],
|
||||
)
|
||||
manager.register(first)
|
||||
manager.register(second)
|
||||
replacements = manager.get_replacement("OldNode")
|
||||
assert len(replacements) == 1
|
||||
assert replacements[0] is first
|
||||
|
||||
|
||||
def test_register_dedupe_does_not_affect_other_old_nodes(NodeReplaceManager):
|
||||
manager = NodeReplaceManager()
|
||||
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
|
||||
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
|
||||
manager.register(FakeNodeReplace(new_node_id="NewB", old_node_id="OldB"))
|
||||
assert len(manager.get_replacement("OldA")) == 1
|
||||
assert len(manager.get_replacement("OldB")) == 1
|
||||
473
tests-unit/execution_test/caching_test.py
Normal file
473
tests-unit/execution_test/caching_test.py
Normal file
@ -0,0 +1,473 @@
|
||||
"""Unit tests for cache-signature canonicalization hardening."""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyNode:
|
||||
"""Minimal node stub used to satisfy cache-signature class lookups."""
|
||||
|
||||
@staticmethod
|
||||
def INPUT_TYPES():
|
||||
"""Return a minimal empty input schema for unit tests."""
|
||||
return {"required": {}}
|
||||
|
||||
|
||||
class _FakeDynPrompt:
|
||||
"""Small DynamicPrompt stand-in with only the methods these tests need."""
|
||||
|
||||
def __init__(self, nodes_by_id):
|
||||
"""Store test nodes by id."""
|
||||
self._nodes_by_id = nodes_by_id
|
||||
|
||||
def has_node(self, node_id):
|
||||
"""Return whether the fake prompt contains the requested node."""
|
||||
return node_id in self._nodes_by_id
|
||||
|
||||
def get_node(self, node_id):
|
||||
"""Return the stored node payload for the requested id."""
|
||||
return self._nodes_by_id[node_id]
|
||||
|
||||
|
||||
class _FakeIsChangedCache:
|
||||
"""Async stub for `is_changed` lookups used by cache-key generation."""
|
||||
|
||||
def __init__(self, values):
|
||||
"""Store canned `is_changed` responses keyed by node id."""
|
||||
self._values = values
|
||||
|
||||
async def get(self, node_id):
|
||||
"""Return the canned `is_changed` value for a node."""
|
||||
return self._values[node_id]
|
||||
|
||||
|
||||
class _OpaqueValue:
|
||||
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caching_module(monkeypatch):
|
||||
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
|
||||
torch_module = types.ModuleType("torch")
|
||||
psutil_module = types.ModuleType("psutil")
|
||||
nodes_module = types.ModuleType("nodes")
|
||||
nodes_module.NODE_CLASS_MAPPINGS = {}
|
||||
graph_module = types.ModuleType("comfy_execution.graph")
|
||||
|
||||
class DynamicPrompt:
|
||||
"""Placeholder graph type so the caching module can import cleanly."""
|
||||
|
||||
pass
|
||||
|
||||
graph_module.DynamicPrompt = DynamicPrompt
|
||||
|
||||
monkeypatch.setitem(sys.modules, "torch", torch_module)
|
||||
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
|
||||
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
|
||||
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
|
||||
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
|
||||
|
||||
module = importlib.import_module("comfy_execution.caching")
|
||||
module = importlib.reload(module)
|
||||
return module, nodes_module
|
||||
|
||||
|
||||
def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module):
|
||||
"""Shared built-in substructures should canonicalize without collapsing to Unhashable."""
|
||||
caching, _ = caching_module
|
||||
shared = [{"value": 1}, {"value": 2}]
|
||||
|
||||
signature = caching._signature_to_hashable([shared, shared])
|
||||
|
||||
assert signature[0] == "list"
|
||||
assert signature[1][0] == signature[1][1]
|
||||
assert signature[1][0][0] == "list"
|
||||
assert signature[1][0][1][0] == ("dict", (("value", 1),))
|
||||
assert signature[1][0][1][1] == ("dict", (("value", 2),))
|
||||
|
||||
|
||||
def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module):
|
||||
"""Opaque values should collapse the full signature to Unhashable immediately."""
|
||||
caching, _ = caching_module
|
||||
|
||||
signature = caching._signature_to_hashable(["safe", object()])
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch):
|
||||
"""Once canonicalization fails, later recursive descent should stop immediately."""
|
||||
caching, _ = caching_module
|
||||
original = caching._signature_to_hashable_impl
|
||||
marker = object()
|
||||
marker_seen = False
|
||||
|
||||
def tracking_canonicalize(obj, *args, **kwargs):
|
||||
"""Track whether recursion reaches the nested marker after failure."""
|
||||
nonlocal marker_seen
|
||||
if obj is marker:
|
||||
marker_seen = True
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize)
|
||||
|
||||
signature = caching._signature_to_hashable([object(), [marker]])
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
assert marker_seen is False
|
||||
|
||||
|
||||
def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch):
|
||||
"""List canonicalization should read a point-in-time snapshot before recursive descent."""
|
||||
caching, _ = caching_module
|
||||
original = caching._signature_to_hashable_impl
|
||||
marker = ("marker",)
|
||||
values = [marker, 2]
|
||||
|
||||
def mutating_canonicalize(obj, *args, **kwargs):
|
||||
"""Mutate the live list during recursion to verify snapshot-based traversal."""
|
||||
if obj is marker:
|
||||
values[1] = 3
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
|
||||
|
||||
signature = caching._signature_to_hashable(values)
|
||||
|
||||
assert signature == ("list", (("tuple", ("marker",)), 2))
|
||||
assert values[1] == 3
|
||||
|
||||
|
||||
def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch):
|
||||
"""Dict canonicalization should read a point-in-time snapshot before recursive descent."""
|
||||
caching, _ = caching_module
|
||||
original = caching._signature_to_hashable_impl
|
||||
marker = ("marker",)
|
||||
values = {"first": marker, "second": 2}
|
||||
|
||||
def mutating_canonicalize(obj, *args, **kwargs):
|
||||
"""Mutate the live dict during recursion to verify snapshot-based traversal."""
|
||||
if obj is marker:
|
||||
values["second"] = 3
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
|
||||
|
||||
signature = caching._signature_to_hashable(values)
|
||||
|
||||
assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2)))
|
||||
assert values["second"] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_factory",
|
||||
[
|
||||
lambda marker: [marker],
|
||||
lambda marker: (marker,),
|
||||
lambda marker: {marker},
|
||||
lambda marker: frozenset({marker}),
|
||||
lambda marker: {"key": marker},
|
||||
],
|
||||
)
|
||||
def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
|
||||
"""Traversal RuntimeError should degrade canonicalization to Unhashable."""
|
||||
caching, _ = caching_module
|
||||
original = caching._signature_to_hashable_impl
|
||||
marker = object()
|
||||
|
||||
def raising_canonicalize(obj, *args, **kwargs):
|
||||
"""Raise a traversal RuntimeError for the marker value and delegate otherwise."""
|
||||
if obj is marker:
|
||||
raise RuntimeError("container changed during iteration")
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize)
|
||||
|
||||
signature = caching._signature_to_hashable(container_factory(marker))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
|
||||
"""The legacy helper should still hash sanitized built-ins stably when used directly."""
|
||||
caching, _ = caching_module
|
||||
shared = [{"value": 1}, {"value": 2}]
|
||||
|
||||
sanitized = [shared, shared]
|
||||
hashable = caching.to_hashable(sanitized)
|
||||
|
||||
assert hashable[0] == "list"
|
||||
assert hashable[1][0] == hashable[1][1]
|
||||
assert hashable[1][0][0] == "list"
|
||||
|
||||
|
||||
def test_to_hashable_uses_parent_snapshot_during_expanded_phase(caching_module, monkeypatch):
|
||||
"""Expanded-phase assembly should not reread a live parent container after snapshotting."""
|
||||
caching, _ = caching_module
|
||||
original_sort_key = caching._sanitized_sort_key
|
||||
outer = [{"marker"}, 2]
|
||||
|
||||
def mutating_sort_key(obj, *args, **kwargs):
|
||||
"""Mutate the live parent while a child container is being canonicalized."""
|
||||
if obj == "marker":
|
||||
outer[1] = 3
|
||||
return original_sort_key(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_sanitized_sort_key", mutating_sort_key)
|
||||
|
||||
hashable = caching.to_hashable(outer)
|
||||
|
||||
assert hashable == ("list", (("set", ("marker",)), 2))
|
||||
assert outer[1] == 3
|
||||
|
||||
|
||||
def test_to_hashable_fails_closed_for_ordered_container_with_opaque_child(caching_module):
|
||||
"""Ordered containers should fail closed when a child cannot be canonicalized."""
|
||||
caching, _ = caching_module
|
||||
|
||||
result = caching.to_hashable([object()])
|
||||
|
||||
assert isinstance(result, caching.Unhashable)
|
||||
|
||||
|
||||
def test_to_hashable_canonicalizes_dict_insertion_order(caching_module):
|
||||
"""Dicts with the same content should hash identically regardless of insertion order."""
|
||||
caching, _ = caching_module
|
||||
|
||||
first = {"b": 2, "a": 1}
|
||||
second = {"a": 1, "b": 2}
|
||||
|
||||
assert caching.to_hashable(first) == ("dict", (("a", 1), ("b", 2)))
|
||||
assert caching.to_hashable(first) == caching.to_hashable(second)
|
||||
|
||||
|
||||
def test_to_hashable_fails_closed_for_opaque_dict_key(caching_module):
|
||||
"""Opaque dict keys should fail closed instead of being traversed during hashing."""
|
||||
caching, _ = caching_module
|
||||
|
||||
hashable = caching.to_hashable({_OpaqueValue(): 1})
|
||||
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_factory",
|
||||
[
|
||||
set,
|
||||
frozenset,
|
||||
],
|
||||
)
|
||||
def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
|
||||
"""Traversal RuntimeError should degrade unordered hash conversion to Unhashable."""
|
||||
caching, _ = caching_module
|
||||
|
||||
def raising_sort_key(obj, *args, **kwargs):
|
||||
"""Raise a traversal RuntimeError while unordered values are canonicalized."""
|
||||
raise RuntimeError("container changed during iteration")
|
||||
|
||||
monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key)
|
||||
|
||||
hashable = caching.to_hashable(container_factory({"value"}))
|
||||
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
def test_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module, monkeypatch):
|
||||
"""Ambiguous dict key ordering should fail closed instead of using insertion order."""
|
||||
caching, _ = caching_module
|
||||
original_sort_key = caching._sanitized_sort_key
|
||||
ambiguous = {"a": 1, "b": 1}
|
||||
|
||||
def colliding_sort_key(obj, *args, **kwargs):
|
||||
"""Force two distinct primitive keys to share the same ordering key."""
|
||||
if obj == "a" or obj == "b":
|
||||
return ("COLLIDE",)
|
||||
return original_sort_key(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_sanitized_sort_key", colliding_sort_key)
|
||||
|
||||
hashable = caching.to_hashable(ambiguous)
|
||||
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module, monkeypatch):
|
||||
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
|
||||
caching, _ = caching_module
|
||||
original_sort_key = caching._primitive_signature_sort_key
|
||||
ambiguous = {"a": 1, "b": 1}
|
||||
|
||||
def colliding_sort_key(obj):
|
||||
"""Force two distinct primitive keys to share the same ordering key."""
|
||||
if obj == "a" or obj == "b":
|
||||
return ("COLLIDE",)
|
||||
return original_sort_key(obj)
|
||||
|
||||
monkeypatch.setattr(caching, "_primitive_signature_sort_key", colliding_sort_key)
|
||||
|
||||
sanitized = caching._signature_to_hashable(ambiguous)
|
||||
|
||||
assert isinstance(sanitized, caching.Unhashable)
|
||||
|
||||
|
||||
def test_signature_to_hashable_fails_closed_for_opaque_dict_key(caching_module):
|
||||
"""Opaque dict keys should fail closed instead of being recursively canonicalized."""
|
||||
caching, _ = caching_module
|
||||
|
||||
sanitized = caching._signature_to_hashable({_OpaqueValue(): 1})
|
||||
|
||||
assert isinstance(sanitized, caching.Unhashable)
|
||||
|
||||
|
||||
def test_signature_to_hashable_fails_closed_on_dict_key_sort_collisions_even_with_distinct_values(caching_module, monkeypatch):
|
||||
"""Different values must not mask dict key-sort collisions during canonicalization."""
|
||||
caching, _ = caching_module
|
||||
original_sort_key = caching._primitive_signature_sort_key
|
||||
|
||||
def colliding_sort_key(obj):
|
||||
"""Force two distinct primitive keys to share the same ordering key."""
|
||||
if obj == "a" or obj == "b":
|
||||
return ("COLLIDE",)
|
||||
return original_sort_key(obj)
|
||||
|
||||
monkeypatch.setattr(caching, "_primitive_signature_sort_key", colliding_sort_key)
|
||||
|
||||
sanitized = caching._signature_to_hashable({"a": 1, "b": 2})
|
||||
|
||||
assert isinstance(sanitized, caching.Unhashable)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_factory",
|
||||
[
|
||||
set,
|
||||
frozenset,
|
||||
],
|
||||
)
|
||||
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, monkeypatch, container_factory):
|
||||
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
|
||||
caching, _ = caching_module
|
||||
original_sort_key = caching._sanitized_sort_key
|
||||
container = container_factory({"a", "b"})
|
||||
|
||||
def colliding_sort_key(obj, *args, **kwargs):
|
||||
"""Force two distinct primitive values to share the same ordering key."""
|
||||
if obj == "a" or obj == "b":
|
||||
return ("COLLIDE",)
|
||||
return original_sort_key(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_sanitized_sort_key", colliding_sort_key)
|
||||
|
||||
hashable = caching.to_hashable(container)
|
||||
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
|
||||
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
|
||||
caching, nodes_module = caching_module
|
||||
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||
monkeypatch.setattr(
|
||||
caching,
|
||||
"to_hashable",
|
||||
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
|
||||
)
|
||||
|
||||
is_changed_value = []
|
||||
is_changed_value.append(is_changed_value)
|
||||
|
||||
dynprompt = _FakeDynPrompt(
|
||||
{
|
||||
"node": {
|
||||
"class_type": "UnitTestNode",
|
||||
"inputs": {"value": 5},
|
||||
}
|
||||
}
|
||||
)
|
||||
key_set = caching.CacheKeySetInputSignature(
|
||||
dynprompt,
|
||||
["node"],
|
||||
_FakeIsChangedCache({"node": is_changed_value}),
|
||||
)
|
||||
|
||||
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_accepts_primitive_lists(caching_module):
|
||||
"""Primitive-only `is_changed` lists should stay hashable without deep descent."""
|
||||
caching, _ = caching_module
|
||||
|
||||
sanitized = caching._shallow_is_changed_signature([1, "two", None, True])
|
||||
|
||||
assert sanitized == ("is_changed_list", (1, "two", None, True))
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_accepts_structured_builtin_fingerprint_lists(caching_module):
|
||||
"""Structured built-in `is_changed` fingerprints should remain representable."""
|
||||
caching, _ = caching_module
|
||||
|
||||
sanitized = caching._shallow_is_changed_signature([("seed", 42), {"cfg": 8}])
|
||||
|
||||
assert sanitized == (
|
||||
"is_changed_list",
|
||||
(
|
||||
("tuple", ("seed", 42)),
|
||||
("dict", (("cfg", 8),)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_fails_closed_for_opaque_payload(caching_module):
|
||||
"""Opaque `is_changed` payloads should still fail closed."""
|
||||
caching, _ = caching_module
|
||||
|
||||
sanitized = caching._shallow_is_changed_signature([_OpaqueValue()])
|
||||
|
||||
assert isinstance(sanitized, caching.Unhashable)
|
||||
|
||||
|
||||
def test_get_immediate_node_signature_fails_closed_for_unhashable_is_changed(caching_module, monkeypatch):
|
||||
"""Recursive `is_changed` payloads should fail the full fragment closed."""
|
||||
caching, nodes_module = caching_module
|
||||
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||
|
||||
is_changed_value = []
|
||||
is_changed_value.append(is_changed_value)
|
||||
dynprompt = _FakeDynPrompt(
|
||||
{
|
||||
"node": {
|
||||
"class_type": "UnitTestNode",
|
||||
"inputs": {"value": 5},
|
||||
}
|
||||
}
|
||||
)
|
||||
key_set = caching.CacheKeySetInputSignature(
|
||||
dynprompt,
|
||||
["node"],
|
||||
_FakeIsChangedCache({"node": is_changed_value}),
|
||||
)
|
||||
|
||||
signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "node", {}))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_get_immediate_node_signature_fails_closed_for_missing_node(caching_module):
|
||||
"""Missing nodes should return the fail-closed sentinel instead of a NaN tuple."""
|
||||
caching, _ = caching_module
|
||||
dynprompt = _FakeDynPrompt({})
|
||||
key_set = caching.CacheKeySetInputSignature(
|
||||
dynprompt,
|
||||
[],
|
||||
_FakeIsChangedCache({}),
|
||||
)
|
||||
|
||||
signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "missing", {}))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
242
tests/execution/test_caching.py
Normal file
242
tests/execution/test_caching.py
Normal file
@ -0,0 +1,242 @@
|
||||
import asyncio
|
||||
|
||||
from comfy_execution import caching
|
||||
|
||||
|
||||
class _StubDynPrompt:
|
||||
def __init__(self, nodes):
|
||||
self._nodes = nodes
|
||||
|
||||
def has_node(self, node_id):
|
||||
return node_id in self._nodes
|
||||
|
||||
def get_node(self, node_id):
|
||||
return self._nodes[node_id]
|
||||
|
||||
|
||||
class _StubIsChangedCache:
|
||||
async def get(self, node_id):
|
||||
return None
|
||||
|
||||
|
||||
class _StubNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_keeps_primitive_only_list_shallow():
|
||||
assert caching._shallow_is_changed_signature([1, "two", None, True]) == (
|
||||
"is_changed_list",
|
||||
(1, "two", None, True),
|
||||
)
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_keeps_primitive_only_tuple_shallow():
|
||||
assert caching._shallow_is_changed_signature((1, "two", None, True)) == (
|
||||
"is_changed_tuple",
|
||||
(1, "two", None, True),
|
||||
)
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_keeps_structured_builtin_fingerprint_list():
|
||||
assert caching._shallow_is_changed_signature([("seed", 42), {"cfg": 8}]) == (
|
||||
"is_changed_list",
|
||||
(
|
||||
("tuple", ("seed", 42)),
|
||||
("dict", (("cfg", 8),)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_does_not_use_to_hashable(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
caching,
|
||||
"to_hashable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
AssertionError("is_changed signature must not deep-canonicalize")
|
||||
),
|
||||
)
|
||||
|
||||
signature = caching._shallow_is_changed_signature([("seed", 42), {"cfg": 8}])
|
||||
|
||||
assert signature == (
|
||||
"is_changed_list",
|
||||
(
|
||||
("tuple", ("seed", 42)),
|
||||
("dict", (("cfg", 8),)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_immediate_node_signature_canonicalizes_non_link_inputs(monkeypatch):
|
||||
live_value = [1, {"nested": [2, 3]}]
|
||||
dynprompt = _StubDynPrompt(
|
||||
{
|
||||
"1": {
|
||||
"class_type": "TestCacheNode",
|
||||
"inputs": {"value": live_value},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||
|
||||
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||
signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {}))
|
||||
|
||||
assert signature == (
|
||||
"TestCacheNode",
|
||||
None,
|
||||
("value", ("list", (1, ("dict", (("nested", ("list", (2, 3))),))))),
|
||||
)
|
||||
|
||||
|
||||
def test_to_hashable_walks_dicts_without_rebinding_traversal_stack():
|
||||
live_value = {
|
||||
"outer": {"nested": [2, 3]},
|
||||
"items": [{"leaf": 4}],
|
||||
}
|
||||
|
||||
assert caching.to_hashable(live_value) == (
|
||||
"dict",
|
||||
(
|
||||
("items", ("list", (("dict", (("leaf", 4),)),))),
|
||||
("outer", ("dict", (("nested", ("list", (2, 3))),))),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_immediate_node_signature_fails_closed_for_opaque_non_link_input(monkeypatch):
|
||||
class OpaqueRuntimeValue:
|
||||
pass
|
||||
|
||||
live_value = OpaqueRuntimeValue()
|
||||
dynprompt = _StubDynPrompt(
|
||||
{
|
||||
"1": {
|
||||
"class_type": "TestCacheNode",
|
||||
"inputs": {"value": live_value},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||
|
||||
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||
signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {}))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_get_node_signature_propagates_unhashable_immediate_fragment(monkeypatch):
|
||||
class OpaqueRuntimeValue:
|
||||
pass
|
||||
|
||||
dynprompt = _StubDynPrompt(
|
||||
{
|
||||
"1": {
|
||||
"class_type": "TestCacheNode",
|
||||
"inputs": {"value": OpaqueRuntimeValue()},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||
|
||||
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_get_node_signature_never_visits_raw_non_link_input(monkeypatch):
|
||||
live_value = [1, 2, 3]
|
||||
dynprompt = _StubDynPrompt(
|
||||
{
|
||||
"1": {
|
||||
"class_type": "TestCacheNode",
|
||||
"inputs": {"value": live_value},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||
monkeypatch.setattr(
|
||||
caching,
|
||||
"_signature_to_hashable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
AssertionError("outer signature canonicalizer should not run")
|
||||
),
|
||||
)
|
||||
|
||||
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
|
||||
|
||||
assert isinstance(signature, tuple)
|
||||
|
||||
|
||||
def test_get_node_signature_keeps_deep_canonicalized_input_fragment(monkeypatch):
|
||||
live_value = 1
|
||||
for _ in range(8):
|
||||
live_value = [live_value]
|
||||
expected = caching.to_hashable(live_value)
|
||||
|
||||
dynprompt = _StubDynPrompt(
|
||||
{
|
||||
"1": {
|
||||
"class_type": "TestCacheNode",
|
||||
"inputs": {"value": live_value},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||
|
||||
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
|
||||
|
||||
assert isinstance(signature, tuple)
|
||||
assert signature[0][2][0] == "value"
|
||||
assert signature[0][2][1] == expected
|
||||
|
||||
|
||||
def test_get_node_signature_keeps_large_precanonicalized_fragment(monkeypatch):
|
||||
live_value = object()
|
||||
canonical_fragment = ("tuple", tuple(("list", (index, index + 1)) for index in range(256)))
|
||||
dynprompt = _StubDynPrompt(
|
||||
{
|
||||
"1": {
|
||||
"class_type": "TestCacheNode",
|
||||
"inputs": {"value": live_value},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||
monkeypatch.setattr(
|
||||
caching,
|
||||
"to_hashable",
|
||||
lambda value, max_nodes=caching._MAX_SIGNATURE_CONTAINER_VISITS: (
|
||||
canonical_fragment if value is live_value else caching.Unhashable()
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
caching,
|
||||
"_signature_to_hashable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
AssertionError("outer signature canonicalizer should not run")
|
||||
),
|
||||
)
|
||||
|
||||
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
|
||||
|
||||
assert isinstance(signature, tuple)
|
||||
assert signature[0][2] == ("value", canonical_fragment)
|
||||
@ -21,7 +21,7 @@ class TestAsyncProgressUpdate(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
async def execute(self, value, sleep_seconds):
|
||||
start = time.time()
|
||||
@ -51,7 +51,7 @@ class TestSyncProgressUpdate(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
def execute(self, value, sleep_seconds):
|
||||
start = time.time()
|
||||
|
||||
@ -21,7 +21,7 @@ class TestAsyncValidation(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||
@ -53,7 +53,7 @@ class TestAsyncError(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "error_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
async def error_execution(self, value, error_after):
|
||||
await asyncio.sleep(error_after)
|
||||
@ -74,7 +74,7 @@ class TestAsyncValidationError(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||
@ -105,7 +105,7 @@ class TestAsyncTimeout(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "timeout_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
async def timeout_execution(self, value, timeout, operation_time):
|
||||
try:
|
||||
@ -129,7 +129,7 @@ class TestSyncError(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sync_error"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
def sync_error(self, value):
|
||||
raise RuntimeError("Intentional sync execution error for testing")
|
||||
@ -150,7 +150,7 @@ class TestAsyncLazyCheck(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
async def check_lazy_status(self, condition, input1, input2):
|
||||
# Simulate async checking (e.g., querying remote service)
|
||||
@ -184,7 +184,7 @@ class TestDynamicAsyncGeneration(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "generate_async_workflow"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||
g = GraphBuilder()
|
||||
@ -229,7 +229,7 @@ class TestAsyncResourceUser(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "use_resource"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
async def use_resource(self, value, resource_id, duration):
|
||||
# Check if resource is already in use
|
||||
@ -265,7 +265,7 @@ class TestAsyncBatchProcessing(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process_batch"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||
batch_size = images.shape[0]
|
||||
@ -305,7 +305,7 @@ class TestAsyncConcurrentLimit(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "limited_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
CATEGORY = "experimental/async"
|
||||
|
||||
async def limited_execution(self, value, duration, node_id):
|
||||
async with self._semaphore:
|
||||
|
||||
@ -409,7 +409,7 @@ class TestSleep(ComfyNodeABC):
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sleep"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "experimental"
|
||||
|
||||
async def sleep(self, value, seconds, unique_id):
|
||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||
@ -440,7 +440,7 @@ class TestParallelSleep(ComfyNodeABC):
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "parallel_sleep"
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "experimental"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||
@ -474,7 +474,7 @@ class TestOutputNodeWithSocketOutput:
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "experimental"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def process(self, image, value):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user