Merge remote-tracking branch 'origin/master' into joyimage-edit-pr

This commit is contained in:
huangfeice 2026-07-01 18:38:09 +08:00
commit a00b731054
72 changed files with 5169 additions and 765 deletions

38
.github/workflows/ci-cursor-review.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: CI - Cursor Review
# Thin caller for the shared reusable cursor-review workflow in
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
# consolidation, prompts, extract/post/notify scripts) lives there as the
# single source of truth, so this repo only carries the repo-specific diff
# excludes.
on:
pull_request:
types: [labeled, unlabeled]
concurrency:
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
cancel-in-progress: true
jobs:
cursor-review:
if: github.event.label.name == 'cursor-review'
permissions:
contents: read
pull-requests: write
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
# from the same commit as the workflow definition.
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
with:
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
diff_excludes: >-
:!**/.claude/**
:!**/dist/**
:!**/vendor/**
:!**/*.generated.*
:!**/*.min.js
:!**/*.min.css
secrets:
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}

166
AGENTS.md Normal file
View File

@ -0,0 +1,166 @@
## Engineering Style
- Keep changes small and direct. Most fixes should touch the narrowest code path
that explains the bug, performance issue, dtype issue, model-format issue, or
user-facing behavior.
- Change the least amount of files possible. A change that touches many files is
more likely to be a bad change than a good one unless the broader scope is
directly required.
- Prefer practical fixes over broad architecture work. Add abstractions only
when they remove real repeated logic or match an existing ComfyUI pattern.
- Delete obsolete code aggressively when newer infrastructure makes it useless.
Remove dead fallbacks, migration paths, unused options, debug prints, and
compatibility branches that are no longer needed. Do not leave dead branches,
unreachable code, or functions that are never called.
- Revert or disable problematic behavior quickly when it breaks users. It is
better to remove a broken feature path than keep a complicated partial fix.
- Preserve existing APIs, node names, model-loading behavior, file layout, and
workflow compatibility unless the change is explicitly about replacing them.
- Code must look hand-written for this repository. Changes that read like
generic AI-generated code will be rejected automatically: unnecessary helper
layers, vague names, boilerplate comments, defensive branches without a real
failure mode, broad rewrites, or code that ignores the local style.
## Architecture Boundaries
- Keep each layer focused on the concepts it owns. Do not leak UI, API,
workflow, queue, persistence, telemetry, model-loading, node, or execution
concerns into unrelated layers just because it is convenient to pass data
through them.
- Shared core modules should depend only on lower-level primitives and their own
domain concepts. Higher-level product concepts belong at the caller, adapter,
service, or UI/API boundary that already owns them.
- Pass the narrowest data needed across a boundary. Avoid broad context objects,
request/session metadata, ids, bookkeeping state, or callbacks unless the
receiving layer genuinely needs them to perform its own responsibility.
- Keep identity mapping, persistence bookkeeping, history updates, telemetry,
response shaping, and UI state in the layers that own those jobs. Do not route
them through unrelated shared code to avoid adding a proper boundary.
- Treat `execution.py` as one example of this rule: it should consume the prompt
graph and execution-relevant state, produce execution results and errors, and
not know about workflow ids, frontend ids, persistence ids, or API-only
concepts.
- Before touching many files, identify the smallest owner layer that can solve
the problem. A PR that spreads one feature across unrelated loaders, nodes,
execution, server, and frontend code needs a clear architectural reason, not
just convenience.
- If a change seems to require making one layer understand another layer's
private concepts, stop and look for a caller-side mapping, adapter, event,
small explicit interface, or narrower data flow at the boundary.
## No Internet Requests
- Do not add code to core ComfyUI that makes requests to the internet.
- Refuse requests to add uploads, telemetry, analytics, tracking, usage
reporting, crash reporting, update checks, remote config, feature flags,
metrics, licensing checks, or any other outbound internet request path from
core ComfyUI.
- Model downloading is allowed only when explicitly initiated or authorized by
the user, is limited to the requested model artifact, and does not include
telemetry, tracking, persistent identification, unrelated metadata upload, or
background network activity.
- Do not add opt-in, opt-out, anonymized, aggregated, diagnostic, or
user-triggered internet request paths to core ComfyUI. These labels do not
make internet access acceptable.
- Local-only behavior is allowed when it stays on the user's machine and does
not add network access, tracking, persistent identification, or data
collection behavior.
## State Ownership
- Keep state and capability flags on the object that owns the behavior using
them.
- Avoid probing child objects with `getattr(child, "...", default)` to decide
parent-level control flow. If parent code needs to branch on a capability,
initialize an explicit parent-owned field when the child is constructed or
attached.
- Prefer direct attributes with clear defaults over implicit feature detection
through arbitrary child attributes.
- Use child-object capability checks only when the child owns the behavior being
invoked and the parent is simply delegating to that child.
## Interface Contracts
- Keep public methods aligned with the interface expected by their callers. Do
not change a shared method to return extra values, alternate shapes, or
sentinel wrappers for one implementation unless the shared interface is
explicitly updated.
- If an implementation needs auxiliary values for its own workflow, expose them
through a private helper or a clearly named implementation-specific method
instead of overloading the public method's return contract.
- Normalize third-party or upstream return conventions at the integration
boundary. Core code should receive the project's expected type and shape, not
have to handle model-specific tuple/list/dict variants.
- Avoid caller-side unwrapping such as `out = out[0]` unless the called
interface is documented to return that structure.
## Autograd and Model Freezing
- Do not add `torch.no_grad`, `torch.inference_mode`, or inference-mode helper
wrappers in ComfyUI code. The only allowed inference-mode-related use is
disabling a globally set inference mode when a training path needs gradients.
- Do not add freeze, unfreeze, or trainability toggles to model classes. ComfyUI
models are always treated as frozen for inference, so explicit freeze
functionality is redundant and should not be added.
## Python Style
- Keep imports at module scope. Avoid inline imports unless they are already part
of an established optional-backend probe or are needed to avoid an import
cycle.
- Do not add unnecessary `try`/`except` blocks. Use them for optional dependency,
platform, or backend capability detection only when the program has a useful
fallback. Prefer specific exception types when changing new code.
- Let unsupported model formats, invalid quantization metadata, and bad states
fail with clear errors instead of silently producing lower quality output.
- Match the existing local style in the file you edit. This codebase tolerates
long lines, simple helper functions, module-level state, and direct tensor
operations when they make the code easier to follow.
- Keep comments sparse and useful. Strip useless comments that restate the code
or describe obvious behavior. Short TODOs are fine when they name the concrete
missing follow-up.
## Model, Device, and Memory Behavior
- Treat dtype, device placement, VRAM usage, and offloading behavior as core
correctness concerns. Check CPU, CUDA, ROCm, MPS, DirectML, XPU, NPU, and low
VRAM implications when touching shared execution or loading code.
- Prefer native ComfyUI formats and existing quantization/offload helpers over
adding parallel code paths. Use `comfy.quant_ops`, `comfy.model_management`,
`comfy.memory_management`, `comfy.pinned_memory`, `comfy_aimdo`, and
`comfy-kitchen` helpers where they already solve the problem.
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
storage dtype, bias dtype, and original tensor shape metadata.
- When optimizing, favor small measurable changes: fewer allocations, fewer
device transfers, less peak memory, better batching, or use of a faster
existing backend op.
## Nodes and User-Facing Behavior
- Follow existing node conventions: `INPUT_TYPES`, `RETURN_TYPES`, `FUNCTION`,
`CATEGORY`, and registration through the local mapping used by that file.
- Keep node changes backward compatible by default. Add inputs with sensible
defaults and avoid changing output types unless the request requires it.
- The official mascot of ComfyUI is a very cute anime girl with massive fennec
ears, a big fluffy tail, long blonde wavy hair, and blue eyes. Feel free to
use her in ComfyUI materials, UI text, examples, tests, generated assets, or
comments, but do not disrespect her.
- Warning and info messages should be short and actionable. Remove noisy or
misleading messages rather than adding more logging.
- Documentation and README edits should be concise, factual, and tied to the
changed behavior.
## Commit and Review Habits
- If asked to write commit messages, use short direct subjects like the existing
history: `Fix ...`, `Add ...`, `Support ...`, `Remove ...`, `Update ...`,
`Make ...`, `Use ...`, `Disable ...`, `Bump ...`, or `Revert ...`.
- Keep PR descriptions short and reviewable. State the problem, the behavioral
change, and the tests run; avoid long narrative explanations, implementation
diaries, or exhaustive file-by-file summaries unless the reviewer explicitly
needs that context.
- Prefer one coherent behavioral change per commit. Dependency pins, tests, and
the code that needs them may be in the same commit when they are inseparable.
- In reviews, prioritize real user impact: crashes, wrong dtype/device behavior,
memory regressions, broken model loading, workflow incompatibility, and noisy
or misleading user-facing output.

View File

@ -140,7 +140,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/Comfy-Desktop)**
2. **[Comfy Desktop](https://github.com/Comfy-Org/Comfy-Desktop)**
- Builds a new release using the latest stable core version
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**

View File

@ -240,6 +240,7 @@ database_default_path = os.path.abspath(
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
parser.add_argument("--enable-asset-hashing", action="store_true", help="Compute blake3 content hashes when scanning assets. Hashing enables future asset-portability features (deduplication, cross-machine model resolution) but adds startup cost and per-output cost on large models directories. Off by default; enable to opt in.")
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")

View File

@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
import logging
import comfy.model_management
import comfy.patcher_extension
import comfy.utils
import comfy.conds
if TYPE_CHECKING:
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
@ -51,12 +53,18 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.context_overlap = context_overlap
self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
self.guide_frames_indices: list[int] = []
self.guide_overlap_info: list[tuple[int, int]] = []
self.guide_kf_local_positions: list[int] = []
self.guide_downscale_factors: list[int] = []
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None:
@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC):
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
if modality_idx == 0:
return self
return self.modality_windows[modality_idx]
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
return cond_value._copy_with(sliced)
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
"""Compute which concatenated guide frames overlap with a context window.
Each guide's latent-space start is derived from its first token's pixel-t-start
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
model's temporal_downscale_ratio.
Args:
guide_entries: list of guide_attention_entry dicts
keyframe_idxs: per-token pixel coords cond tensor for the modality
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
window_index_list: the window's frame indices into the video portion
Returns:
suffix_indices: indices into the guide_frames tensor for frame selection
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
total_overlap: total number of overlapping guide frames
"""
window_set = set(window_index_list)
window_list = list(window_index_list)
suffix_indices = []
overlap_info = []
kf_local_positions = []
suffix_base = 0
token_offset = 0
for entry_idx, entry in enumerate(guide_entries):
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
guide_len = entry["latent_shape"][0]
entry_overlap = 0
for local_offset in range(guide_len):
video_pos = latent_start + local_offset
if video_pos in window_set:
suffix_indices.append(suffix_base + local_offset)
kf_local_positions.append(window_list.index(video_pos))
entry_overlap += 1
if entry_overlap > 0:
overlap_info.append((entry_idx, entry_overlap))
suffix_base += guide_len
token_offset += entry["pre_filter_count"]
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
@dataclass
class WindowingState:
"""Per-modality context windowing state for each step,
built using IndexListContextHandler._build_window_state().
For non-multimodal models the lists are length 1
"""
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
dim: int = 0 # primary modality temporal dim for context windowing
is_multimodal: bool = False
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
"""Reformat window for multimodal contexts by deriving per-modality index lists.
Non-multimodal contexts return the input window unchanged.
"""
if not self.is_multimodal:
return window
x = self.latents[0]
primary_total = self.latent_shapes[0][self.dim]
primary_overlap = window.context_overlap
map_shapes = self.latent_shapes
if x.size(self.dim) != primary_total:
map_shapes = list(self.latent_shapes)
video_shape = list(self.latent_shapes[0])
video_shape[self.dim] = x.size(self.dim)
map_shapes[0] = torch.Size(video_shape)
try:
per_modality_indices = model.map_context_window_to_modalities(
window.index_list, map_shapes, self.dim)
except AttributeError:
raise NotImplementedError(
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
modality_windows = {}
for mod_idx in range(1, len(self.latents)):
modality_total_frames = self.latents[mod_idx].shape[self.dim]
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
modality_overlap = max(round(primary_overlap * ratio), 0)
modality_windows[mod_idx] = IndexListContextWindow(
per_modality_indices[mod_idx], dim=self.dim,
total_frames=modality_total_frames,
context_overlap=modality_overlap)
return IndexListContextWindow(
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
modality_windows=modality_windows, context_overlap=primary_overlap)
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
"""Slice latents for a context window, injecting guide frames where applicable.
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
"""
sliced = []
guide_frame_counts = []
for idx in range(len(self.latents)):
modality_window = window.get_window_for_modality(idx)
retain = retain_index_list if idx == 0 else []
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
if self.guide_entries[idx] is not None:
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
else:
ng = 0
sliced.append(s)
guide_frame_counts.append(ng)
return sliced, guide_frame_counts
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
for idx in range(len(self.latents)):
if guide_frame_counts[idx] > 0:
window_len = len(window.get_window_for_modality(idx).index_list)
for ci in range(len(out_per_modality)):
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
guide_entries = self.guide_entries[modality_idx]
guide_frames = self.guide_latents[modality_idx]
keyframe_idxs = self.keyframe_idxs[modality_idx]
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
kf_local_pos = [p + 1 for p in kf_local_pos]
window.guide_frames_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
guide_downscale_factors = []
if guide_frame_count > 0:
full_H = guide_frames.shape[3]
for entry_idx, _ in overlap_info:
entry_H = guide_entries[entry_idx]["latent_shape"][1]
guide_downscale_factors.append(full_H // entry_H)
window.guide_downscale_factors = guide_downscale_factors
if guide_frame_count > 0:
idx = tuple([slice(None)] * self.dim + [suffix_idx])
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
return latent_slice, 0
def patch_latent_shapes(self, sub_conds, new_shapes):
if not self.is_multimodal:
return
for cond_list in sub_conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
@dataclass
class ContextSchedule:
name: str
@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
causal_window_fix: bool=True):
latent_retain_index_list: list[int]=[], causal_window_fix: bool=True):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC):
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
self.causal_window_fix = causal_window_fix
self.callbacks = {}
@staticmethod
def _get_latent_shapes(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
return model_conds['latent_shapes'].cond
return None
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
@staticmethod
def _get_keyframe_idxs(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
kf = model_conds.get('keyframe_idxs')
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
return kf.cond
return None
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
If guide frames are present on the primary modality, only the video portion is shuffled.
"""
guide_entries = self._get_guide_entries(conds)
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
latent_shapes = self._get_latent_shapes(conds)
if latent_shapes is not None and len(latent_shapes) > 1:
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
primary_total = latent_shapes[0][self.dim]
primary_video_count = modalities[0].size(self.dim) - guide_count
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
for i in range(1, len(modalities)):
mod_total = latent_shapes[i][self.dim]
ratio = mod_total / primary_total if primary_total > 0 else 1
mod_ctx_len = max(round(self.context_length * ratio), 1)
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
noise, _ = comfy.utils.pack_latents(modalities)
return noise
video_count = noise.size(self.dim) - guide_count
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
return noise
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
latent_shapes = self._get_latent_shapes(conds)
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
unpacked_latents_list = list(unpacked_latents)
guide_latents_list = [None] * len(unpacked_latents)
guide_entries_list = [None] * len(unpacked_latents)
keyframe_idxs_list = [None] * len(unpacked_latents)
extracted_guide_entries = self._get_guide_entries(conds)
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
# Strip guide frames (only from first modality for now)
if extracted_guide_entries is not None:
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
if guide_count > 0:
x = unpacked_latents[0]
latent_count = x.size(self.dim) - guide_count
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
guide_entries_list[0] = extracted_guide_entries
keyframe_idxs_list[0] = extracted_keyframe_idxs
return WindowingState(
latents=unpacked_latents_list,
guide_latents=guide_latents_list,
guide_entries=guide_entries_list,
keyframe_idxs=keyframe_idxs_list,
latent_shapes=latent_shapes,
dim=self.dim,
is_multimodal=is_multimodal,
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
total_frame_count = window_state.latents[0].size(self.dim)
if total_frame_count > self.context_length:
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
if self.latent_retain_index_list:
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
return True
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
return False
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
current_timestep = timestep[0].to(sample_sigmas.dtype)
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
return # substep from multi-step sampler: keep self._step from the last full step
@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
conds_final = [torch.zeros_like(x_in) for _ in conds]
window_state = self._build_window_state(x_in, conds, model)
num_modalities = len(window_state.latents)
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
enumerated_context_windows = list(enumerate(context_windows))
total_windows = len(enumerated_context_windows)
# Initialize per-modality accumulators (length 1 for single-modality)
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
else:
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
# accumulate results from each context window
for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
results = self.evaluate_context_windows(
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
model_options, window_state=window_state, total_windows=total_windows)
for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
conds_final, counts_final, biases_final)
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
for mod_idx in range(num_modalities):
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
modality_window = result.window.get_window_for_modality(mod_idx)
self.combine_context_window_results(
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
result.window_idx, total_windows, timestep,
accum[mod_idx], counts[mod_idx], biases[mod_idx])
# fuse accumulated results into final conds
try:
# finalize conds
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
# relative is already normalized, so return as is
del counts_final
return conds_final
else:
# normalize conds via division by context usage counts
for i in range(len(conds_final)):
conds_final[i] /= counts_final[i]
del counts_final
return conds_final
result_out = []
for ci in range(len(conds)):
finalized = []
for mod_idx in range(num_modalities):
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
accum[mod_idx][ci] /= counts[mod_idx][ci]
f = accum[mod_idx][ci]
# if guide frames were injected, append them to the end of the fused latents for the next step
if window_state.guide_latents[mod_idx] is not None:
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
finalized.append(f)
# pack modalities together if needed
if window_state.is_multimodal and len(finalized) > 1:
packed, _ = comfy.utils.pack_latents(finalized)
else:
packed = finalized[0]
result_out.append(packed)
return result_out
finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, device=None, first_device=None):
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, window_state: WindowingState, total_windows: int = None,
device=None, first_device=None):
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
For each window:
1. Builds windows (for each modality if multimodal)
2. Slices window for each modality
3. Injects concatenated latent guide frames where present
4. Packs together if needed and calls model
5. Unpacks and strips any guides from outputs
"""
x = window_state.latents[0]
results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted()
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
# prepare the window accounting for multimodal windows
window = window_state.prepare_window(window, model)
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward.
# Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up.
anchor_applied = False
if self.causal_window_fix:
anchor_idx = window.index_list[0] - 1
@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC):
window.causal_anchor_index = anchor_idx
anchor_applied = True
# slice the window for each modality, injecting guide frames where applicable
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
)
# if multimodal, pack modalities together
if window_state.is_multimodal and len(sliced) > 1:
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
else:
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
# get resized conds for window
model_options["transformer_options"]["context_window"] = window
# get subsections of x, timestep, conds
sub_x = window.get_tensor(x_in, device)
sub_timestep = window.get_tensor(timestep, device, dim=0)
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
sub_timestep = window.get_tensor(timestep, dim=0)
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
# if multimodal, patch latent_shapes in conds for correct unpacking in model
window_state.patch_latent_shapes(sub_conds, sub_shapes)
# call model on window
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
if device is not None:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
# strip causal_window_fix anchor if applied
# unpack outputs
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
# strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct
if anchor_applied:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
for ci in range(len(out_per_modality)):
t = out_per_modality[ci][0]
out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1)
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
# strip injected guide frames
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
return results
@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC):
biases_final[i][idx] = bias_total + bias
else:
# add conds and counts based on weights of fuse method
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
for i in range(len(sub_conds_out)):
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
# limit noise_shape length to context_length for more accurate vram use estimation
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
model_options = kwargs.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is not None:
noise_shape = list(noise_shape)
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, *args, **kwargs)
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
if is_packed:
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
pass
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, conds, *args, **kwargs)
def create_prepare_sampling_wrapper(model: ModelPatcher):
@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
_sampler_sample_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
return ContextSchedule(context_schedule, func)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
def create_weights_flat(length: int, **kwargs) -> list[float]:
@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
# only expected overlap is given different weights
weights_torch = torch.ones((length))
# blend left-side on all except first window
if min(idxs) > 0:
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
weights_torch[:handler.context_overlap] = ramp_up
ramp_up = torch.linspace(1e-37, 1, context_overlap)
weights_torch[:context_overlap] = ramp_up
# blend right-side on all except last window
if max(idxs) < full_length-1:
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
weights_torch[-handler.context_overlap:] = ramp_down
ramp_down = torch.linspace(1, 1e-37, context_overlap)
weights_torch[-context_overlap:] = ramp_down
return weights_torch
class ContextFuseMethods:

321
comfy/ldm/boogu/model.py Normal file
View File

@ -0,0 +1,321 @@
# Boogu-Image-0.1 transformer
# Architecture is an OmniGen2 derivative (see comfy/ldm/omnigen/omnigen2.py) with an
# added dual-stream ("double_stream") stage before the single-stream layers, conditioned
# by a Qwen3-VL multimodal LLM. Reuses the OmniGen2/Lumina building blocks and the Flux
# RoPE core, the only new component is the double-stream block + the hybrid forward order.
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
import comfy.ldm.common_dit
import comfy.ldm.omnigen.omnigen2
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.omnigen.omnigen2 import (
OmniGen2RotaryPosEmbed,
Lumina2CombinedTimestepCaptionEmbedding,
LuminaRMSNormZero,
LuminaLayerNormContinuous,
LuminaFeedForward,
Attention,
OmniGen2TransformerBlock,
apply_rotary_emb,
)
class BooguDoubleStreamProcessor(nn.Module):
# Joint attention over [instruct ; img] with separate per-stream q/k/v and output projections.
def __init__(self, dim, head_dim, heads, kv_heads, dtype=None, device=None, operations=None):
super().__init__()
query_dim = head_dim * heads
kv_dim = head_dim * kv_heads
self.img_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.img_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.img_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.img_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
def forward(self, attn, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}):
batch_size = img_hidden_states.shape[0]
L_instruct = instruct_hidden_states.shape[1]
img_q = self.img_to_q(img_hidden_states)
img_k = self.img_to_k(img_hidden_states)
img_v = self.img_to_v(img_hidden_states)
instruct_q = self.instruct_to_q(instruct_hidden_states)
instruct_k = self.instruct_to_k(instruct_hidden_states)
instruct_v = self.instruct_to_v(instruct_hidden_states)
# Concatenate instruction first, then image (matches reference processor order).
query = torch.cat([instruct_q, img_q], dim=1)
key = torch.cat([instruct_k, img_k], dim=1)
value = torch.cat([instruct_v, img_v], dim=1)
query = query.view(batch_size, -1, attn.heads, attn.dim_head)
key = key.view(batch_size, -1, attn.kv_heads, attn.dim_head)
value = value.view(batch_size, -1, attn.kv_heads, attn.dim_head)
query = attn.norm_q(query)
key = attn.norm_k(key)
if rotary_emb is not None:
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if attn.kv_heads < attn.heads:
key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
# Split back to instruction/image, apply per-stream output projections, recombine.
instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct])
img_hidden_states = self.img_out(hidden_states[:, L_instruct:])
hidden_states = torch.cat([instruct_hidden_states, img_hidden_states], dim=1)
hidden_states = attn.to_out[0](hidden_states)
return hidden_states
class BooguJointAttention(nn.Module):
# Holds the shared q/k RMSNorm + final output projection
def __init__(self, dim, head_dim, heads, kv_heads, eps=1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.dim_head = head_dim
self.scale = head_dim ** -0.5
self.norm_q = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(heads * head_dim, dim, bias=False, dtype=dtype, device=device),
nn.Dropout(0.0),
)
self.processor = BooguDoubleStreamProcessor(dim, head_dim, heads, kv_heads, dtype=dtype, device=device, operations=operations)
def forward(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}):
return self.processor(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask, transformer_options=transformer_options)
class BooguDoubleStreamBlock(nn.Module):
# Dual-stream block: joint attention over [instruct ; img] + image self-attention, each stream with its own modulation/MLP.
def __init__(self, dim, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=None, device=None, operations=None):
super().__init__()
head_dim = dim // num_attention_heads
self.img_instruct_attn = BooguJointAttention(dim, head_dim, num_attention_heads, num_kv_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
self.img_self_attn = Attention(
query_dim=dim, dim_head=head_dim, heads=num_attention_heads, kv_heads=num_kv_heads,
eps=1e-5, bias=False, dtype=dtype, device=device, operations=operations,
)
self.img_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations)
self.instruct_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations)
self.img_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_norm3 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.instruct_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.instruct_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_self_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, img_hidden_states, instruct_hidden_states, joint_rotary_emb, img_rotary_emb, temb, joint_attention_mask=None, img_attention_mask=None, transformer_options={}):
L_instruct = instruct_hidden_states.shape[1]
img_norm1_out, img_gate_msa, img_scale_mlp, img_gate_mlp = self.img_norm1(img_hidden_states, temb)
img_norm2_out, img_shift_mlp, _, _ = self.img_norm2(img_hidden_states, temb)
img_norm3_out, img_gate_self, _, _ = self.img_norm3(img_hidden_states, temb)
instruct_norm1_out, instruct_gate_msa, instruct_scale_mlp, instruct_gate_mlp = self.instruct_norm1(instruct_hidden_states, temb)
instruct_norm2_out, instruct_shift_mlp, _, _ = self.instruct_norm2(instruct_hidden_states, temb)
joint_attn_out = self.img_instruct_attn(img_norm1_out, instruct_norm1_out, joint_rotary_emb, joint_attention_mask, transformer_options=transformer_options)
instruct_attn_out = joint_attn_out[:, :L_instruct]
img_attn_out = joint_attn_out[:, L_instruct:]
img_self_attn_out = self.img_self_attn(img_norm3_out, img_norm3_out, img_attention_mask, img_rotary_emb, transformer_options=transformer_options)
img_hidden_states = img_hidden_states + img_gate_msa.unsqueeze(1).tanh() * self.img_attn_norm(img_attn_out)
img_hidden_states = img_hidden_states + img_gate_self.unsqueeze(1).tanh() * self.img_self_attn_norm(img_self_attn_out)
img_mlp_input = (1 + img_scale_mlp.unsqueeze(1)) * img_norm2_out + img_shift_mlp.unsqueeze(1)
img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_mlp_input))
img_hidden_states = img_hidden_states + img_gate_mlp.unsqueeze(1).tanh() * self.img_ffn_norm2(img_mlp_out)
instruct_hidden_states = instruct_hidden_states + instruct_gate_msa.unsqueeze(1).tanh() * self.instruct_attn_norm(instruct_attn_out)
instruct_mlp_input = (1 + instruct_scale_mlp.unsqueeze(1)) * instruct_norm2_out + instruct_shift_mlp.unsqueeze(1)
instruct_mlp_out = self.instruct_feed_forward(self.instruct_ffn_norm1(instruct_mlp_input))
instruct_hidden_states = instruct_hidden_states + instruct_gate_mlp.unsqueeze(1).tanh() * self.instruct_ffn_norm2(instruct_mlp_out)
return img_hidden_states, instruct_hidden_states
class BooguTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 3360,
num_layers: int = 32,
num_double_stream_layers: int = 8,
num_refiner_layers: int = 2,
num_attention_heads: int = 28,
num_kv_heads: int = 7,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (40, 40, 40),
axes_lens: Tuple[int, int, int] = (2048, 1664, 1664),
instruction_feat_dim: int = 4096,
timestep_scale: float = 1000.0,
image_model=None,
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.hidden_size = hidden_size
self.dtype = dtype
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=instruction_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.double_stream_layers = nn.ModuleList([
BooguDoubleStreamBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=dtype, device=device, operations=operations)
for _ in range(num_double_stream_layers)
])
self.single_stream_layers = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_layers)
])
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
)
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
# Patchify/refine helpers are identical to OmniGen2; reuse via bound methods.
flat_and_pad_to_seq = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.flat_and_pad_to_seq
img_patch_embed_and_refine = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.img_patch_embed_and_refine
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context
text_attention_mask = attention_mask
ref_image_hidden_states = ref_latents
device = hidden_states.device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
rotary_emb, encoder_seq_lengths, seq_lengths,
) = self.rope_embedder(
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes, device,
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
transformer_options=transformer_options,
)
# Double-stream stage: the image self-attention only sees the [ref ; noise] tokens,
# which sit after the instruction tokens in the joint rope.
L_instruct = text_hidden_states.shape[1]
combined_img_rotary_emb = rotary_emb[:, L_instruct:]
for layer in self.double_stream_layers:
combined_img_hidden_states, text_hidden_states = layer(
combined_img_hidden_states, text_hidden_states,
rotary_emb, combined_img_rotary_emb, temb,
joint_attention_mask=None, img_attention_mask=None,
transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
for layer in self.single_stream_layers:
hidden_states = layer(hidden_states, None, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
p = self.patch_size
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded // p, p1=p, p2=p)[:, :, :H, :W]
return -output

View File

@ -515,7 +515,7 @@ class Block(nn.Module):
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
@ -548,7 +548,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
@ -557,7 +557,7 @@ class Block(nn.Module):
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
return x_B_T_H_W_D

290
comfy/ldm/krea2/model.py Normal file
View File

@ -0,0 +1,290 @@
"""Krea 2 (K2) — single-stream MMDiT.
Text tokens produced by a Qwen3-VL-4B 12-layer ``txtfusion`` adapter and patchified image tokens are
concatenated into one sequence and run through ``layers`` shared transformer blocks with
AdaLN-single modulation, GQA + per-head QK-norm + sigmoid-gated attention, SwiGLU MLP, and 3-axis RoPE.
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.model_management
import comfy.patcher_extension
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND, timestep_embedding
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention_masked
class RMSNorm(nn.Module):
"""RMSNorm with the reference ``(1 + scale)`` weight convention (scale stored zero-centered)."""
def __init__(self, features: int, eps: float = 1e-5, device=None, dtype=None, operations=None):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.empty(features, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
weight = comfy.model_management.cast_to(self.scale, dtype=torch.float32, device=x.device) + 1.0
return F.rms_norm(x.float(), (x.shape[-1],), weight=weight, eps=self.eps).to(dtype)
class QKNorm(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.qnorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
self.knorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
def forward(self, q, k):
return self.qnorm(q), self.knorm(k)
class SwiGLU(nn.Module):
def __init__(self, features: int, multiplier: int, bias: bool = False, multiple: int = 128,
device=None, dtype=None, operations=None):
super().__init__()
mlpdim = int(2 * features / 3) * multiplier
mlpdim = multiple * ((mlpdim + multiple - 1) // multiple)
self.gate = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
self.up = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
self.down = operations.Linear(mlpdim, features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.down(F.silu(self.gate(x)).mul_(self.up(x)))
class Attention(nn.Module):
def __init__(self, dim: int, heads: int, kvheads: Optional[int] = None, bias: bool = False,
device=None, dtype=None, operations=None):
super().__init__()
self.heads = heads
self.kvheads = kvheads if kvheads is not None else heads
self.headdim = dim // self.heads
self.wq = operations.Linear(dim, self.headdim * self.heads, bias=bias, device=device, dtype=dtype)
self.wk = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
self.wv = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
self.gate = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
self.qknorm = QKNorm(self.headdim, device=device, dtype=dtype, operations=operations)
self.wo = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
def forward(self, x, freqs=None, mask=None, transformer_options={}):
q, k, v, gate = self.wq(x), self.wk(x), self.wv(x), self.gate(x)
q = rearrange(q, "B L (H D) -> B H L D", H=self.heads)
k = rearrange(k, "B L (H D) -> B H L D", H=self.kvheads)
v = rearrange(v, "B L (H D) -> B H L D", H=self.kvheads)
q, k = self.qknorm(q, k)
if freqs is not None:
q, k = apply_rope(q, k, freqs)
if self.kvheads != self.heads:
rep = self.heads // self.kvheads
k = k.repeat_interleave(rep, dim=1)
v = v.repeat_interleave(rep, dim=1)
out = optimized_attention_masked(q, k, v, self.heads, mask=mask, skip_reshape=True,
transformer_options=transformer_options)
return self.wo(out * F.sigmoid(gate))
class SimpleModulation(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.lin = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
def forward(self, vec):
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device).unsqueeze(0)
scale, shift = out.chunk(2, dim=1)
return scale, shift
class DoubleSharedModulation(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.lin = nn.Parameter(torch.empty(6 * dim, device=device, dtype=dtype))
def forward(self, vec):
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device)
return out.chunk(6, dim=-1)
class TextFusionBlock(nn.Module):
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
def forward(self, x, mask=None, transformer_options={}):
x = x + self.attn(self.prenorm(x), mask=mask, transformer_options=transformer_options)
x = x + self.mlp(self.postnorm(x))
return x
class TextFusionTransformer(nn.Module):
def __init__(self, num_txt_layers, txt_dim, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.layerwise_blocks = nn.ModuleList([
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(2)
])
self.projector = operations.Linear(num_txt_layers, 1, bias=False, device=device, dtype=dtype)
self.refiner_blocks = nn.ModuleList([
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(2)
])
def forward(self, x, mask=None, transformer_options={}):
b, l, n, d = x.shape
x = x.reshape(b * l, n, d)
for block in self.layerwise_blocks:
x = block(x.contiguous(), mask=None, transformer_options=transformer_options)
x = rearrange(x, "(b l) n d -> b l d n", b=b, l=l)
x = self.projector(x).squeeze(-1)
for block in self.refiner_blocks:
x = block(x, mask=mask, transformer_options=transformer_options)
return x
class SingleStreamBlock(nn.Module):
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.mod = DoubleSharedModulation(features, device=device, dtype=dtype, operations=operations)
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
def forward(self, x, vec, freqs, mask=None, transformer_options={}):
prescale, preshift, pregate, postscale, postshift, postgate = self.mod(vec)
x = x + pregate * self.attn((1 + prescale) * self.prenorm(x) + preshift, freqs, mask, transformer_options=transformer_options)
x = x + postgate * self.mlp((1 + postscale) * self.postnorm(x) + postshift)
return x
class LastLayer(nn.Module):
def __init__(self, features, patch, channels, device=None, dtype=None, operations=None):
super().__init__()
self.norm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.linear = operations.Linear(features, patch * patch * channels, bias=True, device=device, dtype=dtype)
self.modulation = SimpleModulation(features, device=device, dtype=dtype, operations=operations)
def forward(self, x, tvec):
scale, shift = self.modulation(tvec)
x = (1 + scale) * self.norm(x) + shift
return self.linear(x)
class SingleStreamDiT(nn.Module):
def __init__(self, features=6144, tdim=256, txtdim=2560, heads=48, kvheads=12, multiplier=4,
layers=28, patch=2, channels=16, bias=False, theta=1e3, txtlayers=12,
txtheads=20, txtkvheads=20, image_model=None,
device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.patch = patch
self.channels = channels
self.tdim = tdim
self.heads = heads
self.txtdim = txtdim
self.txtlayers = txtlayers
headdim = features // heads
axes = [headdim - 12 * (headdim // 16), 6 * (headdim // 16), 6 * (headdim // 16)]
assert sum(axes) == headdim, f"axes {axes} sum != headdim {headdim}"
self.pe_embedder = EmbedND(dim=headdim, theta=int(theta), axes_dim=axes)
self.first = operations.Linear(channels * patch ** 2, features, bias=True, device=device, dtype=dtype)
self.blocks = nn.ModuleList([
SingleStreamBlock(features, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(layers)
])
self.tmlp = nn.Sequential(
operations.Linear(tdim, features, device=device, dtype=dtype),
nn.GELU(approximate="tanh"),
operations.Linear(features, features, device=device, dtype=dtype),
)
self.txtfusion = TextFusionTransformer(txtlayers, txtdim, txtheads, multiplier, bias, txtkvheads,
device=device, dtype=dtype, operations=operations)
self.txtmlp = nn.Sequential(
RMSNorm(txtdim, device=device, dtype=dtype, operations=operations),
operations.Linear(txtdim, features, device=device, dtype=dtype),
nn.GELU(approximate="tanh"),
operations.Linear(features, features, device=device, dtype=dtype),
)
self.last = LastLayer(features, patch, channels, device=device, dtype=dtype, operations=operations)
self.tproj = nn.Sequential(
nn.GELU(approximate="tanh"),
operations.Linear(features, features * 6, device=device, dtype=dtype),
)
def forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
temporal = x.ndim == 5
if temporal:
b5, c5, t5, h5, w5 = x.shape
x = x.reshape(b5 * t5, c5, h5, w5)
bs, c, H_orig, W_orig = x.shape
patch = self.patch
# Pad the latent up to a multiple of patch (as Flux/Lumina/QwenImage do); crop back at the end.
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch, patch))
H, W = x.shape[-2], x.shape[-1]
h_, w_ = H // patch, W // patch
# context arrives as (B, seq, txtlayers*txtdim); reshape to (B, txtlayers, seq, txtdim).
context = self._unpack_context(context)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch, pw=patch)
img = self.first(img)
t = self.tmlp(timestep_embedding(timesteps, self.tdim).unsqueeze(1).to(img.dtype))
tvec = self.tproj(t)
context = self.txtfusion(context, mask=None, transformer_options=transformer_options)
context = self.txtmlp(context)
txtlen, imglen = context.shape[1], img.shape[1]
combined = torch.cat((context, img), dim=1)
# Position ids: text at 0, image at (0, h_idx, w_idx).
device = combined.device
txtpos = torch.zeros(bs, txtlen, 3, device=device, dtype=torch.float32)
imgids = torch.zeros(h_, w_, 3, device=device, dtype=torch.float32)
imgids[..., 1] = torch.arange(h_, device=device, dtype=torch.float32)[:, None]
imgids[..., 2] = torch.arange(w_, device=device, dtype=torch.float32)[None, :]
imgpos = imgids.reshape(1, h_ * w_, 3).repeat(bs, 1, 1)
pos = torch.cat((txtpos, imgpos), dim=1)
freqs = self.pe_embedder(pos)
for block in self.blocks:
combined = block(combined, tvec, freqs, None, transformer_options=transformer_options)
final = self.last(combined, t)
out = final[:, txtlen:txtlen + imglen, :]
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=h_, w=w_, ph=patch, pw=patch, c=self.channels)
out = out[:, :, :H_orig, :W_orig] # crop padding back off
if temporal:
out = out.reshape(b5, t5, self.channels, H_orig, W_orig).movedim(1, 2)
return out
def _unpack_context(self, context):
# context: (B, seq, txtlayers*txtdim) -> (B, seq, txtlayers, txtdim).
b, seq, fused = context.shape
if fused != self.txtlayers * self.txtdim:
raise ValueError(
f"Krea2 expects conditioning with {self.txtlayers}x{self.txtdim}={self.txtlayers * self.txtdim} "
f"features (a {self.txtlayers}-layer Qwen3-VL stack) but got {fused}. "
f"Load the text encoder with CLIPLoader type 'krea2'."
)
return context.reshape(b, seq, self.txtlayers, self.txtdim)

View File

@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel):
)
grid_mask = None
if keyframe_idxs is not None:
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
additional_args.update({ "orig_patchified_shape": list(x.shape)})
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel):
x = x * (1 + scale) + shift
x = self.proj_out(x)
if keyframe_idxs is not None:
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
grid_mask = kwargs["grid_mask"]
orig_patchified_shape = kwargs["orig_patchified_shape"]
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)

View File

@ -22,7 +22,7 @@ def apply_rotary_emb(x, freqs_cis):
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return F.silu(x) * y
return F.silu(x, inplace=True).mul_(y)
class TimestepEmbedding(nn.Module):

View File

@ -1665,7 +1665,7 @@ class SCAILWanModel(WanModel):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if ref_mask_latents is not None: # SCAIL-2 additive mask stream
if ref_mask_latents is not None: # SCAIL-2 additive mask stream (one identity mask frame per reference, then video)
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
@ -1728,22 +1728,25 @@ class SCAILWanModel(WanModel):
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
# reference_latent may stack several frames: the last is the primary reference adjacent to the video, the earlier frames are additional references.
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
if ref_mask_flag is not None and not bool(ref_mask_flag):
REF_ROPE_H = 120.0
POSE_ROPE_W = 120.0
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
main_t_patches = t - ref_t_patches
video_t_start = max(ref_t_patches - 1, 0)
parts = []
if ref_t_patches > 0:
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
if main_t_patches > 0:
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
parts.append(super().rope_encode(main_t_patches, h, w, t_start=video_t_start, device=device, dtype=dtype, transformer_options=transformer_options))
if pose_latents is not None:
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
@ -1752,7 +1755,7 @@ class SCAILWanModel(WanModel):
h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=video_t_start, device=device, dtype=dtype, transformer_options=pose_tf))
return torch.cat(parts, dim=1)
@ -1761,10 +1764,6 @@ class SCAILWanModel(WanModel):
if pose_latents is None:
return main_freqs
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames

View File

@ -326,6 +326,17 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.Krea2):
diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
key_map[key_lora] = to
if isinstance(model, comfy.model_base.Lumina2):
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:

View File

@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
import comfy.ldm.lightricks.symmetric_patchifier
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
@ -54,9 +55,11 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.boogu.model
import comfy.ldm.qwen_image.model
import comfy.ldm.joyimage.model
import comfy.ldm.ideogram4.model
import comfy.ldm.krea2.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
@ -1204,6 +1207,127 @@ class LTXAV(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
result = [primary_indices]
if len(latent_shapes) < 2:
return result
video_total = latent_shapes[0][dim]
for i in range(1, len(latent_shapes)):
mod_total = latent_shapes[i][dim]
# Map each primary index to its proportional range of modality indices and
# concatenate in order. Preserves wrapped/strided geometry so the modality
# attends to the same temporal regions as the primary window.
mod_indices = []
seen = set()
for v_idx in primary_indices:
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
if a_end <= a_start:
a_end = a_start + 1
for a in range(a_start, a_end):
if a not in seen:
seen.add(a)
mod_indices.append(a)
result.append(mod_indices)
return result
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# Audio denoise mask — slice using audio modality window
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
audio_window = window.modality_windows.get(1)
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
return cond_value._copy_with(sliced)
# Video denoise mask — split into video + guide portions, slice each
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
cond_tensor = cond_value.cond
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
if guide_count > 0:
T_video = x_in.size(window.dim)
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
suffix_indices = window.guide_frames_indices
if suffix_indices:
idx = tuple([slice(None)] * window.dim + [suffix_indices])
sliced_guide = guide_mask[idx].to(device)
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
else:
return cond_value._copy_with(sliced_video)
# Keyframe indices — regenerate pixel coords for window, select guide positions
if cond_key == "keyframe_idxs":
kf_local_pos = window.guide_kf_local_positions
if not kf_local_pos:
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
H, W = x_in.shape[3], x_in.shape[4]
window_len = len(window.index_list)
# account for causal_window_fix anchor in coord space size
anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
window_len += 1
patchifier = self.diffusion_model.patchifier
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
scale_factors = self.diffusion_model.vae_scale_factors
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
latent_coords,
scale_factors,
causal_fix=self.diffusion_model.causal_temporal_positioning)
tokens = []
for pos in kf_local_pos:
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
pixel_coords = pixel_coords[:, :, tokens, :]
# Adjust spatial end positions for dilated (downscaled) guides.
# Each guide entry may have a different downscale factor; expand the
# per-entry factor to cover all tokens belonging to that entry.
downscale_factors = window.guide_downscale_factors
overlap_info = window.guide_overlap_info
if downscale_factors:
per_token_factor = []
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
per_token_factor.extend([dsf] * (overlap_count * H * W))
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
).view(1, -1, 1, 1)
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
B = cond_value.cond.shape[0]
if B > 1:
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
return cond_value._copy_with(pixel_coords)
# Guide attention entries — adjust per-guide counts based on window overlap
if cond_key == "guide_attention_entries":
overlap_info = window.guide_overlap_info
H, W = x_in.shape[3], x_in.shape[4]
new_entries = []
for entry_idx, overlap_count in overlap_info:
e = cond_value.cond[entry_idx]
new_entries.append({**e,
"pre_filter_count": overlap_count * H * W,
"latent_shape": [overlap_count, H, W]})
return cond_value._copy_with(new_entries)
return None
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@ -1748,10 +1872,14 @@ class WAN21_SCAIL(WAN21):
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
ref_latent = self.process_latent_in(reference_latents[-1])
ref_mask = torch.ones_like(ref_latent[:, :4])
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
# SCAIL-2 multi-reference: reference_latents[0] is the primary ref, [1:] are additional
# references. Stack as [additional..., primary] so the primary stays adjacent to the video.
ordered = list(reference_latents[1:]) + list(reference_latents[:1])
stacked = []
for lat in ordered:
lat = self.process_latent_in(lat)
stacked.append(torch.cat([lat, torch.ones_like(lat[:, :4])], dim=1))
out['reference_latent'] = comfy.conds.CONDRegular(torch.cat(stacked, dim=2))
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
@ -1793,6 +1921,7 @@ class WAN21_SCAIL2(WAN21_SCAIL):
if driving_mask_28ch is not None:
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
# ref_mask_28ch holds one identity mask per stacked reference frame (additional refs first, then the primary ref), followed by zeros over the video frames.
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
if ref_mask_28ch is not None:
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
@ -1820,10 +1949,11 @@ class WAN21_SCAIL2(WAN21_SCAIL):
# Return sliced view omitting retain_index_list
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0)
if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
# The ref mask is just a single frame padded with frames of zeros, so just grab the first frames for all windows
# The ref mask is N leading ref frames padded with frames of zeros, so just grab the first frames for all windows
full_ref_mask = cond_value.cond
video_frame_count = x_in.shape[2]
if full_ref_mask.shape[2] != video_frame_count + 1:
ref_frame_count = full_ref_mask.shape[2] - video_frame_count
if ref_frame_count < 1:
return None
window_length = len(window.index_list)
@ -1832,7 +1962,7 @@ class WAN21_SCAIL2(WAN21_SCAIL):
if anchor_index is not None and anchor_index >= 0:
window_length += 1
window_ref_mask = full_ref_mask[:, :, :window_length + 1].to(device)
window_ref_mask = full_ref_mask[:, :, :window_length + ref_frame_count].to(device)
return cond_value._copy_with(window_ref_mask)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
@ -2098,6 +2228,11 @@ class Omnigen2(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class Boogu(Omnigen2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(Omnigen2, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.boogu.model.BooguTransformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
@ -2265,6 +2400,17 @@ class Ideogram4(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Krea2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.krea2.model.SingleStreamDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)

View File

@ -761,6 +761,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if '{}double_stream_layers.0.img_instruct_attn.processor.img_to_q.weight'.format(key_prefix) in state_dict_keys: # Boogu-Image (OmniGen2 derivative + dual-stream stage)
dit_config = {}
dit_config["image_model"] = "boogu"
dit_config["hidden_size"] = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}single_stream_layers.'.format(key_prefix) + '{}.')
dit_config["num_double_stream_layers"] = count_blocks(state_dict_keys, '{}double_stream_layers.'.format(key_prefix) + '{}.')
dit_config["num_refiner_layers"] = count_blocks(state_dict_keys, '{}noise_refiner.'.format(key_prefix) + '{}.')
dit_config["instruction_feat_dim"] = state_dict['{}time_caption_embed.caption_embedder.0.weight'.format(key_prefix)].shape[0]
return dit_config
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {}
dit_config["image_model"] = "omnigen2"
@ -845,6 +855,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
return dit_config
if '{}txtfusion.projector.weight'.format(key_prefix) in state_dict_keys: # Krea 2 (K2)
dit_config = {}
dit_config["image_model"] = "krea2"
head_dim = 128
first_w = state_dict['{}first.weight'.format(key_prefix)] # (features, channels*patch^2)
dit_config["features"] = first_w.shape[0]
dit_config["channels"] = first_w.shape[1] // (2 * 2) # patch=2
dit_config["patch"] = 2
dit_config["layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["heads"] = state_dict['{}blocks.0.attn.wq.weight'.format(key_prefix)].shape[0] // head_dim
dit_config["kvheads"] = state_dict['{}blocks.0.attn.wk.weight'.format(key_prefix)].shape[0] // head_dim
dit_config["txtlayers"] = state_dict['{}txtfusion.projector.weight'.format(key_prefix)].shape[1]
dit_config["txtdim"] = state_dict['{}txtfusion.layerwise_blocks.0.prenorm.scale'.format(key_prefix)].shape[0]
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]

View File

@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key)
if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0:
@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs}
elif module.quant_format == "int8_tensorwise":
scale = pop_scale("weight_scale")
if scale is None:
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
scales = {"scale": scale}
params_conf = layer_conf.get("params", {})
if not isinstance(params_conf, dict):
params_conf = {}
if layer_conf.get("convrot", params_conf.get("convrot", False)):
scales["convrot"] = True
scales["convrot_groupsize"] = int(
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
)
else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
quant_conf = {"format": module.quant_format}
if getattr(module, '_full_precision_mm_config', False):
quant_conf["full_precision_matrix_mult"] = True
params = getattr(module.weight, "_params", None)
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
quant_conf["convrot"] = True
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
if extra_quant_conf:
quant_conf.update(extra_quant_conf)
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
def forward_comfy_cast_weights(
self,
input,
compute_dtype=None,
want_requant=False,
weight_only_quant=False,
):
if weight_only_quant:
weight, bias, offload_stream = cast_bias_weight(
self,
input=None,
dtype=self.weight.dtype,
device=input.device,
bias_dtype=input.dtype,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True,
)
weight = weight.to(dtype=input.dtype)
else:
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
if (input.requires_grad and _use_quantized and quantize_input):
weight, bias, offload_stream = cast_bias_weight(
self,
@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
return output
# Inference path (unchanged)
if _use_quantized:
if _use_quantized and quantize_input:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
output = self.forward_comfy_cast_weights(
input,
compute_dtype,
want_requant=isinstance(input, QuantizedTensor),
weight_only_quant=weight_only_quant,
)
# Reshape output back to 3D if input was 3D
if reshaped_3d:
@ -1257,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:

View File

@ -10,6 +10,7 @@ try:
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout as _CKNvfp4Layout,
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
register_layout_op,
register_layout_class,
get_layout_class,
@ -47,6 +48,9 @@ except ImportError as e:
class _CKNvfp4Layout:
pass
class _CKTensorWiseINT8Layout:
pass
def register_layout_class(name, cls):
pass
@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
# ==============================================================================
@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE:
"group_size": 32,
}
QUANT_ALGOS["int8_tensorwise"] = {
"storage_t": torch.int8,
"parameters": {"weight_scale"},
"comfy_tensor_layout": "TensorWiseINT8Layout",
"quantize_input": False,
}
# ==============================================================================
# Re-exports for backward compatibility
@ -226,6 +239,7 @@ __all__ = [
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout",
"TensorWiseINT8Layout",
"QUANT_ALGOS",
"register_layout_op",
]

View File

@ -58,6 +58,7 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.krea2
import comfy.text_encoders.ideogram4
import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
@ -68,6 +69,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.qwen3vl
import comfy.text_encoders.boogu
import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
@ -1302,7 +1304,9 @@ class CLIPType(Enum):
LENS = 28
PIXELDIT = 29
IDEOGRAM4 = 30
JOYIMAGE = 31
BOOGU = 31
KREA2 = 32
JOYIMAGE = 33
@ -1627,6 +1631,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer
elif clip_type == CLIPType.BOOGU and te_model == TEModel.QWEN3VL_8B: # Boogu-Image: full Qwen3-VL-8B, last hidden state, no-think template.
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer
elif clip_type == CLIPType.KREA2 and te_model == TEModel.QWEN3VL_4B: # Krea2: full Qwen3-VL-4B (12-layer tap for conditioning + multimodal generate).
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.krea2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.krea2.Krea2Tokenizer
elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused.
klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b"
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type)
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B if te_model == TEModel.QWEN3VL_8B else comfy.text_encoders.flux.KleinTokenizer
else:
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]

View File

@ -25,6 +25,8 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.ideogram4
import comfy.text_encoders.boogu
import comfy.text_encoders.krea2
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
@ -1758,6 +1760,27 @@ class Omnigen2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class Boogu(Omnigen2):
unet_config = {
"image_model": "boogu",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.16,
}
memory_usage_factor = 2.15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Boogu(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.boogu.BooguTokenizer, comfy.text_encoders.boogu.te(**hunyuan_detect))
class Ideogram4(supported_models_base.BASE):
unet_config = {
"image_model": "ideogram4",
@ -1796,6 +1819,35 @@ class Ideogram4(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect))
class Krea2(supported_models_base.BASE):
unet_config = {
"image_model": "krea2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.15,
}
memory_usage_factor = 2.2
latent_format = latent_formats.Wan21
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Krea2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.krea2.Krea2Tokenizer, comfy.text_encoders.krea2.te(**hunyuan_detect))
class QwenImage(supported_models_base.BASE):
unet_config = {
"image_model": "qwen_image",
@ -2339,9 +2391,11 @@ models = [
ACEStep,
ACEStep15,
Omnigen2,
Boogu,
QwenImage,
JoyImage,
Ideogram4,
Krea2,
Flux2,
Lens,
Kandinsky5Image,

View File

@ -0,0 +1,58 @@
"""Boogu-Image text encoder: full Qwen3-VL-8B, last hidden state (4096-dim).
Boogu uses the final hidden state of Qwen3-VL as the per-token instruction feature
(num_instruction_feature_layers=1, reduce_type=mean -> just the last layer).
The model itself is the standard Qwen3-VL TE, only the chat template differs
(a fixed system prompt and no <think> block).
"""
import comfy.text_encoders.qwen3vl
from comfy import sd1_clip
# System prompts from the reference pipeline (pipeline_boogu.py).
# T2I (non-empty instruction, no image) uses the helpful-assistant prompt
# everything else (the CFG negative / "drop" condition, and any image case) uses the TI2I "describe" prompt.
BOOGU_T2I_SYSTEM = "You are a helpful assistant that generates high-quality images based on user instructions. The instructions are as follows."
BOOGU_DROP_SYSTEM = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
class BooguTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b")
# apply_chat_template without add_generation_prompt
self.llama_template = "<|im_start|>system\n" + BOOGU_T2I_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n"
self.llama_template_images = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n"
# Reference SYSTEM_PROMPT_DROP: used for the empty negative/uncond instruction.
self.llama_template_drop = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
if llama_template is None and len(images) == 0 and text.strip() == "":
llama_template = self.llama_template_drop
# Boogu conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds by default.
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
class BooguQwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"):
super().__init__(device=device, dtype=dtype, attention_mask=attention_mask, model_options=model_options, model_type=model_type)
# apply the final RMSNorm to the tapped last layer
self.layer_norm_hidden_state = True
class BooguTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
clip_model = lambda **kw: BooguQwen3VLClipModel(**kw, model_type="qwen3vl_8b")
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None):
class BooguTEModel_(BooguTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return BooguTEModel_

View File

@ -0,0 +1,84 @@
"""Krea 2 (K2) text encoder: Qwen3-VL-4B, 12-layer tap.
K2 conditions on a stack of hidden states from 12 layers of Qwen3-VL-4B
(reference taps ``hidden_states[2,5,8,...,35]``), kept as a ``(B, 12, seq, 2560)`` tensor and
consumed by the DiT's internal ``txtfusion`` adapter. Comfy carries conditioning as a 3D tensor,
so the 12-layer stack is flattened to ``(B, seq, 12*2560)`` here and unpacked inside the model.
"""
import numbers
import torch
import comfy.text_encoders.qwen3vl
from comfy import sd1_clip
# tap k == hidden_states[k] (no offset).
KREA2_TAP_LAYERS = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35]
# Identical system template to Qwen-Image; Krea2 strips the system+user-opening prefix.
KREA2_TEMPLATE = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
class Krea2Tokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_4b")
self.llama_template = KREA2_TEMPLATE # conditioning template; image text-gen uses qwen3vl's default image template.
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
# Krea2 conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds.
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
class Krea2Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=KREA2_TAP_LAYERS, layer_idx=None, dtype=dtype,
attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_4b")
class Krea2TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3vl_4b", clip_model=Krea2Qwen3VLClipModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled, extra = super().encode_token_weights(token_weight_pairs) # out: (B, 12, seq, 2560)
tok_pairs = token_weight_pairs["qwen3vl_4b"][0]
# Strip the system + user-opening prefix
count_im_start = 0
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[2] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872: # "user"
if tok_pairs[template_end + 2][0] == 198: # "\n"
template_end += 3
out = out[:, :, template_end:]
b, n, seq, h = out.shape
# Flatten the 12-layer axis into the feature dim: (B, seq, 12*2560). Unpacked in the model.
out = out.permute(0, 2, 1, 3).reshape(b, seq, n * h)
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask")
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class Krea2TEModel_(Krea2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Krea2TEModel_

View File

@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def krea2_to_diffusers(mmdit_config, output_prefix=""):
n_layers = mmdit_config.get("layers", 0)
n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks
n_txt_refiner = 2
key_map = {}
def add_block(prefix_to, prefix_from):
block_map = {
"attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv",
"attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo",
"attn.to_out": "attn.wo", # some tools drop the ".0" on to_out
"ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down",
}
for d, c in block_map.items():
key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c)
for i in range(n_layers):
add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i))
for i in range(n_txt_layerwise):
add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i))
for i in range(n_txt_refiner):
add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i))
MAP_BASIC = [
("img_in", "first"),
("time_embed.linear_1", "tmlp.0"),
("time_embed.linear_2", "tmlp.2"),
("time_mod_proj", "tproj.1"),
("txt_in.linear_1", "txtmlp.1"),
("txt_in.linear_2", "txtmlp.3"),
("text_fusion.projector", "txtfusion.projector"),
("final_layer.linear", "last.linear"),
]
for d, c in MAP_BASIC:
key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c)
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)

View File

@ -25,6 +25,11 @@ CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
"default": False,
"description": "Show the sign-in button in the frontend even when not signed in",
},
"enable_telemetry": {
"type": "bool",
"default": False,
"description": "Signal the frontend that telemetry collection is enabled",
},
}

View File

@ -891,6 +891,14 @@ class Tracks(ComfyTypeIO):
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="DICT")
class Dict(ComfyTypeIO):
Type = dict
@comfytype(io_type="ARRAY")
class Array(ComfyTypeIO):
Type = list
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -1279,6 +1287,19 @@ class Color(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLORS")
class Colors(ComfyTypeIO):
Type = list[Color.Type]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[str]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = []
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict):
@ -1326,6 +1347,20 @@ class Curve(ComfyTypeIO):
return d
@comfytype(io_type="BOUNDING_BOXES")
class BoundingBoxes(ComfyTypeIO):
class BoundingBoxWithMetadata(BoundingBox.BoundingBoxDict):
metadata: dict
Type = list[BoundingBoxWithMetadata]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[dict]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = []
@comfytype(io_type="HISTOGRAM")
class Histogram(ComfyTypeIO):
"""A histogram represented as a list of bin counts."""
@ -2376,6 +2411,8 @@ __all__ = [
"AnyType",
"MultiType",
"Tracks",
"Dict",
"Array",
"Color",
# Dynamic Types
"MatchType",
@ -2394,6 +2431,8 @@ __all__ = [
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"BoundingBoxes",
"Colors",
"Curve",
"Histogram",
"Range",

View File

@ -163,15 +163,31 @@ class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
# Dollars per 1K tokens, keyed by (model_id, has_video_input, resolution).
SEEDANCE2_PRICE_PER_1K_TOKENS = {
("dreamina-seedance-2-0-260128", False): 0.007,
("dreamina-seedance-2-0-260128", True): 0.0043,
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
("dreamina-seedance-2-0-260128", False, "480p"): 0.007,
("dreamina-seedance-2-0-260128", True, "480p"): 0.0043,
("dreamina-seedance-2-0-260128", False, "720p"): 0.007,
("dreamina-seedance-2-0-260128", True, "720p"): 0.0043,
("dreamina-seedance-2-0-260128", False, "1080p"): 0.0077,
("dreamina-seedance-2-0-260128", True, "1080p"): 0.0047,
("dreamina-seedance-2-0-260128", False, "4k"): 0.004,
("dreamina-seedance-2-0-260128", True, "4k"): 0.0024,
("dreamina-seedance-2-0-fast-260128", False, "480p"): 0.0056,
("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033,
("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056,
("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033,
("dreamina-seedance-2-0-mini", False, "480p"): 0.0035,
("dreamina-seedance-2-0-mini", True, "480p"): 0.0021,
("dreamina-seedance-2-0-mini", False, "720p"): 0.0035,
("dreamina-seedance-2-0-mini", True, "720p"): 0.0021,
}
def seedance2_price_per_1k_tokens(model_id: str, has_video_input: bool, resolution: str) -> float | None:
return SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input, resolution))
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
@ -266,6 +282,10 @@ SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
"dreamina-seedance-2-0-mini": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
}
# The time in this dictionary are given for 10 seconds duration.

View File

@ -121,6 +121,7 @@ class GeminiGenerationConfig(BaseModel):
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)
thinkingConfig: GeminiThinkingConfig | None = Field(None)
responseModalities: list[str] | None = Field(None)
class GeminiImageOutputOptions(BaseModel):

View File

@ -149,3 +149,59 @@ class MotionControlRequest(BaseModel):
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...)
class Kling3TurboSettings(BaseModel):
resolution: str = Field("720p", description="'720p' or '1080p'")
aspect_ratio: str | None = Field(None, description="'16:9'/'9:16'/'1:1'; text-to-video only")
duration: int = Field(5, description="3-15 second")
class Kling3TurboText2VideoRequest(BaseModel):
prompt: str = Field(..., description="<=3072 chars; may use multi-shot 'shot n, m, words; ...'")
settings: Kling3TurboSettings | None = Field(None)
class Kling3TurboContent(BaseModel):
type: str = Field(..., description="'prompt' or 'first_frame'")
text: str | None = Field(None, description="for type=prompt; <=2500 chars")
url: str | None = Field(None, description="for type=first_frame")
class Kling3TurboImage2VideoRequest(BaseModel):
contents: list[Kling3TurboContent] = Field(..., description="prompt + first_frame materials")
settings: Kling3TurboSettings | None = Field(None)
class Kling3TurboCreateData(BaseModel):
id: str | None = Field(None, description="Task ID")
status: str | None = Field(None)
message: str | None = Field(None)
class Kling3TurboCreateResponse(BaseModel):
code: int | None = Field(None)
message: str | None = Field(None)
request_id: str | None = Field(None)
data: Kling3TurboCreateData | None = Field(None)
class Kling3TurboOutput(BaseModel):
type: str | None = Field(None, description="'video', 'image', 'audio', ...")
id: str | None = Field(None)
url: str | None = Field(None)
duration: str | None = Field(None)
class Kling3TurboTaskData(BaseModel):
id: str | None = Field(None)
status: str | None = Field(None, description="submitted | processing | succeeded | failed")
message: str | None = Field(None)
outputs: list[Kling3TurboOutput] | None = Field(None)
class Kling3TurboQueryResponse(BaseModel):
code: int | None = Field(None)
message: str | None = Field(None)
request_id: str | None = Field(None)
data: list[Kling3TurboTaskData] | None = Field(None)

View File

@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, confloat
class LumaIO:
LUMA_REF = "LUMA_REF"
LUMA_CONCEPTS = "LUMA_CONCEPTS"
LUMA_RAY32_KEYFRAME = "LUMA_RAY32_KEYFRAME"
class LumaReference:
@ -20,13 +21,14 @@ class LumaReference:
def create_api_model(self, download_url: str):
return LumaImageRef(url=download_url, weight=self.weight)
class LumaReferenceChain:
def __init__(self, first_ref: LumaReference=None):
def __init__(self, first_ref: LumaReference = None):
self.refs: list[LumaReference] = []
if first_ref:
self.refs.append(first_ref)
def add(self, luma_ref: LumaReference=None):
def add(self, luma_ref: LumaReference = None):
self.refs.append(luma_ref)
def create_api_model(self, download_urls: list[str], max_refs=4):
@ -124,7 +126,7 @@ def get_luma_concepts(include_none=False):
"pull_out",
"aerial",
"crane_up",
"eye_level"
"eye_level",
]
@ -162,8 +164,8 @@ class LumaVideoModelOutputDuration(str, Enum):
class LumaGenerationType(str, Enum):
video = 'video'
image = 'image'
video = "video"
image = "image"
class LumaState(str, Enum):
@ -174,86 +176,109 @@ class LumaState(str, Enum):
class LumaAssets(BaseModel):
video: Optional[str] = Field(None, description='The URL of the video')
image: Optional[str] = Field(None, description='The URL of the image')
progress_video: Optional[str] = Field(None, description='The URL of the progress video')
video: Optional[str] = Field(None, description="The URL of the video")
image: Optional[str] = Field(None, description="The URL of the image")
progress_video: Optional[str] = Field(None, description="The URL of the progress video")
class LumaImageRef(BaseModel):
"""Used for image gen"""
url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
url: str = Field(..., description="The URL of the image reference")
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
class LumaImageReference(BaseModel):
"""Used for video gen"""
type: Optional[str] = Field('image', description='Input type, defaults to image')
url: str = Field(..., description='The URL of the image')
type: Optional[str] = Field("image", description="Input type, defaults to image")
url: str = Field(..., description="The URL of the image")
class LumaModifyImageRef(BaseModel):
url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
url: str = Field(..., description="The URL of the image reference")
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
class LumaCharacterRef(BaseModel):
identity0: LumaImageIdentity = Field(..., description='The image identity object')
identity0: LumaImageIdentity = Field(..., description="The image identity object")
class LumaImageIdentity(BaseModel):
images: list[str] = Field(..., description='The URLs of the image identity')
images: list[str] = Field(..., description="The URLs of the image identity")
class LumaGenerationReference(BaseModel):
type: str = Field('generation', description='Input type, defaults to generation')
id: str = Field(..., description='The ID of the generation')
type: str = Field("generation", description="Input type, defaults to generation")
id: str = Field(..., description="The ID of the generation")
class LumaKeyframes(BaseModel):
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
class LumaConceptObject(BaseModel):
key: str = Field(..., description='Camera Concept name')
key: str = Field(..., description="Camera Concept name")
class LumaImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The prompt of the generation')
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
prompt: str = Field(..., description="The prompt of the generation")
model: LumaImageModel = Field(LumaImageModel.photon_1, description="The image model used for the generation")
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9)
image_ref: Optional[list[LumaImageRef]] = Field(None, description="List of image reference objects")
style_ref: Optional[list[LumaImageRef]] = Field(None, description="List of style reference objects")
character_ref: Optional[LumaCharacterRef] = Field(None, description="The image identity object")
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description="The modify image reference object")
class LumaGenerationRequest(BaseModel):
prompt: str = Field(..., description='The prompt of the generation')
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
loop: Optional[bool] = Field(None, description='Whether to loop the video')
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
prompt: str = Field(..., description="The prompt of the generation")
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description="The video model used for the generation")
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description="The duration of the generation")
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description="The aspect ratio of the generation")
resolution: Optional[LumaVideoOutputResolution] = Field(None, description="The resolution of the generation")
loop: Optional[bool] = Field(None, description="Whether to loop the video")
keyframes: Optional[LumaKeyframes] = Field(None, description="The keyframes of the generation")
concepts: Optional[list[LumaConceptObject]] = Field(None, description="Camera Concepts to apply to generation")
class LumaGeneration(BaseModel):
id: str = Field(..., description='The ID of the generation')
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
state: LumaState = Field(..., description='The state of the generation')
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
created_at: str = Field(..., description='The date and time when the generation was created')
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
model: str = Field(..., description='The model used for the generation')
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
id: str = Field(..., description="The ID of the generation")
generation_type: LumaGenerationType = Field(..., description="Generation type, image or video")
state: LumaState = Field(..., description="The state of the generation")
failure_reason: Optional[str] = Field(None, description="The reason for the state of the generation")
created_at: str = Field(..., description="The date and time when the generation was created")
assets: Optional[LumaAssets] = Field(None, description="The assets of the generation")
model: str = Field(..., description="The model used for the generation")
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(...)
class Luma2ImageRef(BaseModel):
url: str | None = None
data: str | None = None
media_type: str | None = None
generation_id: str | None = Field(None, description="reference a prior generation (extend / source reuse)")
class Luma2VideoEdit(BaseModel):
"""Edit controls for Ray 3.2 ``video_edit`` generations."""
auto_controls: bool | None = Field(None, description="derive a conditioning schedule from the source (recommended)")
strength: str | None = Field(None, description="'adhere_1' .. 'reimagine_3'; constrained by IO.Combo")
class Luma2VideoOptions(BaseModel):
"""Ray 3.2 ``video`` output settings (text / image / keyframe / edit / extend)."""
resolution: str | None = Field(None, description="360p | 540p | 720p | 1080p")
duration: str | None = Field(None, description="5s | 10s")
loop: bool | None = Field(None)
start_frame: Luma2ImageRef | None = Field(None)
end_frame: Luma2ImageRef | None = Field(None)
keyframes: list[Luma2ImageRef] | None = Field(None)
keyframe_indexes: list[int] | None = Field(None)
edit: Luma2VideoEdit | None = Field(None)
class Luma2GenerationRequest(BaseModel):
@ -266,6 +291,7 @@ class Luma2GenerationRequest(BaseModel):
web_search: bool | None = None
image_ref: list[Luma2ImageRef] | None = None
source: Luma2ImageRef | None = None
video: Luma2VideoOptions | None = Field(None)
class Luma2Generation(BaseModel):
@ -277,3 +303,31 @@ class Luma2Generation(BaseModel):
output: list[LumaImageReference] | None = None
failure_reason: str | None = None
failure_code: str | None = None
# --- Ray 3.2 multi-keyframe chain ---
LUMA_KEYFRAME_MODE_FRACTION = "fraction" # value in [0.0, 1.0] of the output video duration
LUMA_KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the output
class LumaRay32KeyframeItem:
"""One guide image anchored at a position on the Ray 3.2 output timeline."""
def __init__(self, image: torch.Tensor, mode: str, value: float):
self.image = image
self.mode = mode # LUMA_KEYFRAME_MODE_FRACTION | LUMA_KEYFRAME_MODE_SECONDS
self.value = value
class LumaRay32KeyframeChain:
def __init__(self):
self.items: list[LumaRay32KeyframeItem] = []
def add(self, item: LumaRay32KeyframeItem) -> None:
self.items.append(item)
def clone(self) -> "LumaRay32KeyframeChain":
c = LumaRay32KeyframeChain()
c.items = list(self.items)
return c

View File

@ -15,7 +15,6 @@ from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS_SEEDREAM_4_0,
RECOMMENDED_PRESETS_SEEDREAM_4_5,
RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
SEEDANCE2_PRICE_PER_1K_TOKENS,
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
VIDEO_TASKS_EXECUTION_TIME,
GetAssetResponse,
@ -40,6 +39,7 @@ from comfy_api_nodes.apis.bytedance import (
TaskVideoContentUrl,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
seedance2_price_per_1k_tokens,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -89,6 +89,7 @@ BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/cont
SEEDANCE_MODELS = {
"Seedance 2.0": "dreamina-seedance-2-0-260128",
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
"Seedance 2.0 Mini": "dreamina-seedance-2-0-mini",
}
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
@ -141,7 +142,7 @@ SEEDANCE2_RATIO_WH = {
"9:16": (9, 16),
"21:9": (21, 9),
}
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080, "4k": 2160}
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
@ -377,9 +378,9 @@ async def _seedance_virtual_library_upload_video_asset(
return f"asset://{create_resp.asset_id}"
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
def _seedance2_price_extractor(model_id: str, has_video_input: bool, resolution: str):
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
rate = seedance2_price_per_1k_tokens(model_id, has_video_input, resolution)
if rate is None:
return None
@ -1621,10 +1622,12 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
IO.DynamicCombo.Option("Seedance 2.0 Mini", _seedance2_text_inputs(["480p", "720p"])),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
),
IO.Int.Input(
"seed",
@ -1660,11 +1663,16 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "1080p" ? $rate1080 :
$pricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$cost := $dur * $rate * $pricePer1K / 1000;
@ -1703,7 +1711,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -1724,14 +1732,19 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
options=[
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
_seedance2_text_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Mini",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
),
IO.Image.Input(
"first_frame",
@ -1791,11 +1804,16 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "1080p" ? $rate1080 :
$pricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$cost := $dur * $rate * $pricePer1K / 1000;
@ -1913,7 +1931,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -2010,14 +2028,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
options=[
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
_seedance2_reference_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Mini",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
),
IO.Int.Input(
"seed",
@ -2056,13 +2079,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model;
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "1080p" ? $rate1080 :
$noVideoPricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $res = "4k" ? 0.003432 :
$res = "1080p" ? 0.006721 :
$contains($m, "mini") ? 0.003003 :
$contains($m, "fast") ? 0.004719 : 0.006149;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
@ -2258,7 +2289,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
price_extractor=_seedance2_price_extractor(
model_id, has_video_input=has_video_input, resolution=model["resolution"]
),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))

View File

@ -5,7 +5,6 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
import base64
import os
from enum import Enum
from fnmatch import fnmatch
from io import BytesIO
from typing import Any, Literal
@ -14,7 +13,7 @@ import torch
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
from comfy_api_nodes.apis.gemini import (
GeminiContent,
GeminiFileData,
@ -38,6 +37,7 @@ from comfy_api_nodes.util import (
audio_to_base64_string,
bytesio_to_image_tensor,
download_url_to_image_tensor,
download_url_to_video_output,
get_number_of_images,
sync_op,
tensor_to_base64_string,
@ -46,6 +46,7 @@ from comfy_api_nodes.util import (
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_duration,
video_to_base64_string,
)
@ -78,15 +79,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
)
class GeminiImageModel(str, Enum):
"""
Gemini Image Model Names allowed by comfy-api
"""
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
gemini_2_5_flash_image = "gemini-2.5-flash-image"
async def create_image_parts(
cls: type[IO.ComfyNode],
images: Input.Image | list[Input.Image],
@ -239,25 +231,38 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
return torch.cat(image_tensors, dim=0)
async def get_video_from_response(
response: GeminiGenerateContentResponse, cls: type[IO.ComfyNode] | None = None
) -> InputImpl.VideoFromFile:
parts = get_parts_by_type(response, "video/*")
for part in parts:
if part.inlineData and part.inlineData.data:
return InputImpl.VideoFromFile(BytesIO(base64.b64decode(part.inlineData.data)))
if part.fileData and part.fileData.fileUri:
return await download_url_to_video_output(part.fileData.fileUri, cls=cls)
model_message = get_text_from_response(response).strip()
if model_message:
raise ValueError(f"Gemini did not generate a video. Model response: {model_message}")
raise ValueError(
"Gemini did not generate a video. Try rephrasing your prompt, "
"shortening the requested duration, or reducing the number of input images/videos."
)
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
if not response.modelVersion:
return None
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"):
output_video_tokens_price = 0.0
if response.modelVersion == "gemini-2.5-pro":
input_tokens_price = 1.25
output_text_tokens_price = 10.0
output_image_tokens_price = 0.0
elif response.modelVersion in (
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-flash",
):
elif response.modelVersion == "gemini-2.5-flash":
input_tokens_price = 0.30
output_text_tokens_price = 2.50
output_image_tokens_price = 0.0
elif response.modelVersion in (
"gemini-2.5-flash-image-preview",
"gemini-2.5-flash-image",
):
elif response.modelVersion == "gemini-2.5-flash-image":
input_tokens_price = 0.30
output_text_tokens_price = 2.50
output_image_tokens_price = 30.0
@ -265,18 +270,27 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
input_tokens_price = 2
output_text_tokens_price = 12.0
output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
elif response.modelVersion in ("gemini-3.1-flash-lite-preview", "gemini-3.1-flash-lite"):
input_tokens_price = 0.25
output_text_tokens_price = 1.50
output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3-pro-image-preview":
elif response.modelVersion in ("gemini-3-pro-image-preview", "gemini-3-pro-image"):
input_tokens_price = 2
output_text_tokens_price = 12.0
output_image_tokens_price = 120.0
elif response.modelVersion == "gemini-3.1-flash-image-preview":
elif response.modelVersion in ("gemini-3.1-flash-image-preview", "gemini-3.1-flash-image"):
input_tokens_price = 0.5
output_text_tokens_price = 3.0
output_image_tokens_price = 60.0
elif response.modelVersion == "gemini-3.1-flash-lite-image":
input_tokens_price = 0.25
output_text_tokens_price = 1.50
output_image_tokens_price = 30.0
elif response.modelVersion == "gemini-omni-flash-preview":
input_tokens_price = 2.145
output_text_tokens_price = 12.87
output_image_tokens_price = 0.0
output_video_tokens_price = 25.025
else:
return None
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
@ -284,6 +298,8 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
for i in response.usageMetadata.candidatesTokensDetails:
if i.modality == Modality.IMAGE:
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
elif i.modality == Modality.VIDEO:
final_price += output_video_tokens_price * i.tokenCount # for Omni Flash
else:
final_price += output_text_tokens_price * i.tokenCount
if response.usageMetadata.thoughtsTokenCount:
@ -455,8 +471,6 @@ class GeminiNode(IO.ComfyNode):
IO.Combo.Input(
"model",
options=[
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-3-pro-preview",
@ -904,8 +918,7 @@ class GeminiImage(IO.ComfyNode):
),
IO.Combo.Input(
"model",
options=GeminiImageModel,
default=GeminiImageModel.gemini_2_5_flash_image,
options=["gemini-2.5-flash-image"],
tooltip="The Gemini model to use for generating responses.",
),
IO.Int.Input(
@ -1321,7 +1334,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
)
def _nano_banana_2_v2_model_inputs():
def _nano_banana_2_v2_model_inputs(resolutions: list[str]):
return [
IO.Combo.Input(
"aspect_ratio",
@ -1348,8 +1361,8 @@ def _nano_banana_2_v2_model_inputs():
),
IO.Combo.Input(
"resolution",
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
options=resolutions,
tooltip="Target output resolution.",
),
IO.Combo.Input(
"thinking_level",
@ -1395,7 +1408,11 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
options=[
IO.DynamicCombo.Option(
"Nano Banana 2 (Gemini 3.1 Flash Image)",
_nano_banana_2_v2_model_inputs(),
_nano_banana_2_v2_model_inputs(resolutions=["1K", "2K", "4K"]),
),
IO.DynamicCombo.Option(
"Nano Banana 2 Lite",
_nano_banana_2_v2_model_inputs(resolutions=["1K"]),
),
],
),
@ -1464,9 +1481,13 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
expr="""
(
$r := $lookup(widgets, "model.resolution");
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
$contains(widgets.model, "lite")
? {"type":"usd","usd": 0.034, "format":{"suffix":"/Image","approximate":true}}
: (
$r := $lookup(widgets, "model.resolution");
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
)
)
""",
),
@ -1487,6 +1508,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
model_choice = model["model"]
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model_id = "gemini-3.1-flash-image-preview"
elif model_choice == "Nano Banana 2 Lite":
model_id = "gemini-3.1-flash-lite-image"
else:
model_id = model_choice
@ -1536,6 +1559,149 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
)
OMNI_MAX_IMAGES = 14
OMNI_MAX_VIDEOS = 3
OMNI_MODELS: dict[str, str] = {
"Omni Flash": "gemini-omni-flash-preview",
}
def _omni_flash_inputs() -> list[Input]:
"""Per-model inputs for the Omni video DynamicCombo (prompt + reference media + sampling)."""
return [
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describe the video to generate. Specify the length and aspect ratio directly in the "
'prompt, e.g. "a 6-second clip in 16:9". Length may be 3-10 seconds; the aspect ratio must be '
"16:9 (landscape) or 9:16 (portrait). The output is 720p, 24 FPS, with audio.",
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, OMNI_MAX_IMAGES + 1)],
min=0,
),
tooltip=f"Optional reference image(s) to guide or animate the video. Up to {OMNI_MAX_IMAGES} images.",
),
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, OMNI_MAX_VIDEOS + 1)],
min=0,
),
tooltip=f"Optional reference video(s) to guide or edit. Up to {OMNI_MAX_VIDEOS} videos, "
f"each up to 10 seconds long.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more varied.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
advanced=True,
),
]
class GeminiVideoOmni(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiVideoOmni",
display_name="Google Gemini Omni (Video)",
category="partner/video/Gemini",
essentials_category="Video Generation",
description="Generate a video with audio from a text prompt using Google's Gemini Omni Flash model. "
"Optionally provide reference images and/or videos to guide or edit the result. Describe the desired "
"length (3-10s) and aspect ratio (16:9 or 9:16) directly in the prompt.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Omni Flash", _omni_flash_inputs()),
],
tooltip="The Gemini video model used to generate the video.",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr='{"type":"usd","usd":0.146,"format":{"suffix":"/second","approximate":true}}'
),
)
@classmethod
async def execute(cls, model: dict, seed: int) -> IO.NodeOutput:
prompt = model.get("prompt") or ""
validate_string(prompt, strip_whitespace=True, min_length=1)
model_id = OMNI_MODELS[model["model"]]
images = [t for t in (model.get("images") or {}).values() if t is not None]
videos = [v for v in (model.get("videos") or {}).values() if v is not None]
if sum(get_number_of_images(t) for t in images) > OMNI_MAX_IMAGES:
raise ValueError(f"The current maximum number of supported images is {OMNI_MAX_IMAGES}.")
if len(videos) > OMNI_MAX_VIDEOS:
raise ValueError(f"The current maximum number of supported videos is {OMNI_MAX_VIDEOS}.")
for video in videos:
validate_video_duration(video, max_duration=10)
parts: list[GeminiPart] = []
if images or videos:
parts.extend(await build_gemini_media_parts(cls, images, [], videos))
parts.append(GeminiPart(text=prompt))
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
data=GeminiGenerateContentRequest(
contents=[GeminiContent(role=GeminiRole.user, parts=parts)],
generationConfig=GeminiGenerationConfig(
responseModalities=["TEXT", "VIDEO"],
temperature=model.get("temperature", 1.0),
topP=model.get("top_p", 0.95),
),
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(
await get_video_from_response(response, cls=cls),
get_text_from_response(response),
)
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1546,6 +1712,7 @@ class GeminiExtension(ComfyExtension):
GeminiImage2,
GeminiNanoBanana2,
GeminiNanoBanana2V2,
GeminiVideoOmni,
GeminiInputFiles,
]

View File

@ -30,7 +30,7 @@ from comfy_api_nodes.util import (
_GROK_VIDEO_MODEL_API_IDS = {
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
"grok-imagine-video-1.5": "grok-imagine-video-1.5",
}
@ -521,8 +521,8 @@ class GrokVideoNode(IO.ComfyNode):
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
tooltip="The resolution of the output video.",
options=["480p", "720p", "1080p"],
tooltip="The resolution of the output video. 1080p is only available for grok-imagine-video-1.5.",
),
IO.Combo.Input(
"aspect_ratio",
@ -570,11 +570,12 @@ class GrokVideoNode(IO.ComfyNode):
(
$is15 := $contains(widgets.model, "1.5");
$rate := $is15
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
? (widgets.resolution = "1080p" ? 0.25 : (widgets.resolution = "720p" ? 0.14 : 0.08))
: (widgets.resolution = "720p" ? 0.07 : 0.05);
$imgCost := $is15 ? 0.0143 : 0.002;
$imgCost := $is15 ? 0.01 : 0.002;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
$total := inputs.image.connected ? $base + $imgCost : $base;
{"type":"usd","usd": $is15 ? $total * 1.43 : $total}
)
""",
),
@ -593,6 +594,8 @@ class GrokVideoNode(IO.ComfyNode):
) -> IO.NodeOutput:
if image is None and model == "grok-imagine-video-1.5":
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
if resolution == "1080p" and model != "grok-imagine-video-1.5":
raise ValueError(f"1080p resolution is only available for grok-imagine-video-1.5, not '{model}'.")
image_url = None
if image is not None:
if get_number_of_images(image) != 1:

View File

@ -60,6 +60,12 @@ from comfy_api_nodes.apis.kling import (
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
Kling3TurboSettings,
Kling3TurboText2VideoRequest,
Kling3TurboContent,
Kling3TurboImage2VideoRequest,
Kling3TurboCreateResponse,
Kling3TurboQueryResponse,
TaskStatusResponse,
TextToVideoWithAudioRequest,
)
@ -2847,6 +2853,67 @@ class MotionControl(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
def build_turbo_shot_prompt(multi_prompt: list[MultiPromptEntry]) -> str:
"""Render storyboard entries into the Turbo multi-shot prompt 'shot n, m, words; ...'."""
return "; ".join(f"shot {i}, {int(e.duration)}, {e.prompt}" for i, e in enumerate(multi_prompt, 1)) + ";"
def _turbo_video_url(response: Kling3TurboQueryResponse) -> str:
"""Extract the result video URL from a /tasks response (data[].outputs[] where type == 'video')."""
task = response.data[0] if response.data else None
if task and task.outputs:
for output in task.outputs:
if output.type == "video" and output.url:
return output.url
raise RuntimeError(f"Kling 3.0 Turbo task finished without a video output: {response.model_dump()}")
async def execute_kling_turbo(
cls: type[IO.ComfyNode],
*,
prompt: str,
resolution: str,
aspect_ratio: str,
duration: int,
start_frame: torch.Tensor | None,
) -> IO.NodeOutput:
"""Create + poll a Kling 3.0 Turbo task. Image-to-video when start_frame is given, else text-to-video."""
if start_frame is not None:
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
contents = [Kling3TurboContent(type="first_frame", url=tensor_to_base64_string(start_frame))]
if prompt:
contents.insert(0, Kling3TurboContent(type="prompt", text=prompt))
create = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/image-to-video/kling-3.0-turbo", method="POST"),
response_model=Kling3TurboCreateResponse,
data=Kling3TurboImage2VideoRequest(
contents=contents,
settings=Kling3TurboSettings(resolution=resolution, duration=duration), # i2v: no aspect_ratio
),
)
else:
create = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/text-to-video/kling-3.0-turbo", method="POST"),
response_model=Kling3TurboCreateResponse,
data=Kling3TurboText2VideoRequest(
prompt=prompt,
settings=Kling3TurboSettings(resolution=resolution, aspect_ratio=aspect_ratio, duration=duration),
),
)
if not (create.data and create.data.id):
raise RuntimeError(f"Kling 3.0 Turbo create failed. Code: {create.code}, Message: {create.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path="/proxy/kling/tasks", query_params={"task_ids": create.data.id}),
response_model=Kling3TurboQueryResponse,
status_extractor=lambda r: (r.data[0].status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(_turbo_video_url(final_response)))
class KlingVideoNode(IO.ComfyNode):
@classmethod
@ -2884,7 +2951,11 @@ class KlingVideoNode(IO.ComfyNode):
],
tooltip="Generate a series of video segments with individual prompts and durations.",
),
IO.Boolean.Input("generate_audio", default=True),
IO.Boolean.Input(
"generate_audio",
default=True,
tooltip="'kling-3.0-turbo' always generates native audio, so the audio toggle is ignored.",
),
IO.DynamicCombo.Input(
"model",
options=[
@ -2899,6 +2970,17 @@ class KlingVideoNode(IO.ComfyNode):
),
],
),
IO.DynamicCombo.Option(
"kling-3.0-turbo",
[
IO.Combo.Input("resolution", options=["1080p", "720p"], default="720p"),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1"],
tooltip="Ignored in image-to-video mode.",
),
],
),
],
tooltip="Model and generation settings.",
),
@ -2930,6 +3012,7 @@ class KlingVideoNode(IO.ComfyNode):
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"model",
"model.resolution",
"generate_audio",
"multi_shot",
@ -2944,14 +3027,7 @@ class KlingVideoNode(IO.ComfyNode):
),
expr="""
(
$rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio);
$ms := widgets.multi_shot;
$isSb := $ms != "disabled";
$n := $isSb ? $number($substring($ms, 0, 1)) : 0;
@ -2962,7 +3038,18 @@ class KlingVideoNode(IO.ComfyNode):
$d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0;
$d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0;
$dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration");
{"type":"usd","usd": $rate * $dur}
widgets.model = "kling-3.0-turbo"
? {"type":"usd","usd": ($res = "1080p" ? 0.14 : 0.112) * $dur}
: (
$rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio);
{"type":"usd","usd": $rate * $dur}
)
)
""",
),
@ -3015,6 +3102,17 @@ class KlingVideoNode(IO.ComfyNode):
duration = multi_shot["duration"]
validate_string(multi_shot["prompt"], min_length=1, max_length=2500)
if model["model"] == "kling-3.0-turbo":
turbo_prompt = build_turbo_shot_prompt(multi_prompt_list) if custom_multi_shot else multi_shot["prompt"]
return await execute_kling_turbo(
cls,
prompt=turbo_prompt,
resolution=model["resolution"],
aspect_ratio=model["aspect_ratio"],
duration=duration,
start_frame=start_frame,
)
if start_frame is not None:
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))

View File

@ -3,9 +3,13 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.luma import (
LUMA_KEYFRAME_MODE_FRACTION,
LUMA_KEYFRAME_MODE_SECONDS,
Luma2Generation,
Luma2GenerationRequest,
Luma2ImageRef,
Luma2VideoEdit,
Luma2VideoOptions,
LumaAspectRatio,
LumaCharacterRef,
LumaConceptChain,
@ -18,6 +22,8 @@ from comfy_api_nodes.apis.luma import (
LumaIO,
LumaKeyframes,
LumaModifyImageRef,
LumaRay32KeyframeChain,
LumaRay32KeyframeItem,
LumaReference,
LumaReferenceChain,
LumaVideoModel,
@ -33,6 +39,7 @@ from comfy_api_nodes.util import (
sync_op,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
)
@ -692,7 +699,10 @@ async def _luma2_upload_image_refs(
async def _luma2_submit_and_poll(
cls: type[IO.ComfyNode],
request: Luma2GenerationRequest,
) -> Input.Image:
*,
estimated_duration: int | None = None,
) -> Luma2Generation:
"""Submit a Luma Agents generation and poll until done; returns the completed generation."""
initial = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
@ -700,21 +710,21 @@ async def _luma2_submit_and_poll(
data=request,
)
if not initial.id:
raise RuntimeError("Luma 2 API did not return a generation id.")
raise RuntimeError("Luma API did not return a generation id.")
final = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
response_model=Luma2Generation,
status_extractor=lambda r: r.state,
progress_extractor=lambda r: None,
estimated_duration=estimated_duration,
)
if not final.output:
if not final.output or not final.output[0].url:
msg = final.failure_reason or "no output returned"
raise RuntimeError(f"Luma 2 generation failed: {msg}")
url = final.output[0].url
if not url:
raise RuntimeError("Luma 2 generation completed without an output URL.")
return await download_url_to_image_tensor(url)
if final.failure_code:
msg = f"{msg} [{final.failure_code}]"
raise RuntimeError(f"Luma generation failed: {msg}")
return final
class LumaImageNode(IO.ComfyNode):
@ -843,7 +853,8 @@ class LumaImageNode(IO.ComfyNode):
web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
)
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
final = await _luma2_submit_and_poll(cls, request)
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
class LumaImageEditNode(IO.ComfyNode):
@ -929,7 +940,533 @@ class LumaImageEditNode(IO.ComfyNode):
web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
)
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
final = await _luma2_submit_and_poll(cls, request)
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
_BADGE_RAY32_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]),
expr="""
(
$p := {
"360p": {"5s": 0.06, "10s": 0.18},
"540p": {"5s": 0.15, "10s": 0.45},
"720p": {"5s": 0.3, "10s": 0.9},
"1080p": {"5s": 1.2, "10s": 3.6}
};
{"type": "usd", "usd": $lookup($lookup($p, widgets.resolution), widgets.duration)}
)
""",
)
_BADGE_RAY32_VIDEO_5S = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {"360p": 0.06, "540p": 0.15, "720p": 0.3, "1080p": 1.2};
{"type": "usd", "usd": $lookup($p, widgets.resolution)}
)
""",
)
_BADGE_RAY32_EDIT = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {
"360p": {"min": 0.54, "max": 1.08},
"540p": {"min": 0.72, "max": 1.44},
"720p": {"min": 1.08, "max": 2.16},
"1080p": {"min": 2.16, "max": 4.32}
};
$r := $lookup($p, widgets.resolution);
{"type": "range_usd", "min_usd": $r.min, "max_usd": $r.max, "format": {"note": "(by source length)"}}
)
""",
)
_BADGE_RAY32_REFRAME = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {"360p": 0.03, "540p": 0.06, "720p": 0.12, "1080p": 0.36};
{"type": "usd", "usd": $lookup($p, widgets.resolution), "format": {"suffix": "/second"}}
)
""",
)
def _ray32_seed_input() -> IO.Input:
return IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; results are nondeterministic regardless of seed.",
)
async def _ray32_generate(cls: type[IO.ComfyNode], request: Luma2GenerationRequest) -> IO.NodeOutput:
"""Run a ray-3.2 generation and return (video, generation_id)."""
final = await _luma2_submit_and_poll(cls, request, estimated_duration=120)
video = await download_url_to_video_output(final.output[0].url)
return IO.NodeOutput(video, final.id or "")
class LumaRay32TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32TextToVideoNode",
display_name="Luma Ray 3.2 Text to Video",
category="partner/video/Luma",
description="Generate a video from a text prompt using Luma's Ray 3.2 model.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input("duration", options=["5s", "10s"]),
IO.Boolean.Input(
"loop",
default=False,
tooltip="Make the video loop seamlessly. Only available with 5s duration.",
),
_ray32_seed_input(),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO,
)
@classmethod
async def execute(
cls, prompt: str, aspect_ratio: str, resolution: str, duration: str, loop: bool, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
if loop and duration == "10s":
raise ValueError("Looping is only available with 5s duration on Ray 3.2.")
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video",
aspect_ratio=aspect_ratio,
video=Luma2VideoOptions(resolution=resolution, duration=duration, loop=loop or None),
)
return await _ray32_generate(cls, request)
class LumaRay32ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32ImageToVideoNode",
display_name="Luma Ray 3.2 Image to Video",
category="partner/video/Luma",
description="Generate a video from a start and/or end frame using Luma's Ray 3.2 model. "
"Image-anchored generations are always 5 seconds.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Boolean.Input(
"loop",
default=False,
tooltip="Make the video loop seamlessly. Not available when an end_frame is set.",
),
_ray32_seed_input(),
IO.Image.Input("start_frame", optional=True, tooltip="First frame of the generated video."),
IO.Image.Input("end_frame", optional=True, tooltip="Last frame of the generated video."),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO_5S,
)
@classmethod
async def execute(
cls,
prompt: str,
resolution: str,
loop: bool,
seed: int,
start_frame: torch.Tensor | None = None,
end_frame: torch.Tensor | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
if start_frame is None and end_frame is None:
raise ValueError("Provide at least one of start_frame / end_frame.")
if loop and end_frame is not None:
raise ValueError("Looping is not available when an end_frame is set.")
video = Luma2VideoOptions(resolution=resolution, duration="5s", loop=loop or None)
if start_frame is not None:
url = await upload_image_to_comfyapi(cls, start_frame, mime_type="image/png")
video.start_frame = Luma2ImageRef(url=url)
if end_frame is not None:
url = await upload_image_to_comfyapi(cls, end_frame, mime_type="image/png")
video.end_frame = Luma2ImageRef(url=url)
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
return await _ray32_generate(cls, request)
class LumaRay32KeyframeNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32KeyframeNode",
display_name="Luma Ray 3.2 Keyframe",
category="partner/video/Luma",
description="Anchor a guide image to a position on the Ray 3.2 output video timeline. Connect this to "
"the 'keyframes' input of the Luma Ray 3.2 Keyframes to Video node; chain several together via the "
"optional 'keyframes' input below.",
inputs=[
IO.Image.Input("image", tooltip="Guide image to place at the chosen moment of the output video."),
IO.DynamicCombo.Input(
"position",
options=[
IO.DynamicCombo.Option(
"Fraction of duration (0.0-1.0)",
[
IO.Float.Input(
"fraction",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
tooltip="Where in the output video this image applies " "(0.0 = start, 1.0 = end).",
),
],
),
IO.DynamicCombo.Option(
"Absolute time (seconds)",
[
IO.Float.Input(
"seconds",
default=0.0,
min=0.0,
max=10.0,
step=0.1,
display_mode=IO.NumberDisplay.number,
tooltip="Time in seconds from the start of the output video where this "
"image applies.",
),
],
),
],
tooltip="How to place this image on the output video's timeline.",
),
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
"keyframes",
optional=True,
tooltip="Optional earlier keyframes to chain with this one.",
),
],
outputs=[IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Output(display_name="keyframes")],
)
@classmethod
def execute(
cls,
image: torch.Tensor,
position: dict,
keyframes: LumaRay32KeyframeChain | None = None,
) -> IO.NodeOutput:
chain = keyframes.clone() if keyframes is not None else LumaRay32KeyframeChain()
if position["position"] == "Absolute time (seconds)":
mode, value = LUMA_KEYFRAME_MODE_SECONDS, float(position["seconds"])
else:
mode, value = LUMA_KEYFRAME_MODE_FRACTION, float(position["fraction"])
chain.add(LumaRay32KeyframeItem(image=image, mode=mode, value=value))
return IO.NodeOutput(chain)
class LumaRay32KeyframesToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32KeyframesToVideoNode",
display_name="Luma Ray 3.2 Keyframes to Video",
category="partner/video/Luma",
description="Generate a video that interpolates through a sequence of guide images, each anchored to a "
"position on the timeline, using Luma Ray 3.2. Build the sequence with Luma Ray 3.2 Keyframe nodes "
"(at least 2).",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input("duration", options=["5s", "10s"]),
_ray32_seed_input(),
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
"keyframes",
tooltip="Keyframe sequence from Luma Ray 3.2 Keyframe nodes (at least 2).",
),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO,
)
@classmethod
async def execute(
cls,
prompt: str,
resolution: str,
duration: str,
seed: int,
keyframes: LumaRay32KeyframeChain | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
items = keyframes.items if keyframes is not None else []
if len(items) < 2:
raise ValueError(
"Connect at least 2 Luma Ray 3.2 Keyframe nodes "
"(use Luma Ray 3.2 Image to Video for a single start/end frame)."
)
if len(items) > 64:
raise ValueError(f"Ray 3.2 supports at most 64 keyframes; got {len(items)}.")
maxframe = 120 if duration == "5s" else 240
duration_seconds = maxframe / 24 # 5.0 or 10.0
# Resolve each keyframe to an output-frame index, then order by position
# (so the user can chain keyframes in any order — the position is what places them)
placed: list[tuple[int, torch.Tensor]] = []
for item in items:
if item.mode == LUMA_KEYFRAME_MODE_SECONDS:
if item.value > duration_seconds:
raise ValueError(
f"Keyframe position {item.value:g}s is past the end of the {duration} video; "
f"use 0-{duration_seconds:g}s (or switch the keyframe to fraction mode)."
)
idx = round(item.value * 24)
else:
idx = round(item.value * maxframe)
placed.append((max(0, min(maxframe, idx)), item.image))
placed.sort(key=lambda p: p[0])
indexes = [idx for idx, _ in placed]
for a, b in zip(indexes, indexes[1:]):
if a == b:
raise ValueError(
f"Two keyframes resolve to the same output frame ({a}) for a {duration} video "
f"(valid range 0-{maxframe}); give each keyframe a distinct position."
)
refs: list[Luma2ImageRef] = []
for _, image in placed:
url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
refs.append(Luma2ImageRef(url=url))
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video",
video=Luma2VideoOptions(resolution=resolution, duration=duration, keyframes=refs, keyframe_indexes=indexes),
)
return await _ray32_generate(cls, request)
class LumaRay32VideoEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32VideoEditNode",
display_name="Luma Ray 3.2 Video Edit",
category="partner/video/Luma",
description="Re-render an existing video under a new prompt using Luma Ray 3.2 (restyle, relight, add "
"or remove elements) while keeping the original motion. Source video up to 18 seconds; the edited "
"video keeps the source's length.",
inputs=[
IO.Video.Input("video", tooltip="Source video to edit. Up to 18 seconds."),
IO.String.Input("prompt", multiline=True, default="", tooltip="Describes the desired edit."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input(
"strength",
options=[
"auto",
"adhere_1",
"adhere_2",
"adhere_3",
"flex_1",
"flex_2",
"flex_3",
"reimagine_1",
"reimagine_2",
"reimagine_3",
],
default="auto",
tooltip="How strongly to preserve vs. reimagine the source. 'auto' lets Ray 3.2 choose; "
"adhere_* preserves the most, flex_* is balanced, reimagine_* changes the most.",
),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_EDIT,
)
@classmethod
async def execute(
cls, video: Input.Video, prompt: str, resolution: str, strength: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
try:
duration = "5s" if video.get_duration() <= 5.0 else "10s"
except Exception:
duration = "10s"
source_url = await upload_video_to_comfyapi(cls, video, max_duration=18)
edit = Luma2VideoEdit(auto_controls=True) if strength == "auto" else Luma2VideoEdit(strength=strength)
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video_edit",
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
video=Luma2VideoOptions(resolution=resolution, duration=duration, edit=edit),
)
return await _ray32_generate(cls, request)
class LumaRay32VideoReframeNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32VideoReframeNode",
display_name="Luma Ray 3.2 Video Reframe",
category="partner/video/Luma",
description="Change the aspect ratio of an existing video, using Luma Ray 3.2 to fill the newly "
"exposed canvas areas. Source video up to 30 seconds. Billed per second of output.",
inputs=[
IO.Video.Input("video", tooltip="Source video to reframe. Up to 30 seconds."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describes how the newly exposed canvas areas should be filled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_REFRAME,
)
@classmethod
async def execute(
cls, video: Input.Video, prompt: str, aspect_ratio: str, resolution: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
if resolution == "1080p" and aspect_ratio in {"9:16", "3:4"}:
raise ValueError("1080p is not available for vertical aspect ratios (9:16, 3:4) when reframing.")
source_url = await upload_video_to_comfyapi(cls, video, max_duration=30)
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video_reframe",
aspect_ratio=aspect_ratio,
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
video=Luma2VideoOptions(resolution=resolution),
)
return await _ray32_generate(cls, request)
class LumaRay32ExtendVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32ExtendVideoNode",
display_name="Luma Ray 3.2 Extend Video",
category="partner/video/Luma",
description="Extend a previous Ray 3.2 generation forward (continue after it) or backward (lead-in "
"before it). Connect the generation_id output of a prior Luma Ray 3.2 node."
" Extensions are always 5 seconds.",
inputs=[
IO.String.Input(
"source_generation_id",
default="",
tooltip="generation_id of the prior Ray 3.2 video to extend."
" Connect the generation_id output of another Luma Ray 3.2 node.",
),
IO.DynamicCombo.Input(
"direction",
options=[
IO.DynamicCombo.Option(
"Forward (continue after)",
[
IO.Boolean.Input(
"loop",
default=False,
tooltip="Loop the extended video seamlessly (forward extend only).",
),
],
),
IO.DynamicCombo.Option("Backward (lead-in before)", []),
],
tooltip="Forward continues after the prior clip; backward is prepended before it.",
),
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the new content."),
IO.Combo.Input("resolution", options=["540p", "720p", "1080p"], default="720p"),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO_5S,
)
@classmethod
async def execute(
cls, source_generation_id: str, direction: dict, prompt: str, resolution: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
gen_id = (source_generation_id or "").strip()
if not gen_id:
raise ValueError(
"source_generation_id is required (connect the generation_id output of a prior Luma Ray 3.2 node)."
)
video = Luma2VideoOptions(resolution=resolution, duration="5s")
ref = Luma2ImageRef(generation_id=gen_id)
if direction["direction"] == "Forward (continue after)":
video.start_frame = ref
if direction.get("loop"):
video.loop = True
else:
video.end_frame = ref
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
return await _ray32_generate(cls, request)
class LumaExtension(ComfyExtension):
@ -944,6 +1481,13 @@ class LumaExtension(ComfyExtension):
LumaConceptsNode,
LumaImageNode,
LumaImageEditNode,
LumaRay32TextToVideoNode,
LumaRay32ImageToVideoNode,
LumaRay32KeyframeNode,
LumaRay32KeyframesToVideoNode,
LumaRay32VideoEditNode,
LumaRay32VideoReframeNode,
LumaRay32ExtendVideoNode,
]

View File

@ -48,10 +48,13 @@ from comfy_api_nodes.util import (
upload_image_to_comfyapi,
upload_video_to_comfyapi,
validate_audio_duration,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
validate_video_duration,
)
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
@ -1657,6 +1660,44 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.1-t2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. "
"Supports English and Chinese.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Combo.Input(
"ratio",
options=[
"16:9",
"9:16",
"1:1",
"4:3",
"3:4",
"21:9",
"9:21",
"5:4",
"4:5",
],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
],
),
IO.DynamicCombo.Option(
"happyhorse-1.0-t2v",
[
@ -1719,7 +1760,9 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$ppsTable := $contains(widgets.model, "1.1")
? { "720p": 0.2002, "1080p": 0.2574 }
: { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
@ -1781,6 +1824,30 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.1-i2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. "
"Supports English and Chinese.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
],
),
IO.DynamicCombo.Option(
"happyhorse-1.0-i2v",
[
@ -1843,7 +1910,9 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$ppsTable := $contains(widgets.model, "1.1")
? { "720p": 0.2002, "1080p": 0.2574 }
: { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
@ -1859,6 +1928,8 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
seed: int,
watermark: bool,
):
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1), strict=False)
media = [
Wan27MediaItem(
type="first_frame",
@ -2053,6 +2124,62 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.1-r2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the video. Use identifiers such as 'character1' and "
"'character2' to refer to the reference characters.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Combo.Input(
"ratio",
options=[
"16:9",
"9:16",
"1:1",
"4:3",
"3:4",
"21:9",
"9:21",
"5:4",
"4:5",
],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_image"),
names=[
"image1",
"image2",
"image3",
"image4",
"image5",
"image6",
"image7",
"image8",
"image9",
],
min=1,
),
),
],
),
IO.DynamicCombo.Option(
"happyhorse-1.0-r2v",
[
@ -2133,7 +2260,9 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$ppsTable := $contains(widgets.model, "1.1")
? { "720p": 0.2002, "1080p": 0.2574 }
: { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
@ -2149,8 +2278,11 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
watermark: bool,
):
validate_string(model["prompt"], strip_whitespace=False, min_length=1)
media = []
reference_images = model.get("reference_images", {})
for key in reference_images:
validate_image_dimensions(reference_images[key], min_width=400, min_height=400)
validate_image_aspect_ratio(reference_images[key], (1, 2.5), (2.5, 1), strict=False)
media = []
for key in reference_images:
media.append(
Wan27MediaItem(
@ -2159,7 +2291,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
)
)
if not media:
raise ValueError("At least one reference reference image must be provided.")
raise ValueError("At least one reference image must be provided.")
initial_response = await sync_op(
cls,

View File

@ -4,6 +4,8 @@ import os
import re
import time
from collections.abc import Callable
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from io import BytesIO
from yarl import URL
@ -91,6 +93,32 @@ async def sleep_with_interrupt(
await asyncio.sleep(min(1.0, end - now))
def _retry_after_wait(value: str | None, fallback: float, max_wait: float) -> float:
"""Delay before the next retry, honoring a server ``Retry-After`` header."""
seconds: float | None = None
if value is not None:
value = value.strip()
if value.isascii() and value.isdigit():
# delay-seconds form. The ASCII-digit guard keeps exotic Unicode "digit" characters away from float()
# an all-digit string always converts (huge values become inf, never raising).
seconds = float(value)
elif value:
# HTTP-date form. parsedate_to_datetime raises OverflowError (not a ValueError) on absurd years/offsets
try:
parsed = parsedate_to_datetime(value)
except (TypeError, ValueError, OverflowError):
parsed = None
if parsed is not None:
if parsed.tzinfo is None: # naive datetime: HTTP-date is UTC
parsed = parsed.replace(tzinfo=timezone.utc)
delta = (parsed - datetime.now(timezone.utc)).total_seconds()
seconds = delta if delta > 0 else 0.0
if seconds is None:
return fallback
return min(seconds, max_wait)
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()

View File

@ -21,6 +21,7 @@ from server import PromptServer
from . import request_logger
from ._helpers import (
_retry_after_wait,
default_base_url,
get_comfy_api_headers,
get_node_id,
@ -82,6 +83,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
_MAX_RETRY_AFTER_WAIT = 150.0 # Cap a server Retry-After at this many seconds so a large hint can't block execution
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
@ -747,6 +749,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
should_retry = True
if should_retry:
wait_time = _retry_after_wait(resp.headers.get("Retry-After"), wait_time, _MAX_RETRY_AFTER_WAIT)
logging.warning(
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
method,

View File

@ -4,11 +4,22 @@ Provides normalization and helper functions for job status tracking.
"""
import uuid
from typing import Optional
from typing import Callable, Optional
from comfy_api.internal import prune_dict
# Result of classifying a job for cancellation.
# 'running' -> job is currently executing (interrupt it)
# 'pending' -> job is queued but not started (dequeue it)
# 'terminal' -> job already finished (present in history); cancel is a no-op
# 'unknown' -> job id is not present anywhere
CANCEL_RUNNING = 'running'
CANCEL_PENDING = 'pending'
CANCEL_TERMINAL = 'terminal'
CANCEL_UNKNOWN = 'unknown'
class JobStatus:
"""Job status constants."""
PENDING = 'pending'
@ -407,3 +418,71 @@ def get_all_jobs(
jobs = jobs[:limit]
return (jobs, total_count)
def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str:
"""Classify a job id for cancellation.
Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN.
Queue items are tuples whose second element (index 1) is the prompt_id.
History is a dict keyed by prompt_id, so a job present there has already
finished and cancelling it is a no-op.
"""
for item in running:
if item[1] == prompt_id:
return CANCEL_RUNNING
for item in queued:
if item[1] == prompt_id:
return CANCEL_PENDING
if prompt_id in history:
return CANCEL_TERMINAL
return CANCEL_UNKNOWN
def cancel_job(
prompt_id: str,
running: list,
queued: list,
history: dict,
interrupt: Callable[[str], bool],
dequeue: Callable[[str], bool],
) -> str:
"""Cancel a single job by id, regardless of state.
Maps the cancel onto the runtime's existing mechanics:
- a running job is interrupted via ``interrupt``
- a pending job is removed from the queue via ``dequeue``
- a job that already finished (terminal) is a no-op
- an unknown id is a no-op (callers that need fail-fast behaviour should
validate ids up front with ``classify_job_for_cancel``)
Both ``interrupt`` and ``dequeue`` take the prompt id and return whether
they acted on a job that was *actually* in that state, so the value returned
here reflects what truly happened rather than the (possibly stale)
classification. This matters around the narrow TOCTOU windows where a job
changes state between the caller's snapshot and the action:
- a job classified RUNNING may have finished before ``interrupt`` fires:
``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op).
- a job classified PENDING may have started executing before ``dequeue``
fires: ``dequeue`` returns False, ``interrupt`` then catches the now-
running job and this returns CANCEL_RUNNING. If it had simply finished
instead, both return False and this returns CANCEL_UNKNOWN.
``interrupt`` must be atomic interrupt the job only if it is still the one
running so a cancel can never land on an unrelated prompt that started in
the meantime (see ``execution.PromptQueue.interrupt_if_running``).
"""
classification = classify_job_for_cancel(prompt_id, running, queued, history)
if classification == CANCEL_RUNNING:
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
if classification == CANCEL_PENDING:
if dequeue(prompt_id):
return CANCEL_PENDING
# Left the pending queue between classification and dequeue: if it
# started executing, interrupt the now-running job; otherwise it has
# already finished and the cancel is a genuine no-op.
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
return classification

View File

@ -0,0 +1,23 @@
def hex_to_rgb(value: str) -> tuple[int, int, int]:
h = value.lstrip("#")
if len(h) != 6:
return (255, 255, 255)
try:
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
except ValueError:
return (255, 255, 255)
def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]:
r, g, b = rgb
lum = 0.299 * r + 0.587 * g + 0.114 * b
if lum >= 130:
return (r, g, b)
t = (130 - lum) / (255 - lum)
return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t))
def normalize_palette(colors) -> list[str]:
if isinstance(colors, dict):
colors = colors.values()
return [c.upper() for c in colors if isinstance(c, str) and c]

View File

@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudio",
search_aliases=["export flac"],
display_name="Save Audio (FLAC) (Deprecated)",
display_name="Save Audio (FLAC) (DEPRECATED)",
category="audio",
essentials_category="Audio",
inputs=[
@ -166,8 +166,9 @@ class SaveAudio(IO.ComfyNode):
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
is_deprecated=True,
is_output_node=True,
outputs=[IO.Audio.Output("audio")]
)
@classmethod
@ -175,11 +176,10 @@ class SaveAudio(IO.ComfyNode):
if audio is None:
raise ValueError("SaveAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
audio,
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
)
save_flac = execute # TODO: remove
class SaveAudioMP3(IO.ComfyNode):
@classmethod
@ -187,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudioMP3",
search_aliases=["export mp3"],
display_name="Save Audio (MP3) (Deprecated)",
display_name="Save Audio (MP3) (DEPRECATED)",
category="audio",
essentials_category="Audio",
inputs=[
@ -196,8 +196,9 @@ class SaveAudioMP3(IO.ComfyNode):
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
is_deprecated=True,
is_output_node=True,
outputs=[IO.Audio.Output("audio")]
)
@classmethod
@ -205,13 +206,12 @@ class SaveAudioMP3(IO.ComfyNode):
if audio is None:
raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
audio,
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
)
)
save_mp3 = execute # TODO: remove
class SaveAudioOpus(IO.ComfyNode):
@classmethod
@ -219,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudioOpus",
search_aliases=["export opus"],
display_name="Save Audio (Opus) (Deprecated)",
display_name="Save Audio (Opus) (DEPRECATED)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
@ -227,8 +227,9 @@ class SaveAudioOpus(IO.ComfyNode):
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
is_deprecated=True,
is_output_node=True,
outputs=[IO.Audio.Output("audio")]
)
@classmethod
@ -236,13 +237,12 @@ class SaveAudioOpus(IO.ComfyNode):
if audio is None:
raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
audio,
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
)
)
save_opus = execute # TODO: remove
class SaveAudioAdvanced(IO.ComfyNode):
@classmethod
@ -258,10 +258,7 @@ class SaveAudioAdvanced(IO.ComfyNode):
IO.String.Input(
"filename_prefix",
default="audio/ComfyUI",
tooltip=(
"The prefix for the file to save. May include formatting tokens "
"such as %date:yyyy-MM-dd%."
),
tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd%."),
),
IO.DynamicCombo.Input(
"format",
@ -279,6 +276,7 @@ class SaveAudioAdvanced(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[IO.Audio.Output("audio")],
)
@classmethod
@ -289,7 +287,7 @@ class SaveAudioAdvanced(IO.ComfyNode):
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality)
else:
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format)
return IO.NodeOutput(ui=ui)
return IO.NodeOutput(audio, ui=ui)
class PreviewAudio(IO.ComfyNode):
@ -305,13 +303,14 @@ class PreviewAudio(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[IO.Audio.Output("audio")]
)
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
if audio is None:
raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
return IO.NodeOutput(audio, ui=UI.PreviewAudio(audio, cls=cls))
save_flac = execute # TODO: remove

View File

@ -0,0 +1,97 @@
import math
import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class TextEncodeBooguEdit(io.ComfyNode):
"""Boogu-Image Edit conditioning.
The edit image is used twice, matching the reference pipeline:
- Qwen3-VL vision tokens (instruction understanding) -> positive only
- VAE reference latent (image identity) -> positive and negative
The ref latent is in both conds so it cancels under CFG (identity preserved);
the vision tokens are only in the positive so CFG amplifies the instruction.
The tokenizer selects the right system prompt automatically (image -> TI2I,
empty negative -> DROP), so no template plumbing is needed here.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeBooguEdit",
category="model/conditioning/boogu",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.String.Input("negative_prompt", multiline=True, dynamic_prompts=True, advanced=True),
io.Vae.Input("vae"),
io.Autogrow.Input(
"images",
template=io.Autogrow.TemplateNames(
io.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 17)],
min=0,
),
tooltip="Reference image(s) to edit. Boogu focuses on one reference per sample; more are allowed.",
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, clip, prompt, negative_prompt, vae=None, images: io.Autogrow.Type = None) -> io.NodeOutput:
ref_latents = []
images_vl = []
images = images or {}
for name in sorted(images, key=lambda n: int(n.rsplit("_", 1)[-1])):
image = images[name]
if image is None:
continue
samples = image.movedim(-1, 1)
# Vision tower input: the reference caps the VLM image at 384x384
# (max_vlm_input_pil_pixels in pipeline_boogu.py).
total = int(384 * 384)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
images_vl.append(s.movedim(1, -1)[:, :, :, :3])
# Reference latent: align to 16 px (VAE /8 * patch_size 2).
if vae is not None:
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 16.0) * 16
height = round(samples.shape[2] * scale_by / 16.0) * 16
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3]))
# positive: instruction + vision tokens; negative: empty (no vision). Ref latent on both.
positive = clip.encode_from_tokens_scheduled(clip.tokenize(prompt, images=images_vl))
negative = clip.encode_from_tokens_scheduled(clip.tokenize(negative_prompt))
if len(ref_latents) > 0:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True)
return io.NodeOutput(positive, negative)
class BooguExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeBooguEdit,
]
async def comfy_entrypoint() -> BooguExtension:
return BooguExtension()

View File

@ -0,0 +1,253 @@
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color
_PREVIEW_LONG_EDGE = 1024
_PREVIEW_DIM = 0.25
def pixels_to_fractions(box: dict, width: int, height: int) -> dict:
w = width or 1
h = height or 1
return {
"x": box.get("x", 0) / w,
"y": box.get("y", 0) / h,
"w": box.get("width", 0) / w,
"h": box.get("height", 0) / h,
}
def fractions_to_pixels(box: dict, width: int, height: int) -> dict:
x, y = box.get("x", 0.0), box.get("y", 0.0)
w, h = box.get("w", 0.0), box.get("h", 0.0)
if w < 0:
x, w = x + w, -w
if h < 0:
y, h = y + h, -h
return {
"x": round(x * width),
"y": round(y * height),
"width": round(w * width),
"height": round(h * height),
}
def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list:
pixels = [
fractions_to_pixels(box, width, height)
for box in boxes
if isinstance(box, dict)
]
return [pixels] if pixels else []
def _font(size: int):
try:
return ImageFont.load_default(size)
except Exception:
return ImageFont.load_default()
def _wrap(draw, text: str, font, max_w: float) -> list[str]:
lines = []
for para in text.split("\n"):
line = ""
for word in para.split():
test = word if not line else line + " " + word
if line and draw.textlength(test, font=font) > max_w:
lines.append(line)
line = word
else:
line = test
lines.append(line)
return lines
def _bg_from_image(image) -> Image.Image | None:
if image is None:
return None
try:
arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(arr)
except Exception:
return None
def render_preview(regions, width, height, bg=None):
if bg is not None:
iw, ih = bg.size
long_edge = max(iw, ih) or 1
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale))
base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS)
base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM)
img = base.convert("RGBA")
else:
long_edge = max(width, height) or 1
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
rw, rh = max(1, round(width * scale)), max(1, round(height * scale))
grey = round(_PREVIEW_DIM * 128)
img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255))
overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
fs = max(10, round(rh / 64))
font = _font(fs)
tag_font = _font(max(9, fs - 2))
line_h = fs + 2
for i, region in enumerate(regions):
if not isinstance(region, dict):
continue
palette = [c for c in (region.get("palette") or []) if c]
r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140)
x1 = max(0, min(rw, round(region.get("x", 0) * rw)))
y1 = max(0, min(rh, round(region.get("y", 0) * rh)))
x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw)))
y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh)))
if x2 < x1:
x1, x2 = x2, x1
if y2 < y1:
y1, y2 = y2, y1
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2)
swatches = palette[:5]
if swatches and (x2 - x1) > 2:
sh = max(5, fs // 2)
seg = (x2 - x1) / len(swatches)
for p, hexc in enumerate(swatches):
sx = x1 + round(p * seg)
draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc))
etype = "text" if region.get("type") == "text" else "obj"
tag = str(i + 1).zfill(2)
tw = draw.textlength(tag, font=tag_font)
draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255))
tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255)
draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font)
body = region.get("desc", "") or ""
if etype == "text" and region.get("text"):
body = '"%s"%s' % (region["text"], "" + body if body else "")
if body and (x2 - x1) > 8:
ty = y1 + fs + 5
for line in _wrap(draw, body, font, x2 - x1 - 8):
if ty > y2:
break
draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font)
ty += line_h
composed = Image.alpha_composite(img, overlay).convert("RGB")
arr = np.asarray(composed, dtype=np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0)
def boxes_to_regions(boxes, width: int, height: int) -> list:
regions: list = []
if not isinstance(boxes, list):
return regions
for box in boxes:
if not isinstance(box, dict):
continue
meta = box.get("metadata")
meta = meta if isinstance(meta, dict) else {}
regions.append({
**pixels_to_fractions(box, width, height),
"type": meta.get("type", "obj"),
"text": meta.get("text", ""),
"desc": meta.get("desc", ""),
"palette": meta.get("palette", []),
})
return regions
def _norm_bbox(region: dict) -> list[int]:
def grid(value: float) -> int:
return max(0, min(1000, round(value * 1000)))
x, y = region.get("x", 0.0), region.get("y", 0.0)
w, h = region.get("w", 0.0), region.get("h", 0.0)
ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w)
if ymin > ymax:
ymin, ymax = ymax, ymin
if xmin > xmax:
xmin, xmax = xmax, xmin
return [ymin, xmin, ymax, xmax]
def build_elements(regions: list) -> list:
elements = []
for region in regions:
if not isinstance(region, dict):
continue
etype = "text" if region.get("type") == "text" else "obj"
element = {"type": etype}
element["bbox"] = _norm_bbox(region)
if etype == "text":
element["text"] = region.get("text", "")
element["desc"] = region.get("desc", "")
palette = normalize_palette(region.get("palette", []))
if palette:
element["color_palette"] = palette[:5]
elements.append(element)
return elements
class CreateBoundingBoxes(io.ComfyNode):
@classmethod
def define_schema(cls):
editor_state = io.BoundingBoxes.Input(
"editor_state",
socketless=False,
tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.",
)
return io.Schema(
node_id="CreateBoundingBoxes",
display_name="Create Bounding Boxes",
category="utilities",
description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.",
inputs=[
io.Image.Input(
"background",
optional=True,
tooltip="Optional image used as background in the canvas and preview.",
),
io.Int.Input("width", default=1024, min=64, max=16384, step=16,
tooltip="Width of the canvas and the pixel grid for the bounding boxes."),
io.Int.Input("height", default=1024, min=64, max=16384, step=16,
tooltip="Height of the canvas and the pixel grid for the bounding boxes."),
editor_state,
],
outputs=[
io.Image.Output(display_name="preview"),
io.BoundingBox.Output(display_name="bboxes"),
io.Array.Output(display_name="elements"),
],
is_experimental=True,
)
@classmethod
def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput:
regions = boxes_to_regions(editor_state, width, height)
preview = render_preview(regions, width, height, _bg_from_image(background))
return io.NodeOutput(
preview,
fractions_to_bbox_frame(regions, width, height),
build_elements(regions),
ui={"dims": [width, height]},
)
class BoundingBoxesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [CreateBoundingBoxes]
async def comfy_entrypoint() -> BoundingBoxesExtension:
return BoundingBoxesExtension()

View File

@ -1,5 +1,6 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import hex_to_rgb
class ColorToRGBInt(io.ComfyNode):
@ -24,9 +25,11 @@ class ColorToRGBInt(io.ComfyNode):
# expect format #RRGGBB
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB")
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
try:
int(color[1:], 16)
except ValueError:
raise ValueError("Color must be in format #RRGGBB") from None
r, g, b = hex_to_rgb(color)
rgb_int = r * 256 * 256 + g * 256 + b
return io.NodeOutput(rgb_int, color)

View File

@ -8,7 +8,8 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPTextEncodeControlnet",
category="experimental/conditioning",
display_name="CLIP Text Encode (Controlnet)",
category="model/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("conditioning"),
@ -35,11 +36,12 @@ class T5TokenizerOptions(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="T5TokenizerOptions",
category="experimental/conditioning",
display_name="T5 Tokenizer Options",
category="model/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
],
outputs=[io.Clip.Output()],
is_experimental=True,

View File

@ -13,21 +13,22 @@ class ContextWindowsManualNode(io.ComfyNode):
description="Manually set context windows.",
inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."),
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
],
outputs=[
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model:
model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -51,6 +52,7 @@ class ContextWindowsManualNode(io.ComfyNode):
freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows,
latent_retain_index_list=latent_retain_index_list,
causal_window_fix=causal_window_fix,
)
# make memory usage calculation only take into account the context window latents
@ -65,33 +67,71 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
schema = super().define_schema()
schema.node_id = "WanContextWindowsManual"
schema.display_name = "WAN Context Windows (Manual)"
schema.description = "Manually set context windows for WAN-like models (dim=2)."
schema.display_name = "Wan Context Windows"
schema.description = "Set context windows for Wan-like models."
schema.category="model/patch/wan"
schema.inputs = [
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True),
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."),
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."),
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
]
return schema
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(context_overlap // 4, 0) # at least overlap 0
retain_index_list = "0" if retain_first_frame else ""
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
class LTXVContextWindowsNode(ContextWindowsManualNode):
@classmethod
def define_schema(cls) -> io.Schema:
schema = super().define_schema()
schema.node_id = "LTXVContextWindows"
schema.display_name = "LTXV Context Windows"
schema.description = "Set context windows for LTXV-like models."
schema.inputs = [
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."),
io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
]
return schema
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool,
retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1
context_overlap = max(context_overlap // 8, 0) # at least overlap 0
retain_index_list = "0" if retain_first_frame else ""
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise,
cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
class ContextWindowsExtension(ComfyExtension):
@ -99,6 +139,7 @@ class ContextWindowsExtension(ComfyExtension):
return [
ContextWindowsManualNode,
WanContextWindowsManualNode,
LTXVContextWindowsNode,
]
def comfy_entrypoint():

View File

@ -1070,7 +1070,7 @@ class AddNoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="AddNoise",
category="experimental/custom_sampling/noise",
category="model/sampling/noise",
is_experimental=True,
inputs=[
io.Model.Input("model"),
@ -1120,7 +1120,7 @@ class ManualSigmas(io.ComfyNode):
return io.Schema(
node_id="ManualSigmas",
search_aliases=["custom noise schedule", "define sigmas"],
category="experimental/custom_sampling",
category="model/sampling/sigmas",
is_experimental=True,
inputs=[
io.String.Input("sigmas", default="1, 0.5", multiline=False)

View File

@ -1583,7 +1583,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f:
shard_data = torch.load(f)
shard_data = torch.load(f, weights_only=True)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])

View File

@ -77,7 +77,7 @@ class FrameInterpolate(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FrameInterpolate",
display_name="Frame Interpolate",
display_name="Run Frame Interpolation Model",
category="video",
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
inputs=[

View File

@ -1,85 +1,68 @@
import os
import sys
import re
import ctypes
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
import comfy_angle
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
logger.debug("_check_opengl_availability: starting")
missing = []
def _preload_angle():
egl_path = comfy_angle.get_egl_path()
gles_path = comfy_angle.get_glesv2_path()
# Check Python packages (using find_spec to avoid importing)
logger.debug("_check_opengl_availability: checking for glfw package")
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
if sys.platform == "win32":
angle_dir = comfy_angle.get_lib_dir()
os.add_dll_directory(angle_dir)
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
logger.debug("_check_opengl_availability: checking for OpenGL package")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
ctypes.CDLL(str(egl_path), mode=mode)
ctypes.CDLL(str(gles_path), mode=mode)
# Run early check at import time
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
_check_opengl_availability()
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
_preload_angle()
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
import OpenGL
OpenGL.USE_ACCELERATE = False
def _patch_find_library():
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
'libGLESv2'. Patch find_library to return the full ANGLE paths so
PyOpenGL loads the same libraries we pre-loaded."""
if sys.platform == "linux":
return
import ctypes.util
_orig = ctypes.util.find_library
def _patched(name):
if name == 'EGL':
return comfy_angle.get_egl_path()
if name == 'GLESv2':
return comfy_angle.get_glesv2_path()
return _orig(name)
ctypes.util.find_library = _patched
_patch_find_library()
from OpenGL import EGL
from OpenGL import GLES3 as gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
VERTEX_SHADER = """#version 300 es
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
@ -126,14 +109,99 @@ void main() {
"""
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _egl_attribs(*values):
"""Build an EGL_NONE-terminated EGLint attribute array."""
vals = list(values) + [EGL.EGL_NONE]
return (ctypes.c_int32 * len(vals))(*vals)
# EGL platform extension constants
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
_eglGetPlatformDisplayEXT = None
def _get_egl_platform_display_ext(platform, native_display, attribs):
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
global _eglGetPlatformDisplayEXT
if _eglGetPlatformDisplayEXT is None:
from OpenGL import platform as _plat
egl_lib = _plat.PLATFORM.EGL
_get_proc = egl_lib.eglGetProcAddress
_get_proc.restype = ctypes.c_void_p
_get_proc.argtypes = [ctypes.c_char_p]
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
if not ptr:
return None
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
_eglGetPlatformDisplayEXT = func_type(ptr)
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
if not raw:
return None
return ctypes.cast(raw, EGL.EGLDisplay)
def _get_egl_display():
"""Get an EGL display, trying the default first then ANGLE's Vulkan
platform for headless environments without a display server."""
failures = []
# Try the default display first (works when X11/Wayland is available)
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
if display:
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
return display, major.value, minor.value
except Exception as e:
failures.append(f"default: {e}")
logger.info("Default EGL display unavailable, trying headless fallbacks")
# Headless fallback strategies, tried in order:
headless_strategies = [
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
]
for name, platform, native_display, attribs in headless_strategies:
display = _get_egl_platform_display_ext(platform, native_display, attribs)
if not display:
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
continue
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
logger.info(f"Using EGL {name} platform (headless)")
return display, major.value, minor.value
failures.append(f"{name}: eglInitialize returned false")
except Exception as e:
failures.append(f"{name}: {e}")
continue
details = "\n".join(f" - {f}" for f in failures)
raise RuntimeError(
"Failed to initialize EGL display.\n"
"No display server and no headless EGL platform available.\n"
f"Tried:\n{details}\n"
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
)
def _gl_str(name):
"""Get an OpenGL string parameter."""
v = gl.glGetString(name)
if not v:
return "Unknown"
if isinstance(v, bytes):
return v.decode(errors="replace")
return ctypes.string_at(v).decode(errors="replace")
def _detect_output_count(source: str) -> int:
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext:
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) EGL (headless GPU) OSMesa (software).
"""
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
_instance = None
_initialized = False
@ -327,131 +240,105 @@ class GLContext:
def __init__(self):
if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time
start = time.perf_counter()
self._backend = None
self._window = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._display = None
self._surface = None
self._context = None
self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try:
self._window, glfw = _init_glfw()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
self._display, self._egl_major, self._egl_minor = _get_egl_display()
if self._backend is None:
logger.debug("GLContext.__init__: trying EGL backend")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
if self._backend is None:
logger.debug("GLContext.__init__: trying OSMesa backend")
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
logger.debug("GLContext.__init__: OSMesa backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
errors.append(("OSMesa", e))
config = EGL.EGLConfig()
n_configs = ctypes.c_int32(0)
if not EGL.eglChooseConfig(
self._display,
_egl_attribs(
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
),
ctypes.byref(config), 1, ctypes.byref(n_configs),
) or n_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
self._surface = EGL.eglCreatePbufferSurface(
self._display, config,
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
)
if not self._surface:
raise RuntimeError("eglCreatePbufferSurface() failed")
# Now import OpenGL.GL (after context is current)
logger.debug("GLContext.__init__: importing OpenGL.GL")
_import_opengl()
self._context = EGL.eglCreateContext(
self._display, config, EGL.EGL_NO_CONTEXT,
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
)
if not self._context:
raise RuntimeError("eglCreateContext() failed")
# Create VAO (required for core profile, but OSMesa may use compat profile)
logger.debug("GLContext.__init__: creating VAO")
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully")
except Exception as e:
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
raise RuntimeError("eglMakeCurrent() failed")
self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(self._vao)
except Exception:
self._cleanup()
raise
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
renderer = _gl_str(gl.GL_RENDERER)
vendor = _gl_str(gl.GL_VENDOR)
version = _gl_str(gl.GL_VERSION)
GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
def make_current(self):
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
err = EGL.eglGetError()
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _cleanup(self):
if not self._display:
return
try:
if self._vao is not None:
gl.glDeleteVertexArrays(1, [self._vao])
self._vao = None
except Exception:
pass
try:
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
except Exception:
pass
try:
if self._context:
EGL.eglDestroyContext(self._display, self._context)
except Exception:
pass
try:
if self._surface:
EGL.eglDestroySurface(self._display, self._surface)
except Exception:
pass
try:
EGL.eglTerminate(self._display)
except Exception:
pass
self._display = None
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
error = gl.glGetShaderInfoLog(shader)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
error = gl.glGetProgramInfoLog(program)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
@ -530,9 +421,6 @@ def _render_shader_batch(
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
@ -558,9 +446,9 @@ def _render_shader_batch(
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_source)
program = _create_program(VERTEX_SHADER, fragment_code)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}")
logger.error(f"Fragment shader:\n{fragment_code}")
raise
gl.glUseProgram(program)
@ -723,13 +611,13 @@ def _render_shader_batch(
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(img[::-1, :, :].copy())
for i in range(num_outputs):
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
buf = np.empty((height, width, 4), dtype=np.float32)
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
batch_outputs.append(buf[::-1, :, :].copy())
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
@ -750,18 +638,18 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in curve_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
gl.glDeleteTextures(int(tex))
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if curve_textures:
gl.glDeleteTextures(len(curve_textures), curve_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if ping_pong_textures:
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo])
if ping_pong_fbos:
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
if program is not None:
gl.glDeleteProgram(program)

View File

@ -214,11 +214,13 @@ class SaveAnimatedWEBP(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[IO.Image.Output(display_name="images")]
)
@classmethod
def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
return IO.NodeOutput(
images,
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
images=images,
filename_prefix=filename_prefix,
@ -230,8 +232,6 @@ class SaveAnimatedWEBP(IO.ComfyNode):
)
)
save_images = execute # TODO: remove
class SaveAnimatedPNG(IO.ComfyNode):
@ -249,11 +249,13 @@ class SaveAnimatedPNG(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[IO.Image.Output(display_name="images")]
)
@classmethod
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
return IO.NodeOutput(
images,
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
images=images,
filename_prefix=filename_prefix,
@ -263,8 +265,6 @@ class SaveAnimatedPNG(IO.ComfyNode):
)
)
save_images = execute # TODO: remove
class ImageStitch(IO.ComfyNode):
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@ -513,6 +513,7 @@ class SaveSVGNode(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[IO.SVG.Output("svg")],
)
@classmethod
@ -562,9 +563,7 @@ class SaveSVGNode(IO.ComfyNode):
results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
counter += 1
return IO.NodeOutput(ui={"images": results})
save_svg = execute # TODO: remove
return IO.NodeOutput(svg, ui={"images": results})
class GetImageSize(IO.ComfyNode):
@ -1157,40 +1156,27 @@ class SaveImageAdvanced(IO.ComfyNode):
IO.String.Input(
"filename_prefix",
default="ComfyUI",
tooltip=(
"The prefix for the file to save. May include formatting tokens "
"such as %date:yyyy-MM-dd% or %Empty Latent Image.width%."
),
tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd% or %Empty Latent Image.width%."),
),
IO.DynamicCombo.Input(
"format",
options=[
IO.DynamicCombo.Option("png", [
IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"],
default="8-bit", advanced=True),
IO.Combo.Input("input_color_space", options=["sRGB"],
default="sRGB", advanced=True),
IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], default="8-bit", advanced=True),
IO.Combo.Input("input_color_space", options=["sRGB"], default="sRGB", advanced=True),
]),
IO.DynamicCombo.Option("exr", [
IO.Combo.Input("bit_depth", options=["32-bit float"],
default="32-bit float", advanced=True),
IO.Combo.Input("bit_depth", options=["32-bit float"], default="32-bit float", advanced=True),
IO.Combo.Input(
"input_color_space",
options=["sRGB", "HDR", "linear"],
default="sRGB",
advanced=True,
tooltip=(
"Colorspace of the input tensor. The EXR is "
"always written as scene-linear in the matching "
"gamut.\n"
" 'sRGB' — input is sRGB-encoded Rec.709; "
"the inverse sRGB EOTF is applied.\n"
" 'HDR' — input is HLG-encoded Rec.2020 "
"(BT.2100); the inverse HLG OETF is applied "
"to get scene-linear light.\n"
" 'linear' — input is already scene-linear "
"(Rec.709 primaries); written through unchanged. "
"Use this for renderer/compositor output."
"Colorspace of the input tensor. The EXR is always written as scene-linear in the matching gamut.\n"
"sRGB — input is sRGB-encoded Rec.709; the inverse sRGB EOTF is applied.\n"
"HDR — input is HLG-encoded Rec.2020 (BT.2100); the inverse HLG OETF is applied to get scene-linear light.\n"
"linear — input is already scene-linear (Rec.709 primaries); written through unchanged. Use this for renderer/compositor output."
),
),
]),
@ -1200,6 +1186,7 @@ class SaveImageAdvanced(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[IO.Image.Output(display_name="images")]
)
@classmethod
@ -1237,7 +1224,7 @@ class SaveImageAdvanced(IO.ComfyNode):
results.append({"filename": file, "subfolder": subfolder, "type": "output"})
counter += 1
return IO.NodeOutput(ui={"images": results})
return IO.NodeOutput(images, ui={"images": results})
class ImagesExtension(ComfyExtension):

View File

@ -0,0 +1,77 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import normalize_palette
class BuildJsonPromptIdeogram(io.ComfyNode):
@classmethod
def define_schema(cls):
color_palette = io.Colors.Input(
"color_palette",
socketless=False,
tooltip="Hex color codes that steer the image's dominant colors. Up to 16 entries.",
)
return io.Schema(
node_id="BuildJsonPromptIdeogram",
display_name="Build JSON Prompt (Ideogram)",
category="text",
description="Build a JSON prompt for the Ideogram 4 model.",
inputs=[
io.Array.Input("element", tooltip="Prompt elements from the node Create Bounding Boxes."),
io.String.Input("high_level_description", multiline=True, default="",
tooltip="Optional description of the image in one or two sentences. Strongly recommended."),
io.String.Input("background", multiline=True, default="",
tooltip="Mandatory description of the image background or environment."),
io.DynamicCombo.Input("style", options=[
io.DynamicCombo.Option("none", []),
io.DynamicCombo.Option("photo", [io.String.Input("photo", default="", tooltip="Camera or lens details for photographic outputs (e.g. 35mm, f/1.4, bokeh).")]),
io.DynamicCombo.Option("art_style", [io.String.Input("art_style", default="", tooltip="Art style description (e.g. flat vector illustration, bold outlines).")]),
]),
io.String.Input("aesthetics", default="", tooltip="Mandatory aesthetic keywords (e.g. moody, cinematic, desaturated)."),
io.String.Input("lighting", default="", tooltip="Mandatory lighting description (e.g. golden hour, rim light, dramatic shadows)."),
io.String.Input("medium", default="", tooltip="Mandatory medium type (e.g. photograph, illustration, 3d_render, painting, graphic_design). When style = photo, set to photograph."),
color_palette,
],
outputs=[io.Dict.Output(display_name="prompt")],
is_experimental=True,
)
@classmethod
def execute(cls, element, style, high_level_description="", background="",
aesthetics="", lighting="", medium="", color_palette=None) -> io.NodeOutput:
elements = element if isinstance(element, list) else []
kind = style.get("style", "none") if isinstance(style, dict) else "none"
photo = style.get("photo", "") if isinstance(style, dict) else ""
art_style = style.get("art_style", "") if isinstance(style, dict) else ""
palette = normalize_palette(color_palette or [])
caption: dict = {}
if high_level_description.strip():
caption["high_level_description"] = high_level_description
if kind != "none":
style_desc: dict = {"aesthetics": aesthetics, "lighting": lighting}
if kind == "photo":
style_desc["photo"] = photo
style_desc["medium"] = medium
else:
style_desc["medium"] = medium
style_desc["art_style"] = art_style
if palette:
style_desc["color_palette"] = palette
caption["style_description"] = style_desc
caption["compositional_deconstruction"] = {
"background": background,
"elements": elements,
}
return io.NodeOutput(caption)
class JsonPromptExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [BuildJsonPromptIdeogram]
async def comfy_entrypoint() -> JsonPromptExtension:
return JsonPromptExtension()

View File

@ -317,11 +317,74 @@ class PreviewPointCloud(IO.ComfyNode):
)
MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
class Load3DAdvanced(IO.ComfyNode):
@classmethod
def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in MESH_EXTENSIONS
]
return IO.Schema(
node_id="Load3DAdvanced",
display_name="Load 3D (Advanced)",
category="3d",
search_aliases=[
"load mesh",
"load gltf",
"load glb",
"load obj",
"load fbx",
"load stl",
],
is_experimental=True,
inputs=[
IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model),
IO.Load3D.Input("viewport_state"),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.File3DAny.Output(display_name="model_3d"),
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
],
)
@classmethod
def validate_inputs(cls, model_file, **kwargs) -> bool | str:
if not model_file or model_file == "none":
return True
if not folder_paths.exists_annotated_filepath(model_file):
return f"Invalid 3D model file: {model_file}"
return True
@classmethod
def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
file_3d = None
if model_file and model_file != "none":
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
model_3d_info = viewport_state.get('model_3d_info', [])
return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height)
class Load3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Load3D,
Load3DAdvanced,
Preview3D,
Preview3DAdvanced,
PreviewGaussianSplat,

View File

@ -89,7 +89,8 @@ class SwitchNode(io.ComfyNode):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
search_aliases=["if", "then", "switch", "conditional", "branch"],
display_name="If/Else Switch",
category="utilities/logic",
is_experimental=True,
inputs=[

View File

@ -337,6 +337,36 @@ class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeKrea2(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "model/merging/model specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["first."] = argument
arg_dict["tmlp."] = argument
arg_dict["txtmlp."] = argument
arg_dict["tproj."] = argument
for i in range(2):
arg_dict["txtfusion.layerwise_blocks.{}.".format(i)] = argument
arg_dict["txtfusion.projector."] = argument
for i in range(2):
arg_dict["txtfusion.refiner_blocks.{}.".format(i)] = argument
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["last."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@ -353,4 +383,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
"ModelMergeQwenImage": ModelMergeQwenImage,
"ModelMergeKrea2": ModelMergeKrea2,
}

View File

@ -123,7 +123,8 @@ class PhotoMakerLoader(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerLoader",
category="experimental/photomaker",
display_name="Load PhotoMaker Model",
category="model/loaders",
inputs=[
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
],
@ -149,7 +150,8 @@ class PhotoMakerEncode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerEncode",
category="experimental/photomaker",
display_name="PhotoMaker Encode",
category="model/conditioning/photomaker",
inputs=[
io.Photomaker.Input("photomaker"),
io.Image.Input("image"),

View File

@ -10,12 +10,11 @@ class String(io.ComfyNode):
return io.Schema(
node_id="PrimitiveString",
search_aliases=["text", "string", "text box", "prompt"],
display_name="Text String",
display_name="Text String (DEPRECATED)",
category="utilities/primitive",
inputs=[
io.String.Input("value"),
],
inputs=[io.String.Input("value")],
outputs=[io.String.Output()],
is_deprecated=True
)
@classmethod
@ -29,12 +28,10 @@ class StringMultiline(io.ComfyNode):
return io.Schema(
node_id="PrimitiveStringMultiline",
search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"],
display_name="Text String (Multiline)",
display_name="Input Text",
category="utilities/primitive",
essentials_category="Basics",
inputs=[
io.String.Input("value", multiline=True),
],
inputs=[io.String.Input("value", multiline=True)],
outputs=[io.String.Output()],
)

View File

@ -34,14 +34,20 @@ def _unpack(track_data):
return unpack_masks(packed)
def _first_frame_cx_area(masks_bool):
first = masks_bool[0].float()
H, W = first.shape[-2], first.shape[-1]
n_pixels = H * W
grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W)
area = first.sum(dim=(-1, -2)).clamp_(min=1)
cx = (first * grid_x).sum(dim=(-1, -2)) / area
return (cx / W).tolist(), (area / n_pixels).tolist()
def _first_appearance_cx_area(masks_bool):
"""Per object: first frame it appears in, plus centroid-x and area in that frame."""
m = masks_bool.float()
T, H, W = m.shape[0], m.shape[-2], m.shape[-1]
grid_x = torch.arange(W, device=m.device, dtype=m.dtype).view(1, 1, 1, W)
area_t = m.sum(dim=(-1, -2))
cx_t = (m * grid_x).sum(dim=(-1, -2)) / area_t.clamp(min=1)
present = area_t > 0
frame_idx = torch.arange(T, device=m.device).unsqueeze(1)
first_t = torch.where(present, frame_idx, T).amin(dim=0)
sel = first_t.clamp(max=T - 1).unsqueeze(0)
cx = cx_t.gather(0, sel).squeeze(0)
area = area_t.gather(0, sel).squeeze(0)
return first_t.tolist(), (cx / W).tolist(), (area / (H * W)).tolist()
def _subset_track_data(track_data, obj_indices):
@ -81,12 +87,26 @@ def _render_colored_masks(track_data, background="black"):
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
).view(T, N_obj, H, W) > 0.5
any_mask = masks_full.any(dim=1)
obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1)
color_overlay = colors[obj_idx_map]
color_overlay = colors[masks_full.to(torch.uint8).argmax(dim=1)]
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
def _render_mask_as_identity(mask, background="black"):
"""Plain comfy MASK (B,H,W) or (H,W) -> (B,H,W,3) rendered as a single identity (palette[0])
on the given background. A batch is treated as multiple views of that one subject."""
device = comfy.model_management.intermediate_device()
dtype = comfy.model_management.intermediate_dtype()
if mask.ndim == 2:
mask = mask.unsqueeze(0)
mask = mask.to(device=device, dtype=dtype)
B, H, W = mask.shape
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
color = torch.tensor(DEFAULT_PALETTE[0], device=device, dtype=dtype).view(1, 1, 1, 3)
bg = torch.tensor(bg_rgb, device=device, dtype=dtype).view(1, 1, 1, 3)
return torch.where((mask > 0.5).unsqueeze(-1), color.expand(B, H, W, 3), bg.expand(B, H, W, 3))
def _extract_mask_to_28ch(rgb_video):
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
@ -138,8 +158,8 @@ class WanSCAILToVideo(io.ComfyNode):
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."),
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."),
io.Image.Input("reference_image", optional=True, tooltip="Reference image. The first image is the primary reference (composite all identities onto it). SCAIL-2: extra batch images are used as additional views (back view, close-up, occluded background), each needing a matching reference_image_mask in that identity's color."),
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask, batch matching reference_image (first = primary reference mask, rest = identity masks for the additional reference_image)."),
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
@ -171,19 +191,21 @@ class WanSCAILToVideo(io.ComfyNode):
video_frame_offset -= prev_trimmed.shape[0]
video_frame_offset = max(0, video_frame_offset)
ref_latent = None
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
# Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte
if replacement_mode and reference_image_mask is not None:
rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype)
reference_image = reference_image * is_char
ref_latent = vae.encode(reference_image[:, :, :, :3])
ref_imgs = comfy.utils.common_upscale(reference_image.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
n_ref = ref_imgs.shape[0]
# SCAIL-2 multi-reference: the first image is the primary ref, the rest are additional references.
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
# Replacement Mode: composite each ref on black bg using its mask as alpha matte
if replacement_mode and reference_image_mask is not None:
rm = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
rm = rm[[min(i, rm.shape[0] - 1) for i in range(n_ref)]]
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(ref_imgs.dtype)
ref_imgs = ref_imgs * is_char
# encode each ref individually so each stays a single latent frame (a batched encode would be treated as a video)
ref_latents = [vae.encode(ref_imgs[i:i + 1, :, :, :3]) for i in range(n_ref)]
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
@ -221,11 +243,16 @@ class WanSCAILToVideo(io.ComfyNode):
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
if reference_image_mask is not None:
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw)
# The ref mask binds reference frames to identities, so it only applies when there's a reference image.
if reference_image_mask is not None and reference_image is not None:
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
n_masks = ref_mask_hw.shape[0]
n_ref = reference_image.shape[0]
add_masks = [_extract_mask_to_28ch(ref_mask_hw[min(i, n_masks - 1)][None]) for i in range(1, n_ref)]
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw[:1])
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1)
ref_mask_28ch = torch.cat(add_masks + [ref_mask_1f, zeros], dim=1)
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
@ -244,12 +271,9 @@ class WanSCAILToVideo(io.ComfyNode):
class SCAIL2ColoredMask(io.ComfyNode):
"""Render SAM3 tracks for the driving pose video and (optionally) the reference
image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by`
across both outputs guarantees identity K maps to the same color on both
sides, for multi-person workflow consistency.
reference_image_mask is always rendered black-bg (model convention)
pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode
"""Render SAM3 tracks for the driving pose video and reference image(s) into the
colored masks WanSCAILToVideo consumes. Shared `sort_by` keeps each identity on the
same color across both outputs.
"""
@classmethod
@ -260,10 +284,12 @@ class SCAIL2ColoredMask(io.ComfyNode):
category="model/conditioning/wan/scail",
inputs=[
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
SAM3TrackData.Input("ref_track_data", optional=True, tooltip="SAM3 track of the reference image."),
io.String.Input("object_indices", default="", tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
io.MultiType.Input("ref_track_data", [SAM3TrackData, io.Mask], optional=True, display_name="reference_masks",
tooltip="SAM3 track of the reference image(s) (one identity per object, colored in batch order), or a plain MASK of the reference subject (rendered as a single identity)."),
io.String.Input("object_indices", default="",
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."),
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). Objects that appear in earlier frames always come first; within a frame, left_to_right = leftmost object (by centroid at first appearance) gets the first color, area = biggest object (by mask area at first appearance) gets the first color; none = keep SAM3's order."),
io.Boolean.Input("replacement_mode", default=False,
tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). "
"True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."),
@ -280,11 +306,11 @@ class SCAIL2ColoredMask(io.ComfyNode):
def _prep(td):
masks_bool = _unpack(td)
if sort_by != "none" and masks_bool is not None:
cx, area = _first_frame_cx_area(masks_bool)
first_t, cx, area = _first_appearance_cx_area(masks_bool)
if sort_by == "left_to_right":
order = sorted(range(len(cx)), key=lambda i: cx[i])
order = sorted(range(len(cx)), key=lambda i: (first_t[i], cx[i]))
else: # "area"
order = sorted(range(len(area)), key=lambda i: -area[i])
order = sorted(range(len(area)), key=lambda i: (first_t[i], -area[i]))
td = _subset_track_data(td, order)
if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
@ -300,8 +326,10 @@ class SCAIL2ColoredMask(io.ComfyNode):
ref_bg = "black" if replacement_mode else "white"
if ref_track_data is not None:
ref = _prep(ref_track_data)
reference_image_mask = _render_colored_masks(ref, ref_bg)
if isinstance(ref_track_data, torch.Tensor): # plain comfy MASK
reference_image_mask = _render_mask_as_identity(ref_track_data, ref_bg)
else:
reference_image_mask = _render_colored_masks(_prep(ref_track_data), ref_bg)
else:
H, W = drv["orig_size"]
fill_value = 1.0 if ref_bg == "white" else 0.0

View File

@ -0,0 +1,33 @@
import sys
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SeedNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedNode",
display_name="Seed",
search_aliases=["seed", "random"],
category="utilities",
inputs=[
io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
],
outputs=[io.Int.Output(display_name="seed")],
)
@classmethod
def execute(cls, seed: int) -> io.NodeOutput:
return io.NodeOutput(seed)
class SeedExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [SeedNode]
async def comfy_entrypoint() -> SeedExtension:
return SeedExtension()

View File

@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="experimental/stable_cascade",
category="experimental/stable cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),

View File

@ -440,6 +440,57 @@ class JsonExtractString(io.ComfyNode):
except (json.JSONDecodeError, TypeError):
return io.NodeOutput("")
def _dump_json(value, indent):
return json.dumps(value, ensure_ascii=False, indent=indent or None)
class ConvertDictionaryToString(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertDictionaryToString",
display_name="Convert Dictionary to String",
category="text",
search_aliases=["json", "dict to json", "stringify", "serialize", "dict to string"],
inputs=[
io.Dict.Input("dictionary"),
io.Int.Input("indent", default=2, min=0, max=8,
tooltip="Spaces per indent level. 0 produces compact single-line string."),
],
outputs=[
io.String.Output(),
],
)
@classmethod
def execute(cls, dictionary, indent=2):
return io.NodeOutput(_dump_json(dictionary, indent))
class ConvertArrayToString(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertArrayToString",
display_name="Convert Array to String",
category="text",
search_aliases=["json", "list to json", "stringify", "serialize", "list to string", "array to json"],
inputs=[
io.Array.Input("array"),
io.Int.Input("indent", default=2, min=0, max=8,
tooltip="Spaces per indent level. 0 produces compact single-line string."),
],
outputs=[
io.String.Output(),
],
)
@classmethod
def execute(cls, array, indent=2):
return io.NodeOutput(_dump_json(array, indent))
class StringExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -457,6 +508,8 @@ class StringExtension(ComfyExtension):
RegexExtract,
RegexReplace,
JsonExtractString,
ConvertDictionaryToString,
ConvertArrayToString,
]
async def comfy_entrypoint() -> StringExtension:

View File

@ -65,7 +65,7 @@ class TripoSplatPreprocessImage(IO.ComfyNode):
return IO.Schema(
node_id="TripoSplatPreprocessImage",
display_name="TripoSplat Preprocess Image",
category="3d/conditioning",
category="model/conditioning/triposplat",
description="Crop center each image to a square canvas on a black background and add padding.",
inputs=[
IO.Image.Input("image"),
@ -95,7 +95,7 @@ class TripoSplatConditioning(IO.ComfyNode):
return IO.Schema(
node_id="TripoSplatConditioning",
display_name="TripoSplat Conditioning",
category="3d/conditioning",
category="model/conditioning/triposplat",
description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative "
"conditioning, and create the fixed size noise target (latent + camera) for the KSampler",
inputs=[
@ -143,7 +143,7 @@ class VAEDecodeTripoSplat(IO.ComfyNode):
return IO.Schema(
node_id="VAEDecodeTripoSplat",
display_name="TripoSplat Decode",
category="3d/latent",
category="model/latent/triposplat",
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
"Modify the number of gaussians to vary the density.",
inputs=[
@ -188,7 +188,7 @@ class TripoSplatSamplingPreview(IO.ComfyNode):
return IO.Schema(
node_id="TripoSplatSamplingPreview",
display_name="TripoSplat Sampling Preview",
category="3d/latent",
category="model/latent/triposplat",
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
"gaussian splat preview at each step.",
inputs=[

View File

@ -27,6 +27,7 @@ class SaveWEBM(io.ComfyNode):
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[io.Image.Output(display_name="images")]
)
@classmethod
@ -69,7 +70,7 @@ class SaveWEBM(io.ComfyNode):
container.mux(stream.encode())
container.close()
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return io.NodeOutput(images, ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
class SaveVideo(io.ComfyNode):
@classmethod
@ -89,6 +90,7 @@ class SaveVideo(io.ComfyNode):
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[io.Video.Output("video")],
)
@classmethod
@ -117,7 +119,7 @@ class SaveVideo(io.ComfyNode):
metadata=saved_metadata
)
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return io.NodeOutput(video, ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
class CreateVideo(io.ComfyNode):
@ -233,13 +235,8 @@ class VideoSlice(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Video Slice",
display_name="Video Slice",
search_aliases=[
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
display_name="Trim Video",
search_aliases=["trim video duration", "skip first frames", "frame load cap", "start time"],
category="video",
essentials_category="Video Tools",
inputs=[

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.25.0"
__version__ = "0.27.0"

View File

@ -1308,6 +1308,25 @@ class PromptQueue:
queued = copy.copy(self.queue)
return (running, queued)
def interrupt_if_running(self, prompt_id):
"""Interrupt the running prompt with this id, atomically.
Checks the live running set and signals the interrupt under the queue
mutex, so the worker cannot move the job to done (and start the next
prompt) in between. Returns True if a matching job was running and an
interrupt was signalled, False otherwise. The atomicity is what keeps a
cancel from landing on an unrelated prompt that started after a separate
is-running check: the global interrupt flag is reset at the start of
every prompt (execute_async), so a job that finishes before consuming
the flag cannot leak the interrupt onto its successor.
"""
with self.mutex:
for item in self.currently_running.values():
if item[1] == prompt_id:
nodes.interrupt_processing()
return True
return False
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)

View File

@ -8,21 +8,37 @@
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
# #is_default: true
# checkpoints: models/checkpoints/
# configs: models/configs/
# loras: models/loras/
# vae: models/vae/
# text_encoders: |
# models/text_encoders/
# models/clip/ # legacy location still supported
# clip_vision: models/clip_vision/
# configs: models/configs/
# controlnet: models/controlnet/
# models/clip/
# diffusion_models: |
# models/diffusion_models
# models/unet
# models/unet/
# models/diffusion_models/
# clip_vision: models/clip_vision/
# style_models: models/style_models/
# embeddings: models/embeddings/
# loras: models/loras/
# diffusers: models/diffusers/
# vae_approx: models/vae_approx/
# controlnet: |
# models/controlnet/
# models/t2i_adapter/
# gligen: models/gligen/
# upscale_models: models/upscale_models/
# vae: models/vae/
# audio_encoders: models/audio_encoders/
# latent_upscale_models: models/latent_upscale_models/
# custom_nodes: custom_nodes/
# hypernetworks: models/hypernetworks/
# photomaker: models/photomaker/
# classifiers: models/classifiers/
# model_patches: models/model_patches/
# audio_encoders: models/audio_encoders/
# background_removal: models/background_removal/
# frame_interpolation: models/frame_interpolation/
# geometry_estimation: models/geometry_estimation/
# optical_flow: models/optical_flow/
# detection: models/detection/
#config for a1111 ui
@ -45,8 +61,7 @@
# controlnet: models/ControlNet
# For a full list of supported keys (style_models, vae_approx, hypernetworks, photomaker,
# model_patches, audio_encoders, classifiers, etc.) see folder_paths.py.
# For the canonical list of supported keys and extensions, see folder_paths.py.
#other_ui:
# base_path: path/to/ui

13
main.py
View File

@ -403,7 +403,7 @@ def prompt_worker(q, server_instance):
hook_breaker_ac10a0.restore_functions()
if not asset_seeder.is_disabled():
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=args.enable_asset_hashing)
asset_seeder.resume()
@ -458,7 +458,7 @@ def setup_database():
if dependencies_available():
init_db()
if args.enable_assets:
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=args.enable_asset_hashing):
logging.info("Background asset scan initiated for models, input, output")
except Exception as e:
if "database is locked" in str(e):
@ -557,8 +557,13 @@ if __name__ == "__main__":
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
if args.disable_dynamic_vram:
logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.")
logging.warning(
"Dynamic vram disabled with argument. If you have any issues with "
"dynamic vram enabled please give us a detailed reports as this "
"argument will be removed soon. If you use gguf we recommend keeping "
"dynamic vram enabled and using native ComfyUI model formats instead. "
"ComfyUI native formats like fp8 will be faster even if they are larger than your memory."
)
event_loop, _, start_all_func = start_comfyui()
try:
x = start_all_func()

View File

@ -20,8 +20,6 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_load
import comfy.samplers
import comfy.sample
@ -161,6 +159,29 @@ class ConditioningConcat:
return (out, )
class ConditioningMultiply:
SEARCH_ALIASES = ["scale conditioning", "scale prompt", "multiply conditioning", "multiply prompt"]
@classmethod
def INPUT_TYPES(cls):
return {"required": {"conditioning": ("CONDITIONING", ),
"multiplier": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "multiply"
CATEGORY = "model/conditioning/transform"
def multiply(self, conditioning, multiplier):
c = []
for t in conditioning:
values = {}
pooled_output = t[1].get("pooled_output", None)
if pooled_output is not None:
values["pooled_output"] = pooled_output * multiplier
scaled = node_helpers.conditioning_set_values([[t[0] * multiplier, t[1]]], values)[0]
c.append(scaled)
return (c,)
class ConditioningSetArea:
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
@ -328,7 +349,7 @@ class VAEDecodeTiled:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "experimental"
CATEGORY = "model/latent"
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
@ -375,7 +396,7 @@ class VAEEncodeTiled:
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "experimental"
CATEGORY = "model/latent"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
@ -482,16 +503,18 @@ class SaveLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
return { "required": {
"samples": ("LATENT",),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "experimental"
CATEGORY = "model/latent"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
@ -524,7 +547,7 @@ class SaveLatent:
output["latent_format_version_0"] = torch.tensor([])
comfy.utils.save_torch_file(output, file, metadata=metadata)
return { "ui": { "latents": results } }
return { "ui": { "latents": results }, "result": (samples,) }
class LoadLatent:
@ -536,7 +559,7 @@ class LoadLatent:
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "experimental"
CATEGORY = "model/latent"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
@ -969,7 +992,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "joyimage"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2", "joyimage"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -1629,14 +1652,18 @@ class SaveImage:
return {
"required": {
"images": ("IMAGE", {"tooltip": "The images to save."}),
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
"filename_prefix": ("STRING", {
"default": "ComfyUI",
"tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."
})
},
"hidden": {
"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ()
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "save_images"
OUTPUT_NODE = True
@ -1672,7 +1699,7 @@ class SaveImage:
})
counter += 1
return { "ui": { "images": results } }
return { "ui": { "images": results }, "result" : (images,) }
class PreviewImage(SaveImage):
def __init__(self):
@ -2046,6 +2073,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningAverage": ConditioningAverage,
"ConditioningCombine": ConditioningCombine,
"ConditioningConcat": ConditioningConcat,
"ConditioningMultiply": ConditioningMultiply,
"ConditioningSetArea": ConditioningSetArea,
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
@ -2117,6 +2145,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningAverage ": "Conditioning (Average)",
"ConditioningAverage": "Conditioning (Average)",
"ConditioningConcat": "Conditioning (Concat)",
"ConditioningMultiply": "Conditioning (Multiply)",
"ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetAreaStrength": "Conditioning (Set Area Strength)",
@ -2126,6 +2155,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"GLIGENTextBoxApply": "Apply GLIGEN Text Box",
"ConditioningZeroOut": "Conditioning Zero Out",
# Latent
"LoadLatent": "Load Latent",
"SaveLatent": "Save Latent",
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask",
"VAEDecode": "VAE Decode",
@ -2160,7 +2191,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSharpen": "Sharpen Image",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
"GetImageSize": "Get Image Size",
# experimental
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
}
@ -2299,6 +2329,9 @@ async def init_external_custom_nodes():
Returns:
None
"""
# TODO: remove at some point when custom nodes don't break.
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
@ -2367,6 +2400,8 @@ async def init_builtin_extra_nodes():
"nodes_images.py",
"nodes_video_model.py",
"nodes_ideogram4.py",
"nodes_bounding_boxes.py",
"nodes_json_prompt.py",
"nodes_train.py",
"nodes_dataset.py",
"nodes_sag.py",
@ -2426,6 +2461,7 @@ async def init_builtin_extra_nodes():
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_joyimage.py",
"nodes_boogu.py",
"nodes_chroma_radiance.py",
"nodes_pid.py",
"nodes_model_patch.py",
@ -2466,6 +2502,7 @@ async def init_builtin_extra_nodes():
"nodes_gaussian_splat.py",
"nodes_triposplat.py",
"nodes_depth_anything_3.py",
"nodes_seed.py",
]
import_failed = []

View File

@ -55,6 +55,12 @@ components:
description: URL for asset preview/thumbnail
format: uri
type: string
short_url:
description: Durable, owner-gated short link to this asset's content (relative `/api/s/{id}` path). Stable across the underlying signed URL's expiry — resolving it re-mints a fresh signed URL on every request — so it is safe to persist or share into chat, unlike `preview_url`. Only the minting user can resolve it. Omitted when the short-link surface is disabled or the asset has no resolvable content hash.
nullable: true
type: string
x-runtime:
- cloud
size:
description: Size of the asset in bytes
format: int64
@ -673,6 +679,35 @@ components:
- created_at
- updated_at
type: object
JobsCancelRequest:
additionalProperties: false
description: Request to cancel multiple jobs by ID.
properties:
job_ids:
description: Job identifiers (UUIDs) to cancel.
items:
format: uuid
type: string
maxItems: 100
minItems: 1
type: array
required:
- job_ids
type: object
JobsCancelResponse:
description: Response for POST /api/jobs/cancel.
properties:
cancelled:
description: |
Job IDs for which a cancel event was successfully dispatched by this
call. Jobs already in a terminal or cancelling state are idempotently
skipped and will not appear here.
items:
type: string
type: array
required:
- cancelled
type: object
JobsListResponse:
description: Paginated list of jobs for the authenticated user.
properties:
@ -1006,7 +1041,7 @@ components:
description: If true, clear all pending jobs from the queue
type: boolean
delete:
description: Array of PENDING job IDs to cancel
description: Array of job IDs to cancel; pending and running jobs transition to cancelled
items:
type: string
type: array
@ -1657,6 +1692,12 @@ paths:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unsupported media type
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Validation error (e.g., disallowed model_type tag)
"500":
content:
application/json:
@ -1822,6 +1863,83 @@ paths:
summary: Update asset metadata
tags:
- file
/api/assets/{id}/content:
get:
description: |
Returns the binary content of an asset by ID.
The contract is the same across runtimes — "GET this path and you
receive the asset's bytes" — but the mechanism differs:
- **Local ComfyUI** streams the bytes directly (`200`,
`application/octet-stream`).
- **Cloud** does not proxy large files; it responds `302` with a
`Location` redirect to a short-lived signed storage URL. Clients that
follow redirects (browsers, `fetch`/XHR, `<img>`/`<video>`) receive
the bytes transparently.
Prefer this over the filename-addressed `/api/view` when you have an
asset ID.
operationId: getAssetContent
parameters:
- description: Asset ID
in: path
name: id
required: true
schema:
type: string
- description: |
Content-Disposition for the response: `attachment` (download) or
`inline` (render in browser). Defaults to `attachment`.
in: query
name: disposition
schema:
default: attachment
enum:
- inline
- attachment
type: string
responses:
"200":
content:
application/octet-stream:
schema:
format: binary
type: string
description: Asset content stream (local runtime streams the bytes directly)
"302":
description: Redirect to a signed storage URL (cloud runtime)
headers:
Cache-Control:
description: Private caching directive scoped to the signed URL lifetime
schema:
type: string
Location:
description: Short-lived signed URL to the asset content in storage
schema:
type: string
Vary:
description: Partitions any cached redirect by auth credentials so a private redirect is not reused across users
schema:
type: string
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Asset not found
"500":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error
security:
- ApiKeyAuth: []
- BearerAuth: []
- CookieAuth: []
summary: Get asset content
tags:
- file
/api/assets/{id}/tags:
delete:
description: Removes one or more tags from an existing asset
@ -2025,6 +2143,12 @@ paths:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Source asset with given hash not found
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Validation error (e.g., disallowed model_type tag)
"500":
content:
application/json:
@ -2245,6 +2369,10 @@ paths:
description: |
Returns a list of model folders available in the system.
This is an experimental endpoint that replaces the legacy /models endpoint.
Each folder's name is the identifier to pass to /api/experiment/models/{folder}.
Once the model_type migration is active the names are model_type folder_names
(e.g. `ultralytics_bbox`); a folder with no folder_name mapping is returned by
its directory path.
operationId: getModelFolders
responses:
"200":
@ -2675,14 +2803,20 @@ paths:
summary: Get internationalisation translation strings
/api/interrupt:
post:
deprecated: true
description: |
Cancel all currently RUNNING jobs for the authenticated user.
This will interrupt any job that is currently in 'in_progress' status.
Note: This endpoint only affects running jobs. To cancel pending jobs, use /api/queue.
Deprecated. Prefer the jobs-namespace cancel endpoints:
POST /api/jobs/{job_id}/cancel for a single job, or
POST /api/jobs/cancel to cancel jobs by ID.
Cancels the first active job for the authenticated user (the currently
running job if there is one, otherwise the next pending job). Takes no
body and cannot target a specific job — use the jobs-namespace endpoints
for that.
operationId: interruptJob
responses:
"200":
description: Success - Job interrupted or no running job found
description: Success - first active job cancelled, or no active job found
"401":
content:
application/json:
@ -2695,7 +2829,7 @@ paths:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error
summary: Interrupt currently running jobs
summary: Interrupt the first active job
tags:
- queue
/api/job/{job_id}/status:
@ -2869,6 +3003,17 @@ paths:
schema:
format: uuid
type: string
- description: |
When present, each output item in the response receives a `short_url` field containing a short link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime and auth model: use `ephemeral_tool_chain` for short-lived (≤5 minute) machine-to-machine handoffs — these are public bearer links where the link ID itself is the credential, so anyone holding the link can resolve it (intended for pasting into an agent/MCP tool chain); use `default` for durable (30 day) human-revisitable links, which are owner-gated and resolvable only by the authenticated owner. Links are always minted under the authenticated request owner's identity; the auth model is selected by the server and is never settable by the caller.
in: query
name: short_link
schema:
enum:
- ephemeral_tool_chain
- default
type: string
x-runtime:
- cloud
responses:
"200":
content:
@ -2954,6 +3099,64 @@ paths:
summary: Cancel a job
tags:
- workflow
/api/jobs/cancel:
post:
description: |
Cancel one or more jobs for the authenticated user in a single request.
State-agnostic: cancels both pending and running jobs (both transition to
the cancelled state via the same mechanism as the single-job endpoint).
Idempotent per job: a job already in a terminal or cancelling state is a
no-op and simply will not appear in the returned `cancelled` list.
Fail-fast on unknown IDs: if any provided job ID does not exist for this
user, the request returns 404 and no jobs are cancelled. This surfaces
bad IDs to the caller rather than silently dropping them.
This is the canonical batch-cancel endpoint. The delete operation on
POST /api/queue is deprecated in favour of this.
operationId: cancelJobs
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/JobsCancelRequest'
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/JobsCancelResponse'
description: Success - cancel requests dispatched (or jobs were already terminal)
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Bad Request - job_ids is missing, empty, exceeds the maximum count, or contains an invalid UUID
"401":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unauthorized - Authentication required
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: One or more job IDs not found for this user (no jobs cancelled)
"500":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error - cancellation failed
summary: Cancel multiple jobs
tags:
- workflow
/api/node_replacements:
get:
description: |
@ -3104,9 +3307,18 @@ paths:
tags:
- queue
post:
deprecated: true
description: |
Cancel specific PENDING jobs by ID or clear all pending jobs in the queue.
Note: This endpoint only affects pending jobs. To cancel running jobs, use /api/interrupt.
Deprecated. Prefer the jobs-namespace cancel endpoints:
POST /api/jobs/cancel for cancelling jobs by ID, and
POST /api/jobs/{job_id}/cancel for a single job.
Cancel specific jobs by ID (the `delete` field) or clear all pending
jobs in the queue (the `clear` field). Despite the `delete` naming, this
does not delete anything — listed jobs transition to the cancelled state,
and `delete` cancels both pending and running jobs (not pending-only as
previously documented). Job-by-ID cancellation is superseded by
POST /api/jobs/cancel; `clear` has no jobs-namespace replacement yet.
operationId: manageQueue
requestBody:
content:

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.25.0"
version = "0.27.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.45.15
comfyui-workflow-templates==0.10.0
comfyui-embedded-docs==0.5.4
comfyui-frontend-package==1.45.20
comfyui-workflow-templates==0.11.1
comfyui-embedded-docs==0.5.6
torch
torchsde
torchvision
@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.10
comfy-kitchen==0.2.16
comfy-aimdo==0.4.10
requests
simpleeval>=1.0.0
@ -33,5 +33,5 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
glfw
PyOpenGL>=3.1.8
comfy-angle

111
server.py
View File

@ -8,7 +8,15 @@ import time
import nodes
import folder_paths
import execution
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id
from comfy_execution.jobs import (
JobStatus,
get_job,
get_all_jobs,
validate_job_id,
cancel_job,
CANCEL_PENDING,
CANCEL_RUNNING,
)
import uuid
import urllib
import json
@ -899,6 +907,107 @@ class PromptServer():
return web.json_response(job)
def _cancel_job_by_id(job_id):
"""Cancel a single job by id using the queue's existing mechanics.
Running jobs are interrupted (same mechanism as /interrupt); pending
jobs are dequeued (same mechanism as /queue {"delete": [...]}).
Already-finished or unknown ids are no-ops. State-agnostic.
Returns True when a cancel was actually dispatched (running or
pending job), False when the call was a no-op (terminal/unknown id).
"""
running, queued = self.prompt_queue.get_current_queue()
history = self.prompt_queue.get_history()
def interrupt(prompt_id):
logging.info(f"Cancelling running prompt {prompt_id}")
# Atomic: only interrupts if the job is still the one running,
# so a cancel can't land on a prompt that started in the gap
# since the snapshot above. Returns whether it actually fired.
return self.prompt_queue.interrupt_if_running(prompt_id)
def dequeue(prompt_id):
logging.info(f"Cancelling pending prompt {prompt_id}")
return self.prompt_queue.delete_queue_item(lambda a: a[1] == prompt_id)
classification = cancel_job(job_id, running, queued, history, interrupt, dequeue)
return classification in (CANCEL_RUNNING, CANCEL_PENDING)
@routes.post("/api/jobs/{job_id}/cancel")
async def cancel_job_by_id(request):
"""Cancel a single job by id, regardless of state.
Idempotent: cancelling a job that has already finished, or an id
that is not known, returns 200 with {"cancelled": false} rather
than an error.
"""
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response(
{"error": "job_id is required"},
status=400
)
cancelled = _cancel_job_by_id(job_id)
return web.json_response({"cancelled": cancelled})
@routes.post("/api/jobs/cancel")
async def cancel_jobs_batch(request):
"""Cancel a batch of jobs by id.
Body: {"job_ids": ["<uuid>", ...]}
Best-effort and idempotent: every well-formed id is cancelled if it
is running or pending; ids that are already finished or unknown are
no-ops, not errors. A batch of all no-ops still returns 200 with
{"cancelled": false}. This matches the single-cancel endpoint and
means "cancel all" still cancels the in-progress jobs even if some
finished between the client's snapshot and the request. Malformed
ids are still rejected up front with 400 (see below).
"""
try:
json_data = await request.json()
except json.JSONDecodeError:
return web.json_response(
{"error": "Request body must be valid JSON"},
status=400
)
job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None
if not isinstance(job_ids, list):
return web.json_response(
{"error": "job_ids must be a list"},
status=400
)
# Validate that every element is a well-formed job id before doing
# anything else. An unhashable element (e.g. a nested dict or list)
# would cause a TypeError when used as a history dict key; a
# non-string or non-UUID value is never a valid id. Reject early
# with 400 rather than letting the classify loop raise 500.
invalid_ids = []
for jid in job_ids:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
if invalid_ids:
return web.json_response(
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400,
)
# Best-effort: cancel each id that is still running/pending; an id
# that has finished or never existed is a no-op rather than a reason
# to fail the whole batch.
cancelled = False
for jid in job_ids:
if _cancel_job_by_id(jid):
cancelled = True
return web.json_response({"cancelled": cancelled})
@routes.get("/history")
async def get_history(request):
max_items = request.rel_url.query.get("max_items", None)

View File

@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase):
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
def test_int8_convrot_metadata_loads_into_params(self):
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
torch.manual_seed(123)
layer_quant_config = {
"layer": {
"format": "int8_tensorwise",
"convrot": True,
"convrot_groupsize": 256,
}
}
weight = torch.randn(16, 256, dtype=torch.bfloat16)
bias = torch.randn(16, dtype=torch.bfloat16)
q_weight = QuantizedTensor.from_float(
weight,
"TensorWiseINT8Layout",
per_channel=True,
convrot=True,
convrot_groupsize=256,
)
state_dict = {
"layer.weight": q_weight._qdata,
"layer.bias": bias,
"layer.weight_scale": q_weight._params.scale,
}
state_dict, _ = comfy.utils.convert_old_quants(
state_dict,
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
)
model = torch.nn.Module()
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
model.load_state_dict(state_dict, strict=False)
self.assertIsInstance(model.layer.weight, QuantizedTensor)
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
self.assertTrue(model.layer.weight._params.convrot)
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
loaded_out = model.layer(input_tensor)
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
self.assertTrue(torch.equal(loaded_out, ref_out))
fp16_input = input_tensor.to(torch.float16)
loaded_fp16_out = model.layer(fp16_input)
ref_fp16_out = torch.nn.functional.linear(
fp16_input,
q_weight.to(dtype=torch.float16),
bias.to(dtype=torch.float16),
)
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
saved = model.state_dict()
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
self.assertTrue(saved_conf["convrot"])
self.assertEqual(saved_conf["convrot_groupsize"], 256)
if __name__ == "__main__":
unittest.main()

View File

View File

@ -0,0 +1,453 @@
"""Tests for the jobs-namespace cancel endpoints.
Covers both layers:
* the pure cancel helpers in ``comfy_execution.jobs``
(``classify_job_for_cancel`` / ``cancel_job``), which hold the business
logic of mapping a cancel onto interrupt-vs-dequeue, and
* the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and
``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and
best-effort batch cancellation that treats unknown/finished ids as no-ops
while still rejecting malformed ids with 400).
The HTTP layer is exercised against a small aiohttp app whose handlers are a
faithful copy of the wiring in ``server.py`` driven by a fake queue that
mirrors ``execution.PromptQueue`` (``get_current_queue`` / ``get_history`` /
``delete_queue_item``). This keeps the test free of the heavy ComfyUI runtime
(torch, nodes, ...) while still testing the real cancel logic.
"""
import json
import pytest
from aiohttp import web
from comfy_execution.jobs import (
CANCEL_PENDING,
CANCEL_RUNNING,
CANCEL_TERMINAL,
CANCEL_UNKNOWN,
cancel_job,
classify_job_for_cancel,
validate_job_id,
)
# Classifications for which a cancel was actually dispatched (vs a no-op).
_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING)
# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format).
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd"
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
def make_queue_item(prompt_id, number=0):
"""Build a queue tuple shaped like the real ones: index 1 is the id."""
return (number, prompt_id, {}, {}, [])
class FakePromptQueue:
"""Minimal stand-in for execution.PromptQueue for the cancel paths.
Tracks interrupts and dequeues so tests can assert side effects.
"""
def __init__(self, running=None, pending=None, history=None):
self._running = list(running or [])
self._pending = list(pending or [])
self._history = dict(history or {})
self.interrupt_count = 0
def get_current_queue(self):
return (list(self._running), list(self._pending))
def get_history(self, prompt_id=None):
if prompt_id is None:
return dict(self._history)
if prompt_id in self._history:
return {prompt_id: self._history[prompt_id]}
return {}
def delete_queue_item(self, function):
for i, item in enumerate(self._pending):
if function(item):
self._pending.pop(i)
return True
return False
def interrupt_if_running(self, prompt_id):
# Mirrors execution.PromptQueue.interrupt_if_running: only signals an
# interrupt when the id is actually in the running set.
if any(item[1] == prompt_id for item in self._running):
self.interrupt_count += 1
return True
return False
def build_app(queue):
"""Build an aiohttp app exposing the cancel routes against ``queue``.
Handler bodies mirror server.py exactly.
"""
def _cancel_job_by_id(job_id):
running, pending = queue.get_current_queue()
history = queue.get_history()
def interrupt(prompt_id):
return queue.interrupt_if_running(prompt_id)
def dequeue(prompt_id):
return queue.delete_queue_item(lambda a: a[1] == prompt_id)
classification = cancel_job(
job_id, running, pending, history, interrupt, dequeue
)
return classification in _CANCELLED
async def cancel_job_by_id(request):
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response({"error": "job_id is required"}, status=400)
cancelled = _cancel_job_by_id(job_id)
return web.json_response({"cancelled": cancelled})
async def cancel_jobs_batch(request):
try:
json_data = await request.json()
except json.JSONDecodeError:
return web.json_response(
{"error": "Request body must be valid JSON"}, status=400
)
job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None
if not isinstance(job_ids, list):
return web.json_response({"error": "job_ids must be a list"}, status=400)
invalid_ids = []
for jid in job_ids:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
if invalid_ids:
return web.json_response(
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400,
)
cancelled = False
for jid in job_ids:
if _cancel_job_by_id(jid):
cancelled = True
return web.json_response({"cancelled": cancelled})
app = web.Application()
app.router.add_post("/api/jobs/{job_id}/cancel", cancel_job_by_id)
app.router.add_post("/api/jobs/cancel", cancel_jobs_batch)
return app
# ---------------------------------------------------------------------------
# Pure helper tests: classification + cancel side effects
# ---------------------------------------------------------------------------
class TestClassifyJobForCancel:
def test_running(self):
running = [make_queue_item("a")]
assert classify_job_for_cancel("a", running, [], {}) == CANCEL_RUNNING
def test_pending(self):
pending = [make_queue_item("b")]
assert classify_job_for_cancel("b", [], pending, {}) == CANCEL_PENDING
def test_terminal(self):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
assert classify_job_for_cancel("c", [], [], history) == CANCEL_TERMINAL
def test_unknown(self):
assert classify_job_for_cancel("z", [], [], {}) == CANCEL_UNKNOWN
class TestCancelJobHelper:
"""``interrupt`` and ``dequeue`` both take the id and return whether they
actually acted, so cancel_job's return reflects the real outcome."""
def test_running_is_interrupted_not_dequeued(self):
interrupts = []
dequeues = []
result = cancel_job(
"a", [make_queue_item("a")], [], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_RUNNING
assert interrupts == ["a"]
assert dequeues == []
def test_pending_is_dequeued_not_interrupted(self):
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_PENDING
assert dequeues == ["b"]
assert interrupts == []
def test_terminal_is_noop(self):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
interrupts = []
dequeues = []
result = cancel_job(
"c", [], [], history,
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_TERMINAL
assert interrupts == []
assert dequeues == []
def test_unknown_is_noop(self):
interrupts = []
dequeues = []
result = cancel_job(
"z", [], [], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_UNKNOWN
assert interrupts == []
assert dequeues == []
def test_running_but_finished_before_interrupt_returns_unknown(self):
"""Classified RUNNING from a stale snapshot, but the job finished before
the atomic interrupt fired (interrupt returns False). cancel_job reports
UNKNOWN rather than claiming a cancel that did not happen and the
atomic interrupt guarantees no unrelated job was hit."""
interrupts = []
result = cancel_job(
"a", [make_queue_item("a")], [], {},
interrupt=lambda pid: interrupts.append(pid) or False,
dequeue=lambda pid: True,
)
assert result == CANCEL_UNKNOWN
assert interrupts == ["a"] # interrupt was attempted atomically
def test_pending_started_running_is_interrupted(self):
"""Pending->running race: the job leaves the queue (dequeue False)
because it started executing. The atomic interrupt catches the now-
running job, so cancel_job interrupts it and reports CANCEL_RUNNING."""
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: (dequeues.append(pid), False)[1],
)
assert result == CANCEL_RUNNING
assert dequeues == ["b"] # dequeue attempted first
assert interrupts == ["b"] # then the now-running job was interrupted
def test_pending_dequeue_miss_not_running_returns_unknown(self):
"""Dequeue miss where the job is not running anymore (it finished): the
atomic interrupt finds nothing to interrupt and returns False, so
cancel_job is a no-op reporting UNKNOWN never reporting a cancel that
did not happen, and never interrupting a bystander."""
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or False,
dequeue=lambda pid: (dequeues.append(pid), False)[1],
)
assert result == CANCEL_UNKNOWN
assert dequeues == ["b"]
assert interrupts == ["b"] # interrupt attempted, found nothing running
# ---------------------------------------------------------------------------
# HTTP contract tests: POST /api/jobs/{job_id}/cancel
# ---------------------------------------------------------------------------
class TestSingleCancelEndpoint:
@pytest.mark.asyncio
async def test_cancel_running_job_interrupts(self, aiohttp_client):
queue = FakePromptQueue(running=[make_queue_item("a")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/a/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1
@pytest.mark.asyncio
async def test_cancel_pending_job_dequeues(self, aiohttp_client):
queue = FakePromptQueue(pending=[make_queue_item("b")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/b/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
# Pending job removed from the queue; nothing interrupted.
assert queue.get_current_queue()[1] == []
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_terminal_job_is_idempotent_noop(self, aiohttp_client):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
queue = FakePromptQueue(history=history)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/c/cancel")
# Already-finished job: 200 no-op (cancelled=false), not an error.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_unknown_id_is_200_noop(self, aiohttp_client):
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/does-not-exist/cancel")
# Single-cancel of an unknown id is treated as an idempotent no-op.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_pending_that_started_running_interrupts(self, aiohttp_client):
"""Pending->running race end to end: the job is pending at snapshot time
but starts executing by the time we dequeue (delete misses). The live
re-check sees it running and interrupts it, so the cancel is not dropped
and the caller still gets cancelled=True."""
class RacingQueue(FakePromptQueue):
def delete_queue_item(self, function):
# The worker picked the job up just before we removed it: it
# leaves the pending queue (delete misses) and is now running.
self._running = list(self._pending)
self._pending = []
return False
queue = RacingQueue(pending=[make_queue_item("b")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/b/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1
# ---------------------------------------------------------------------------
# HTTP contract tests: POST /api/jobs/cancel (batch)
# ---------------------------------------------------------------------------
class TestBatchCancelEndpoint:
@pytest.mark.asyncio
async def test_batch_happy_path(self, aiohttp_client):
queue = FakePromptQueue(
running=[make_queue_item(_UUID_A)],
pending=[make_queue_item(_UUID_B, number=1)],
)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]})
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1 # running job interrupted
assert queue.get_current_queue()[1] == [] # pending job dequeued
@pytest.mark.asyncio
async def test_batch_best_effort_skips_unknown_id(self, aiohttp_client):
"""An unknown id in the batch is a no-op, not a reason to abort: the
running and pending jobs are still cancelled (200, cancelled=true). This
is the "cancel all as a job finishes" case from review."""
queue = FakePromptQueue(
running=[make_queue_item(_UUID_A)],
pending=[make_queue_item(_UUID_B, number=1)],
)
client = await aiohttp_client(build_app(queue))
resp = await client.post(
"/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]}
)
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1 # running job interrupted
assert queue.get_current_queue()[1] == [] # pending job dequeued
@pytest.mark.asyncio
async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):
history = {
_UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}},
_UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}},
}
queue = FakePromptQueue(history=history)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]})
# All known but terminal: 200 with cancelled=false, nothing dispatched.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_batch_missing_job_ids_is_400(self, aiohttp_client):
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={})
assert resp.status == 400
@pytest.mark.asyncio
async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client):
"""An unhashable element such as a dict or list must yield 400, not 500.
Previously, passing e.g. {"job_ids": [{}]} would reach the classify
loop where ``prompt_id in history`` raises TypeError on an unhashable
type, resulting in an unhandled 500. The input-validation guard must
catch this before any queue or history access.
"""
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]})
assert resp.status == 400
body = await resp.json()
assert "invalid_ids" in body
# No queue side effects.
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client):
"""A string that is not a valid UUID must be rejected with 400."""
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post(
"/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]}
)
assert resp.status == 400
body = await resp.json()
assert "invalid_ids" in body