Merge branch 'master' into alexis/add_output_save_nodes
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
Alexis Rolland 2026-06-20 12:54:10 +08:00 committed by GitHub
commit ef4ee7b1d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 3012 additions and 237 deletions

View File

@ -140,7 +140,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
- Commits outside of the stable release tags may be very unstable and break many custom nodes. - Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release - Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** 2. **[Comfy Desktop](https://github.com/Comfy-Org/Comfy-Desktop)**
- Builds a new release using the latest stable core version - Builds a new release using the latest stable core version
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)** 3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**

View File

@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
import logging import logging
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
import comfy.utils
import comfy.conds
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.model_base import BaseModel from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
@ -51,12 +53,18 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC): class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
self.index_list = index_list self.index_list = index_list
self.context_length = len(index_list) self.context_length = len(index_list)
self.context_overlap = context_overlap
self.dim = dim self.dim = dim
self.total_frames = total_frames self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
self.guide_frames_indices: list[int] = []
self.guide_overlap_info: list[tuple[int, int]] = []
self.guide_kf_local_positions: list[int] = []
self.guide_downscale_factors: list[int] = []
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None: if dim is None:
@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC):
region_idx = int(self.center_ratio * num_regions) region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1) return min(max(region_idx, 0), num_regions - 1)
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
if modality_idx == 0:
return self
return self.modality_windows[modality_idx]
class IndexListCallbacks: class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
return cond_value._copy_with(sliced) return cond_value._copy_with(sliced)
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
"""Compute which concatenated guide frames overlap with a context window.
Each guide's latent-space start is derived from its first token's pixel-t-start
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
model's temporal_downscale_ratio.
Args:
guide_entries: list of guide_attention_entry dicts
keyframe_idxs: per-token pixel coords cond tensor for the modality
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
window_index_list: the window's frame indices into the video portion
Returns:
suffix_indices: indices into the guide_frames tensor for frame selection
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
total_overlap: total number of overlapping guide frames
"""
window_set = set(window_index_list)
window_list = list(window_index_list)
suffix_indices = []
overlap_info = []
kf_local_positions = []
suffix_base = 0
token_offset = 0
for entry_idx, entry in enumerate(guide_entries):
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
guide_len = entry["latent_shape"][0]
entry_overlap = 0
for local_offset in range(guide_len):
video_pos = latent_start + local_offset
if video_pos in window_set:
suffix_indices.append(suffix_base + local_offset)
kf_local_positions.append(window_list.index(video_pos))
entry_overlap += 1
if entry_overlap > 0:
overlap_info.append((entry_idx, entry_overlap))
suffix_base += guide_len
token_offset += entry["pre_filter_count"]
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
@dataclass
class WindowingState:
"""Per-modality context windowing state for each step,
built using IndexListContextHandler._build_window_state().
For non-multimodal models the lists are length 1
"""
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
dim: int = 0 # primary modality temporal dim for context windowing
is_multimodal: bool = False
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
"""Reformat window for multimodal contexts by deriving per-modality index lists.
Non-multimodal contexts return the input window unchanged.
"""
if not self.is_multimodal:
return window
x = self.latents[0]
primary_total = self.latent_shapes[0][self.dim]
primary_overlap = window.context_overlap
map_shapes = self.latent_shapes
if x.size(self.dim) != primary_total:
map_shapes = list(self.latent_shapes)
video_shape = list(self.latent_shapes[0])
video_shape[self.dim] = x.size(self.dim)
map_shapes[0] = torch.Size(video_shape)
try:
per_modality_indices = model.map_context_window_to_modalities(
window.index_list, map_shapes, self.dim)
except AttributeError:
raise NotImplementedError(
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
modality_windows = {}
for mod_idx in range(1, len(self.latents)):
modality_total_frames = self.latents[mod_idx].shape[self.dim]
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
modality_overlap = max(round(primary_overlap * ratio), 0)
modality_windows[mod_idx] = IndexListContextWindow(
per_modality_indices[mod_idx], dim=self.dim,
total_frames=modality_total_frames,
context_overlap=modality_overlap)
return IndexListContextWindow(
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
modality_windows=modality_windows, context_overlap=primary_overlap)
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
"""Slice latents for a context window, injecting guide frames where applicable.
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
"""
sliced = []
guide_frame_counts = []
for idx in range(len(self.latents)):
modality_window = window.get_window_for_modality(idx)
retain = retain_index_list if idx == 0 else []
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
if self.guide_entries[idx] is not None:
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
else:
ng = 0
sliced.append(s)
guide_frame_counts.append(ng)
return sliced, guide_frame_counts
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
for idx in range(len(self.latents)):
if guide_frame_counts[idx] > 0:
window_len = len(window.get_window_for_modality(idx).index_list)
for ci in range(len(out_per_modality)):
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
guide_entries = self.guide_entries[modality_idx]
guide_frames = self.guide_latents[modality_idx]
keyframe_idxs = self.keyframe_idxs[modality_idx]
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
kf_local_pos = [p + 1 for p in kf_local_pos]
window.guide_frames_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
guide_downscale_factors = []
if guide_frame_count > 0:
full_H = guide_frames.shape[3]
for entry_idx, _ in overlap_info:
entry_H = guide_entries[entry_idx]["latent_shape"][1]
guide_downscale_factors.append(full_H // entry_H)
window.guide_downscale_factors = guide_downscale_factors
if guide_frame_count > 0:
idx = tuple([slice(None)] * self.dim + [suffix_idx])
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
return latent_slice, 0
def patch_latent_shapes(self, sub_conds, new_shapes):
if not self.is_multimodal:
return
for cond_list in sub_conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
@dataclass @dataclass
class ContextSchedule: class ContextSchedule:
name: str name: str
@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
class IndexListContextHandler(ContextHandlerABC): class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
causal_window_fix: bool=True): latent_retain_index_list: list[int]=[], causal_window_fix: bool=True):
self.context_schedule = context_schedule self.context_schedule = context_schedule
self.fuse_method = fuse_method self.fuse_method = fuse_method
self.context_length = context_length self.context_length = context_length
@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC):
self.freenoise = freenoise self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows self.split_conds_to_windows = split_conds_to_windows
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
self.causal_window_fix = causal_window_fix self.causal_window_fix = causal_window_fix
self.callbacks = {} self.callbacks = {}
@staticmethod
def _get_latent_shapes(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
return model_conds['latent_shapes'].cond
return None
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
@staticmethod
def _get_keyframe_idxs(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
kf = model_conds.get('keyframe_idxs')
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
return kf.cond
return None
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
If guide frames are present on the primary modality, only the video portion is shuffled.
"""
guide_entries = self._get_guide_entries(conds)
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
latent_shapes = self._get_latent_shapes(conds)
if latent_shapes is not None and len(latent_shapes) > 1:
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
primary_total = latent_shapes[0][self.dim]
primary_video_count = modalities[0].size(self.dim) - guide_count
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
for i in range(1, len(modalities)):
mod_total = latent_shapes[i][self.dim]
ratio = mod_total / primary_total if primary_total > 0 else 1
mod_ctx_len = max(round(self.context_length * ratio), 1)
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
noise, _ = comfy.utils.pack_latents(modalities)
return noise
video_count = noise.size(self.dim) - guide_count
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
return noise
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
latent_shapes = self._get_latent_shapes(conds)
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
unpacked_latents_list = list(unpacked_latents)
guide_latents_list = [None] * len(unpacked_latents)
guide_entries_list = [None] * len(unpacked_latents)
keyframe_idxs_list = [None] * len(unpacked_latents)
extracted_guide_entries = self._get_guide_entries(conds)
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
# Strip guide frames (only from first modality for now)
if extracted_guide_entries is not None:
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
if guide_count > 0:
x = unpacked_latents[0]
latent_count = x.size(self.dim) - guide_count
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
guide_entries_list[0] = extracted_guide_entries
keyframe_idxs_list[0] = extracted_keyframe_idxs
return WindowingState(
latents=unpacked_latents_list,
guide_latents=guide_latents_list,
guide_entries=guide_entries_list,
keyframe_idxs=keyframe_idxs_list,
latent_shapes=latent_shapes,
dim=self.dim,
is_multimodal=is_multimodal,
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
if x_in.size(self.dim) > self.context_length: total_frame_count = window_state.latents[0].size(self.dim)
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") if total_frame_count > self.context_length:
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
if self.cond_retain_index_list: if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
if self.latent_retain_index_list:
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
return True return True
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
return False return False
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]): def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
current_timestep = timestep[0].to(sample_sigmas.dtype)
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
matches = torch.nonzero(mask) matches = torch.nonzero(mask)
if torch.numel(matches) == 0: if torch.numel(matches) == 0:
return # substep from multi-step sampler: keep self._step from the last full step return # substep from multi-step sampler: keep self._step from the last full step
@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options) context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
return context_windows return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model self._model = model
self.set_step(timestep, model_options) self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
conds_final = [torch.zeros_like(x_in) for _ in conds] window_state = self._build_window_state(x_in, conds, model)
num_modalities = len(window_state.latents)
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
enumerated_context_windows = list(enumerate(context_windows))
total_windows = len(enumerated_context_windows)
# Initialize per-modality accumulators (length 1 for single-modality)
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
if self.fuse_method.name == ContextFuseMethods.RELATIVE: if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
else: else:
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options) callback(self, model, x_in, conds, timestep, model_options)
# accumulate results from each context window
for enum_window in enumerated_context_windows: for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) results = self.evaluate_context_windows(
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
model_options, window_state=window_state, total_windows=total_windows)
for result in results: for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, # result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
conds_final, counts_final, biases_final) for mod_idx in range(num_modalities):
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
modality_window = result.window.get_window_for_modality(mod_idx)
self.combine_context_window_results(
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
result.window_idx, total_windows, timestep,
accum[mod_idx], counts[mod_idx], biases[mod_idx])
# fuse accumulated results into final conds
try: try:
# finalize conds result_out = []
if self.fuse_method.name == ContextFuseMethods.RELATIVE: for ci in range(len(conds)):
# relative is already normalized, so return as is finalized = []
del counts_final for mod_idx in range(num_modalities):
return conds_final if self.fuse_method.name != ContextFuseMethods.RELATIVE:
accum[mod_idx][ci] /= counts[mod_idx][ci]
f = accum[mod_idx][ci]
# if guide frames were injected, append them to the end of the fused latents for the next step
if window_state.guide_latents[mod_idx] is not None:
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
finalized.append(f)
# pack modalities together if needed
if window_state.is_multimodal and len(finalized) > 1:
packed, _ = comfy.utils.pack_latents(finalized)
else: else:
# normalize conds via division by context usage counts packed = finalized[0]
for i in range(len(conds_final)):
conds_final[i] /= counts_final[i] result_out.append(packed)
del counts_final return result_out
return conds_final
finally: finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options) callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
model_options, device=None, first_device=None): timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, window_state: WindowingState, total_windows: int = None,
device=None, first_device=None):
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
For each window:
1. Builds windows (for each modality if multimodal)
2. Slices window for each modality
3. Injects concatenated latent guide frames where present
4. Packs together if needed and calls model
5. Unpacks and strips any guides from outputs
"""
x = window_state.latents[0]
results: list[ContextResults] = [] results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows: for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel # allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward # prepare the window accounting for multimodal windows
window = window_state.prepare_window(window, model)
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward.
# Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up.
anchor_applied = False anchor_applied = False
if self.causal_window_fix: if self.causal_window_fix:
anchor_idx = window.index_list[0] - 1 anchor_idx = window.index_list[0] - 1
@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC):
window.causal_anchor_index = anchor_idx window.causal_anchor_index = anchor_idx
anchor_applied = True anchor_applied = True
# slice the window for each modality, injecting guide frames where applicable
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
)
# if multimodal, pack modalities together
if window_state.is_multimodal and len(sliced) > 1:
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
else:
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
# get resized conds for window
model_options["transformer_options"]["context_window"] = window model_options["transformer_options"]["context_window"] = window
# get subsections of x, timestep, conds sub_timestep = window.get_tensor(timestep, dim=0)
sub_x = window.get_tensor(x_in, device) sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
sub_timestep = window.get_tensor(timestep, device, dim=0)
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
# if multimodal, patch latent_shapes in conds for correct unpacking in model
window_state.patch_latent_shapes(sub_conds, sub_shapes)
# call model on window
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
if device is not None:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
# strip causal_window_fix anchor if applied # unpack outputs
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
# strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct
if anchor_applied: if anchor_applied:
for i in range(len(sub_conds_out)): for ci in range(len(out_per_modality)):
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1) t = out_per_modality[ci][0]
out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1)
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) # strip injected guide frames
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
return results return results
@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC):
biases_final[i][idx] = bias_total + bias biases_final[i][idx] = bias_total + bias
else: else:
# add conds and counts based on weights of fuse method # add conds and counts based on weights of fuse method
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
for i in range(len(sub_conds_out)): for i in range(len(sub_conds_out)):
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
# limit noise_shape length to context_length for more accurate vram use estimation # Scale noise_shape to a single context window so VRAM estimation budgets per-window.
model_options = kwargs.get("model_options", None) model_options = kwargs.get("model_options", None)
if model_options is None: if model_options is None:
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None) handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is not None: if handler is not None:
noise_shape = list(noise_shape) noise_shape = list(noise_shape)
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
if is_packed:
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
pass
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, *args, **kwargs) return executor(model, noise_shape, conds, *args, **kwargs)
def create_prepare_sampling_wrapper(model: ModelPatcher): def create_prepare_sampling_wrapper(model: ModelPatcher):
@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise: if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher): def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key( model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
_sampler_sample_wrapper _sampler_sample_wrapper
) )
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape) total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device) weights_tensor = torch.Tensor(weights).to(device=device)
@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
return ContextSchedule(context_schedule, func) return ContextSchedule(context_schedule, func)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) context_overlap = handler.context_overlap if context_overlap is None else context_overlap
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
def create_weights_flat(length: int, **kwargs) -> list[float]: def create_weights_flat(length: int, **kwargs) -> list[float]:
@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence return weight_sequence
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
# only expected overlap is given different weights # only expected overlap is given different weights
weights_torch = torch.ones((length)) weights_torch = torch.ones((length))
# blend left-side on all except first window # blend left-side on all except first window
if min(idxs) > 0: if min(idxs) > 0:
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) ramp_up = torch.linspace(1e-37, 1, context_overlap)
weights_torch[:handler.context_overlap] = ramp_up weights_torch[:context_overlap] = ramp_up
# blend right-side on all except last window # blend right-side on all except last window
if max(idxs) < full_length-1: if max(idxs) < full_length-1:
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) ramp_down = torch.linspace(1, 1e-37, context_overlap)
weights_torch[-handler.context_overlap:] = ramp_down weights_torch[-context_overlap:] = ramp_down
return weights_torch return weights_torch
class ContextFuseMethods: class ContextFuseMethods:

321
comfy/ldm/boogu/model.py Normal file
View File

@ -0,0 +1,321 @@
# Boogu-Image-0.1 transformer
# Architecture is an OmniGen2 derivative (see comfy/ldm/omnigen/omnigen2.py) with an
# added dual-stream ("double_stream") stage before the single-stream layers, conditioned
# by a Qwen3-VL multimodal LLM. Reuses the OmniGen2/Lumina building blocks and the Flux
# RoPE core, the only new component is the double-stream block + the hybrid forward order.
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
import comfy.ldm.common_dit
import comfy.ldm.omnigen.omnigen2
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.omnigen.omnigen2 import (
OmniGen2RotaryPosEmbed,
Lumina2CombinedTimestepCaptionEmbedding,
LuminaRMSNormZero,
LuminaLayerNormContinuous,
LuminaFeedForward,
Attention,
OmniGen2TransformerBlock,
apply_rotary_emb,
)
class BooguDoubleStreamProcessor(nn.Module):
# Joint attention over [instruct ; img] with separate per-stream q/k/v and output projections.
def __init__(self, dim, head_dim, heads, kv_heads, dtype=None, device=None, operations=None):
super().__init__()
query_dim = head_dim * heads
kv_dim = head_dim * kv_heads
self.img_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.img_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.img_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.img_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
def forward(self, attn, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}):
batch_size = img_hidden_states.shape[0]
L_instruct = instruct_hidden_states.shape[1]
img_q = self.img_to_q(img_hidden_states)
img_k = self.img_to_k(img_hidden_states)
img_v = self.img_to_v(img_hidden_states)
instruct_q = self.instruct_to_q(instruct_hidden_states)
instruct_k = self.instruct_to_k(instruct_hidden_states)
instruct_v = self.instruct_to_v(instruct_hidden_states)
# Concatenate instruction first, then image (matches reference processor order).
query = torch.cat([instruct_q, img_q], dim=1)
key = torch.cat([instruct_k, img_k], dim=1)
value = torch.cat([instruct_v, img_v], dim=1)
query = query.view(batch_size, -1, attn.heads, attn.dim_head)
key = key.view(batch_size, -1, attn.kv_heads, attn.dim_head)
value = value.view(batch_size, -1, attn.kv_heads, attn.dim_head)
query = attn.norm_q(query)
key = attn.norm_k(key)
if rotary_emb is not None:
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if attn.kv_heads < attn.heads:
key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
# Split back to instruction/image, apply per-stream output projections, recombine.
instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct])
img_hidden_states = self.img_out(hidden_states[:, L_instruct:])
hidden_states = torch.cat([instruct_hidden_states, img_hidden_states], dim=1)
hidden_states = attn.to_out[0](hidden_states)
return hidden_states
class BooguJointAttention(nn.Module):
# Holds the shared q/k RMSNorm + final output projection
def __init__(self, dim, head_dim, heads, kv_heads, eps=1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.dim_head = head_dim
self.scale = head_dim ** -0.5
self.norm_q = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(heads * head_dim, dim, bias=False, dtype=dtype, device=device),
nn.Dropout(0.0),
)
self.processor = BooguDoubleStreamProcessor(dim, head_dim, heads, kv_heads, dtype=dtype, device=device, operations=operations)
def forward(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}):
return self.processor(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask, transformer_options=transformer_options)
class BooguDoubleStreamBlock(nn.Module):
# Dual-stream block: joint attention over [instruct ; img] + image self-attention, each stream with its own modulation/MLP.
def __init__(self, dim, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=None, device=None, operations=None):
super().__init__()
head_dim = dim // num_attention_heads
self.img_instruct_attn = BooguJointAttention(dim, head_dim, num_attention_heads, num_kv_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
self.img_self_attn = Attention(
query_dim=dim, dim_head=head_dim, heads=num_attention_heads, kv_heads=num_kv_heads,
eps=1e-5, bias=False, dtype=dtype, device=device, operations=operations,
)
self.img_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations)
self.instruct_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations)
self.img_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_norm3 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.instruct_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.instruct_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_self_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, img_hidden_states, instruct_hidden_states, joint_rotary_emb, img_rotary_emb, temb, joint_attention_mask=None, img_attention_mask=None, transformer_options={}):
L_instruct = instruct_hidden_states.shape[1]
img_norm1_out, img_gate_msa, img_scale_mlp, img_gate_mlp = self.img_norm1(img_hidden_states, temb)
img_norm2_out, img_shift_mlp, _, _ = self.img_norm2(img_hidden_states, temb)
img_norm3_out, img_gate_self, _, _ = self.img_norm3(img_hidden_states, temb)
instruct_norm1_out, instruct_gate_msa, instruct_scale_mlp, instruct_gate_mlp = self.instruct_norm1(instruct_hidden_states, temb)
instruct_norm2_out, instruct_shift_mlp, _, _ = self.instruct_norm2(instruct_hidden_states, temb)
joint_attn_out = self.img_instruct_attn(img_norm1_out, instruct_norm1_out, joint_rotary_emb, joint_attention_mask, transformer_options=transformer_options)
instruct_attn_out = joint_attn_out[:, :L_instruct]
img_attn_out = joint_attn_out[:, L_instruct:]
img_self_attn_out = self.img_self_attn(img_norm3_out, img_norm3_out, img_attention_mask, img_rotary_emb, transformer_options=transformer_options)
img_hidden_states = img_hidden_states + img_gate_msa.unsqueeze(1).tanh() * self.img_attn_norm(img_attn_out)
img_hidden_states = img_hidden_states + img_gate_self.unsqueeze(1).tanh() * self.img_self_attn_norm(img_self_attn_out)
img_mlp_input = (1 + img_scale_mlp.unsqueeze(1)) * img_norm2_out + img_shift_mlp.unsqueeze(1)
img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_mlp_input))
img_hidden_states = img_hidden_states + img_gate_mlp.unsqueeze(1).tanh() * self.img_ffn_norm2(img_mlp_out)
instruct_hidden_states = instruct_hidden_states + instruct_gate_msa.unsqueeze(1).tanh() * self.instruct_attn_norm(instruct_attn_out)
instruct_mlp_input = (1 + instruct_scale_mlp.unsqueeze(1)) * instruct_norm2_out + instruct_shift_mlp.unsqueeze(1)
instruct_mlp_out = self.instruct_feed_forward(self.instruct_ffn_norm1(instruct_mlp_input))
instruct_hidden_states = instruct_hidden_states + instruct_gate_mlp.unsqueeze(1).tanh() * self.instruct_ffn_norm2(instruct_mlp_out)
return img_hidden_states, instruct_hidden_states
class BooguTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 3360,
num_layers: int = 32,
num_double_stream_layers: int = 8,
num_refiner_layers: int = 2,
num_attention_heads: int = 28,
num_kv_heads: int = 7,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (40, 40, 40),
axes_lens: Tuple[int, int, int] = (2048, 1664, 1664),
instruction_feat_dim: int = 4096,
timestep_scale: float = 1000.0,
image_model=None,
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.hidden_size = hidden_size
self.dtype = dtype
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=instruction_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.double_stream_layers = nn.ModuleList([
BooguDoubleStreamBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=dtype, device=device, operations=operations)
for _ in range(num_double_stream_layers)
])
self.single_stream_layers = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_layers)
])
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
)
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
# Patchify/refine helpers are identical to OmniGen2; reuse via bound methods.
flat_and_pad_to_seq = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.flat_and_pad_to_seq
img_patch_embed_and_refine = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.img_patch_embed_and_refine
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context
text_attention_mask = attention_mask
ref_image_hidden_states = ref_latents
device = hidden_states.device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
rotary_emb, encoder_seq_lengths, seq_lengths,
) = self.rope_embedder(
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes, device,
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
transformer_options=transformer_options,
)
# Double-stream stage: the image self-attention only sees the [ref ; noise] tokens,
# which sit after the instruction tokens in the joint rope.
L_instruct = text_hidden_states.shape[1]
combined_img_rotary_emb = rotary_emb[:, L_instruct:]
for layer in self.double_stream_layers:
combined_img_hidden_states, text_hidden_states = layer(
combined_img_hidden_states, text_hidden_states,
rotary_emb, combined_img_rotary_emb, temb,
joint_attention_mask=None, img_attention_mask=None,
transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
for layer in self.single_stream_layers:
hidden_states = layer(hidden_states, None, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
p = self.patch_size
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded // p, p1=p, p2=p)[:, :, :H, :W]
return -output

View File

@ -515,7 +515,7 @@ class Block(nn.Module):
h=H, h=H,
w=W, w=W,
) )
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
def _x_fn( def _x_fn(
_x_B_T_H_W_D: torch.Tensor, _x_B_T_H_W_D: torch.Tensor,
@ -548,7 +548,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
normalized_x_B_T_H_W_D = _fn( normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D, x_B_T_H_W_D,
@ -557,7 +557,7 @@ class Block(nn.Module):
shift_mlp_B_T_1_1_D, shift_mlp_B_T_1_1_D,
) )
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
return x_B_T_H_W_D return x_B_T_H_W_D

View File

@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel):
) )
grid_mask = None grid_mask = None
if keyframe_idxs is not None: if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
additional_args.update({ "orig_patchified_shape": list(x.shape)}) additional_args.update({ "orig_patchified_shape": list(x.shape)})
denoise_mask = self.patchifier.patchify(denoise_mask)[0] denoise_mask = self.patchifier.patchify(denoise_mask)[0]
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel):
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
x = self.proj_out(x) x = self.proj_out(x)
if keyframe_idxs is not None: if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
grid_mask = kwargs["grid_mask"] grid_mask = kwargs["grid_mask"]
orig_patchified_shape = kwargs["orig_patchified_shape"] orig_patchified_shape = kwargs["orig_patchified_shape"]
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)

View File

@ -22,7 +22,7 @@ def apply_rotary_emb(x, freqs_cis):
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return F.silu(x) * y return F.silu(x, inplace=True).mul_(y)
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):

View File

@ -1665,7 +1665,7 @@ class SCAILWanModel(WanModel):
# embeddings # embeddings
x = self.patch_embedding(x.float()).to(x.dtype) x = self.patch_embedding(x.float()).to(x.dtype)
if ref_mask_latents is not None: # SCAIL-2 additive mask stream if ref_mask_latents is not None: # SCAIL-2 additive mask stream (one identity mask frame per reference, then video)
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype) x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
grid_sizes = x.shape[2:] grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes transformer_options["grid_sizes"] = grid_sizes
@ -1728,22 +1728,25 @@ class SCAILWanModel(WanModel):
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode, # ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset. # which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
# reference_latent may stack several frames: the last is the primary reference adjacent to the video, the earlier frames are additional references.
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
if ref_mask_flag is not None and not bool(ref_mask_flag): if ref_mask_flag is not None and not bool(ref_mask_flag):
REF_ROPE_H = 120.0 REF_ROPE_H = 120.0
POSE_ROPE_W = 120.0 POSE_ROPE_W = 120.0
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
main_t_patches = t - ref_t_patches main_t_patches = t - ref_t_patches
video_t_start = max(ref_t_patches - 1, 0)
parts = [] parts = []
if ref_t_patches > 0: if ref_t_patches > 0:
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}} ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf)) parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
if main_t_patches > 0: if main_t_patches > 0:
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options)) parts.append(super().rope_encode(main_t_patches, h, w, t_start=video_t_start, device=device, dtype=dtype, transformer_options=transformer_options))
if pose_latents is not None: if pose_latents is not None:
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
@ -1752,7 +1755,7 @@ class SCAILWanModel(WanModel):
h_shift = (h_scale - 1) / 2 h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2 w_shift = (w_scale - 1) / 2
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}} pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf)) parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=video_t_start, device=device, dtype=dtype, transformer_options=pose_tf))
return torch.cat(parts, dim=1) return torch.cat(parts, dim=1)
@ -1761,10 +1764,6 @@ class SCAILWanModel(WanModel):
if pose_latents is None: if pose_latents is None:
return main_freqs return main_freqs
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames

View File

@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch import torch
import logging import logging
import comfy.ldm.lightricks.av_model import comfy.ldm.lightricks.av_model
import comfy.ldm.lightricks.symmetric_patchifier
import comfy.context_windows import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_c import StageC
@ -54,6 +55,7 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.boogu.model
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
import comfy.ldm.ideogram4.model import comfy.ldm.ideogram4.model
import comfy.ldm.kandinsky5.model import comfy.ldm.kandinsky5.model
@ -1203,6 +1205,127 @@ class LTXAV(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image return latent_image
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
result = [primary_indices]
if len(latent_shapes) < 2:
return result
video_total = latent_shapes[0][dim]
for i in range(1, len(latent_shapes)):
mod_total = latent_shapes[i][dim]
# Map each primary index to its proportional range of modality indices and
# concatenate in order. Preserves wrapped/strided geometry so the modality
# attends to the same temporal regions as the primary window.
mod_indices = []
seen = set()
for v_idx in primary_indices:
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
if a_end <= a_start:
a_end = a_start + 1
for a in range(a_start, a_end):
if a not in seen:
seen.add(a)
mod_indices.append(a)
result.append(mod_indices)
return result
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# Audio denoise mask — slice using audio modality window
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
audio_window = window.modality_windows.get(1)
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
return cond_value._copy_with(sliced)
# Video denoise mask — split into video + guide portions, slice each
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
cond_tensor = cond_value.cond
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
if guide_count > 0:
T_video = x_in.size(window.dim)
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
suffix_indices = window.guide_frames_indices
if suffix_indices:
idx = tuple([slice(None)] * window.dim + [suffix_indices])
sliced_guide = guide_mask[idx].to(device)
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
else:
return cond_value._copy_with(sliced_video)
# Keyframe indices — regenerate pixel coords for window, select guide positions
if cond_key == "keyframe_idxs":
kf_local_pos = window.guide_kf_local_positions
if not kf_local_pos:
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
H, W = x_in.shape[3], x_in.shape[4]
window_len = len(window.index_list)
# account for causal_window_fix anchor in coord space size
anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
window_len += 1
patchifier = self.diffusion_model.patchifier
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
scale_factors = self.diffusion_model.vae_scale_factors
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
latent_coords,
scale_factors,
causal_fix=self.diffusion_model.causal_temporal_positioning)
tokens = []
for pos in kf_local_pos:
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
pixel_coords = pixel_coords[:, :, tokens, :]
# Adjust spatial end positions for dilated (downscaled) guides.
# Each guide entry may have a different downscale factor; expand the
# per-entry factor to cover all tokens belonging to that entry.
downscale_factors = window.guide_downscale_factors
overlap_info = window.guide_overlap_info
if downscale_factors:
per_token_factor = []
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
per_token_factor.extend([dsf] * (overlap_count * H * W))
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
).view(1, -1, 1, 1)
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
B = cond_value.cond.shape[0]
if B > 1:
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
return cond_value._copy_with(pixel_coords)
# Guide attention entries — adjust per-guide counts based on window overlap
if cond_key == "guide_attention_entries":
overlap_info = window.guide_overlap_info
H, W = x_in.shape[3], x_in.shape[4]
new_entries = []
for entry_idx, overlap_count in overlap_info:
e = cond_value.cond[entry_idx]
new_entries.append({**e,
"pre_filter_count": overlap_count * H * W,
"latent_shape": [overlap_count, H, W]})
return cond_value._copy_with(new_entries)
return None
class HunyuanVideo(BaseModel): class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@ -1747,10 +1870,14 @@ class WAN21_SCAIL(WAN21):
reference_latents = kwargs.get("reference_latents", None) reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None: if reference_latents is not None:
ref_latent = self.process_latent_in(reference_latents[-1]) # SCAIL-2 multi-reference: reference_latents[0] is the primary ref, [1:] are additional
ref_mask = torch.ones_like(ref_latent[:, :4]) # references. Stack as [additional..., primary] so the primary stays adjacent to the video.
ref_latent = torch.cat([ref_latent, ref_mask], dim=1) ordered = list(reference_latents[1:]) + list(reference_latents[:1])
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) stacked = []
for lat in ordered:
lat = self.process_latent_in(lat)
stacked.append(torch.cat([lat, torch.ones_like(lat[:, :4])], dim=1))
out['reference_latent'] = comfy.conds.CONDRegular(torch.cat(stacked, dim=2))
pose_latents = kwargs.get("pose_video_latent", None) pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None: if pose_latents is not None:
@ -1792,6 +1919,7 @@ class WAN21_SCAIL2(WAN21_SCAIL):
if driving_mask_28ch is not None: if driving_mask_28ch is not None:
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous()) out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
# ref_mask_28ch holds one identity mask per stacked reference frame (additional refs first, then the primary ref), followed by zeros over the video frames.
ref_mask_28ch = kwargs.get("ref_mask_28ch", None) ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
if ref_mask_28ch is not None: if ref_mask_28ch is not None:
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous()) out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
@ -1819,10 +1947,11 @@ class WAN21_SCAIL2(WAN21_SCAIL):
# Return sliced view omitting retain_index_list # Return sliced view omitting retain_index_list
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0) return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0)
if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
# The ref mask is just a single frame padded with frames of zeros, so just grab the first frames for all windows # The ref mask is N leading ref frames padded with frames of zeros, so just grab the first frames for all windows
full_ref_mask = cond_value.cond full_ref_mask = cond_value.cond
video_frame_count = x_in.shape[2] video_frame_count = x_in.shape[2]
if full_ref_mask.shape[2] != video_frame_count + 1: ref_frame_count = full_ref_mask.shape[2] - video_frame_count
if ref_frame_count < 1:
return None return None
window_length = len(window.index_list) window_length = len(window.index_list)
@ -1831,7 +1960,7 @@ class WAN21_SCAIL2(WAN21_SCAIL):
if anchor_index is not None and anchor_index >= 0: if anchor_index is not None and anchor_index >= 0:
window_length += 1 window_length += 1
window_ref_mask = full_ref_mask[:, :, :window_length + 1].to(device) window_ref_mask = full_ref_mask[:, :, :window_length + ref_frame_count].to(device)
return cond_value._copy_with(window_ref_mask) return cond_value._copy_with(window_ref_mask)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
@ -2097,6 +2226,11 @@ class Omnigen2(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out return out
class Boogu(Omnigen2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(Omnigen2, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.boogu.model.BooguTransformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
class QwenImage(BaseModel): class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)

View File

@ -761,6 +761,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config return dit_config
if '{}double_stream_layers.0.img_instruct_attn.processor.img_to_q.weight'.format(key_prefix) in state_dict_keys: # Boogu-Image (OmniGen2 derivative + dual-stream stage)
dit_config = {}
dit_config["image_model"] = "boogu"
dit_config["hidden_size"] = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}single_stream_layers.'.format(key_prefix) + '{}.')
dit_config["num_double_stream_layers"] = count_blocks(state_dict_keys, '{}double_stream_layers.'.format(key_prefix) + '{}.')
dit_config["num_refiner_layers"] = count_blocks(state_dict_keys, '{}noise_refiner.'.format(key_prefix) + '{}.')
dit_config["instruction_feat_dim"] = state_dict['{}time_caption_embed.caption_embedder.0.weight'.format(key_prefix)].shape[0]
return dit_config
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2 if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {} dit_config = {}
dit_config["image_model"] = "omnigen2" dit_config["image_model"] = "omnigen2"

View File

@ -68,6 +68,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35 import comfy.text_encoders.qwen35
import comfy.text_encoders.qwen3vl import comfy.text_encoders.qwen3vl
import comfy.text_encoders.boogu
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4 import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo import comfy.text_encoders.cogvideo
@ -1301,6 +1302,7 @@ class CLIPType(Enum):
LENS = 28 LENS = 28
PIXELDIT = 29 PIXELDIT = 29
IDEOGRAM4 = 30 IDEOGRAM4 = 30
BOOGU = 31
@ -1622,6 +1624,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer
elif clip_type == CLIPType.BOOGU and te_model == TEModel.QWEN3VL_8B: # Boogu-Image: full Qwen3-VL-8B, last hidden state, no-think template.
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer
elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused.
klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b"
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type)
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B if te_model == TEModel.QWEN3VL_8B else comfy.text_encoders.flux.KleinTokenizer
else: else:
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model] qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]

View File

@ -25,6 +25,7 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5 import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image import comfy.text_encoders.z_image
import comfy.text_encoders.ideogram4 import comfy.text_encoders.ideogram4
import comfy.text_encoders.boogu
import comfy.text_encoders.anima import comfy.text_encoders.anima
import comfy.text_encoders.ace15 import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
@ -1758,6 +1759,27 @@ class Omnigen2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class Boogu(Omnigen2):
unet_config = {
"image_model": "boogu",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.16,
}
memory_usage_factor = 2.15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Boogu(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.boogu.BooguTokenizer, comfy.text_encoders.boogu.te(**hunyuan_detect))
class Ideogram4(supported_models_base.BASE): class Ideogram4(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "ideogram4", "image_model": "ideogram4",
@ -2300,6 +2322,7 @@ models = [
ACEStep, ACEStep,
ACEStep15, ACEStep15,
Omnigen2, Omnigen2,
Boogu,
QwenImage, QwenImage,
Ideogram4, Ideogram4,
Flux2, Flux2,

View File

@ -0,0 +1,58 @@
"""Boogu-Image text encoder: full Qwen3-VL-8B, last hidden state (4096-dim).
Boogu uses the final hidden state of Qwen3-VL as the per-token instruction feature
(num_instruction_feature_layers=1, reduce_type=mean -> just the last layer).
The model itself is the standard Qwen3-VL TE, only the chat template differs
(a fixed system prompt and no <think> block).
"""
import comfy.text_encoders.qwen3vl
from comfy import sd1_clip
# System prompts from the reference pipeline (pipeline_boogu.py).
# T2I (non-empty instruction, no image) uses the helpful-assistant prompt
# everything else (the CFG negative / "drop" condition, and any image case) uses the TI2I "describe" prompt.
BOOGU_T2I_SYSTEM = "You are a helpful assistant that generates high-quality images based on user instructions. The instructions are as follows."
BOOGU_DROP_SYSTEM = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
class BooguTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b")
# apply_chat_template without add_generation_prompt
self.llama_template = "<|im_start|>system\n" + BOOGU_T2I_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n"
self.llama_template_images = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n"
# Reference SYSTEM_PROMPT_DROP: used for the empty negative/uncond instruction.
self.llama_template_drop = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
if llama_template is None and len(images) == 0 and text.strip() == "":
llama_template = self.llama_template_drop
# Boogu conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds by default.
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
class BooguQwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"):
super().__init__(device=device, dtype=dtype, attention_mask=attention_mask, model_options=model_options, model_type=model_type)
# apply the final RMSNorm to the tapped last layer
self.layer_norm_hidden_state = True
class BooguTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
clip_model = lambda **kw: BooguQwen3VLClipModel(**kw, model_type="qwen3vl_8b")
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None):
class BooguTEModel_(BooguTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return BooguTEModel_

View File

@ -25,6 +25,11 @@ CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
"default": False, "default": False,
"description": "Show the sign-in button in the frontend even when not signed in", "description": "Show the sign-in button in the frontend even when not signed in",
}, },
"enable_telemetry": {
"type": "bool",
"default": False,
"description": "Signal the frontend that telemetry collection is enabled",
},
} }

View File

@ -149,3 +149,59 @@ class MotionControlRequest(BaseModel):
character_orientation: str = Field(...) character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'") mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...) model_name: str = Field(...)
class Kling3TurboSettings(BaseModel):
resolution: str = Field("720p", description="'720p' or '1080p'")
aspect_ratio: str | None = Field(None, description="'16:9'/'9:16'/'1:1'; text-to-video only")
duration: int = Field(5, description="3-15 second")
class Kling3TurboText2VideoRequest(BaseModel):
prompt: str = Field(..., description="<=3072 chars; may use multi-shot 'shot n, m, words; ...'")
settings: Kling3TurboSettings | None = Field(None)
class Kling3TurboContent(BaseModel):
type: str = Field(..., description="'prompt' or 'first_frame'")
text: str | None = Field(None, description="for type=prompt; <=2500 chars")
url: str | None = Field(None, description="for type=first_frame")
class Kling3TurboImage2VideoRequest(BaseModel):
contents: list[Kling3TurboContent] = Field(..., description="prompt + first_frame materials")
settings: Kling3TurboSettings | None = Field(None)
class Kling3TurboCreateData(BaseModel):
id: str | None = Field(None, description="Task ID")
status: str | None = Field(None)
message: str | None = Field(None)
class Kling3TurboCreateResponse(BaseModel):
code: int | None = Field(None)
message: str | None = Field(None)
request_id: str | None = Field(None)
data: Kling3TurboCreateData | None = Field(None)
class Kling3TurboOutput(BaseModel):
type: str | None = Field(None, description="'video', 'image', 'audio', ...")
id: str | None = Field(None)
url: str | None = Field(None)
duration: str | None = Field(None)
class Kling3TurboTaskData(BaseModel):
id: str | None = Field(None)
status: str | None = Field(None, description="submitted | processing | succeeded | failed")
message: str | None = Field(None)
outputs: list[Kling3TurboOutput] | None = Field(None)
class Kling3TurboQueryResponse(BaseModel):
code: int | None = Field(None)
message: str | None = Field(None)
request_id: str | None = Field(None)
data: list[Kling3TurboTaskData] | None = Field(None)

View File

@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, confloat
class LumaIO: class LumaIO:
LUMA_REF = "LUMA_REF" LUMA_REF = "LUMA_REF"
LUMA_CONCEPTS = "LUMA_CONCEPTS" LUMA_CONCEPTS = "LUMA_CONCEPTS"
LUMA_RAY32_KEYFRAME = "LUMA_RAY32_KEYFRAME"
class LumaReference: class LumaReference:
@ -20,13 +21,14 @@ class LumaReference:
def create_api_model(self, download_url: str): def create_api_model(self, download_url: str):
return LumaImageRef(url=download_url, weight=self.weight) return LumaImageRef(url=download_url, weight=self.weight)
class LumaReferenceChain: class LumaReferenceChain:
def __init__(self, first_ref: LumaReference=None): def __init__(self, first_ref: LumaReference = None):
self.refs: list[LumaReference] = [] self.refs: list[LumaReference] = []
if first_ref: if first_ref:
self.refs.append(first_ref) self.refs.append(first_ref)
def add(self, luma_ref: LumaReference=None): def add(self, luma_ref: LumaReference = None):
self.refs.append(luma_ref) self.refs.append(luma_ref)
def create_api_model(self, download_urls: list[str], max_refs=4): def create_api_model(self, download_urls: list[str], max_refs=4):
@ -124,7 +126,7 @@ def get_luma_concepts(include_none=False):
"pull_out", "pull_out",
"aerial", "aerial",
"crane_up", "crane_up",
"eye_level" "eye_level",
] ]
@ -162,8 +164,8 @@ class LumaVideoModelOutputDuration(str, Enum):
class LumaGenerationType(str, Enum): class LumaGenerationType(str, Enum):
video = 'video' video = "video"
image = 'image' image = "image"
class LumaState(str, Enum): class LumaState(str, Enum):
@ -174,86 +176,109 @@ class LumaState(str, Enum):
class LumaAssets(BaseModel): class LumaAssets(BaseModel):
video: Optional[str] = Field(None, description='The URL of the video') video: Optional[str] = Field(None, description="The URL of the video")
image: Optional[str] = Field(None, description='The URL of the image') image: Optional[str] = Field(None, description="The URL of the image")
progress_video: Optional[str] = Field(None, description='The URL of the progress video') progress_video: Optional[str] = Field(None, description="The URL of the progress video")
class LumaImageRef(BaseModel): class LumaImageRef(BaseModel):
"""Used for image gen""" """Used for image gen"""
url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') url: str = Field(..., description="The URL of the image reference")
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
class LumaImageReference(BaseModel): class LumaImageReference(BaseModel):
"""Used for video gen""" """Used for video gen"""
type: Optional[str] = Field('image', description='Input type, defaults to image')
url: str = Field(..., description='The URL of the image') type: Optional[str] = Field("image", description="Input type, defaults to image")
url: str = Field(..., description="The URL of the image")
class LumaModifyImageRef(BaseModel): class LumaModifyImageRef(BaseModel):
url: str = Field(..., description='The URL of the image reference') url: str = Field(..., description="The URL of the image reference")
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
class LumaCharacterRef(BaseModel): class LumaCharacterRef(BaseModel):
identity0: LumaImageIdentity = Field(..., description='The image identity object') identity0: LumaImageIdentity = Field(..., description="The image identity object")
class LumaImageIdentity(BaseModel): class LumaImageIdentity(BaseModel):
images: list[str] = Field(..., description='The URLs of the image identity') images: list[str] = Field(..., description="The URLs of the image identity")
class LumaGenerationReference(BaseModel): class LumaGenerationReference(BaseModel):
type: str = Field('generation', description='Input type, defaults to generation') type: str = Field("generation", description="Input type, defaults to generation")
id: str = Field(..., description='The ID of the generation') id: str = Field(..., description="The ID of the generation")
class LumaKeyframes(BaseModel): class LumaKeyframes(BaseModel):
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
class LumaConceptObject(BaseModel): class LumaConceptObject(BaseModel):
key: str = Field(..., description='Camera Concept name') key: str = Field(..., description="Camera Concept name")
class LumaImageGenerationRequest(BaseModel): class LumaImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The prompt of the generation') prompt: str = Field(..., description="The prompt of the generation")
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation') model: LumaImageModel = Field(LumaImageModel.photon_1, description="The image model used for the generation")
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation') aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9)
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects') image_ref: Optional[list[LumaImageRef]] = Field(None, description="List of image reference objects")
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects') style_ref: Optional[list[LumaImageRef]] = Field(None, description="List of style reference objects")
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object') character_ref: Optional[LumaCharacterRef] = Field(None, description="The image identity object")
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object') modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description="The modify image reference object")
class LumaGenerationRequest(BaseModel): class LumaGenerationRequest(BaseModel):
prompt: str = Field(..., description='The prompt of the generation') prompt: str = Field(..., description="The prompt of the generation")
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation') model: LumaVideoModel = Field(LumaVideoModel.ray_2, description="The video model used for the generation")
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation') duration: Optional[LumaVideoModelOutputDuration] = Field(None, description="The duration of the generation")
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation') aspect_ratio: Optional[LumaAspectRatio] = Field(None, description="The aspect ratio of the generation")
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation') resolution: Optional[LumaVideoOutputResolution] = Field(None, description="The resolution of the generation")
loop: Optional[bool] = Field(None, description='Whether to loop the video') loop: Optional[bool] = Field(None, description="Whether to loop the video")
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation') keyframes: Optional[LumaKeyframes] = Field(None, description="The keyframes of the generation")
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation') concepts: Optional[list[LumaConceptObject]] = Field(None, description="Camera Concepts to apply to generation")
class LumaGeneration(BaseModel): class LumaGeneration(BaseModel):
id: str = Field(..., description='The ID of the generation') id: str = Field(..., description="The ID of the generation")
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video') generation_type: LumaGenerationType = Field(..., description="Generation type, image or video")
state: LumaState = Field(..., description='The state of the generation') state: LumaState = Field(..., description="The state of the generation")
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation') failure_reason: Optional[str] = Field(None, description="The reason for the state of the generation")
created_at: str = Field(..., description='The date and time when the generation was created') created_at: str = Field(..., description="The date and time when the generation was created")
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') assets: Optional[LumaAssets] = Field(None, description="The assets of the generation")
model: str = Field(..., description='The model used for the generation') model: str = Field(..., description="The model used for the generation")
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(...)
class Luma2ImageRef(BaseModel): class Luma2ImageRef(BaseModel):
url: str | None = None url: str | None = None
data: str | None = None data: str | None = None
media_type: str | None = None media_type: str | None = None
generation_id: str | None = Field(None, description="reference a prior generation (extend / source reuse)")
class Luma2VideoEdit(BaseModel):
"""Edit controls for Ray 3.2 ``video_edit`` generations."""
auto_controls: bool | None = Field(None, description="derive a conditioning schedule from the source (recommended)")
strength: str | None = Field(None, description="'adhere_1' .. 'reimagine_3'; constrained by IO.Combo")
class Luma2VideoOptions(BaseModel):
"""Ray 3.2 ``video`` output settings (text / image / keyframe / edit / extend)."""
resolution: str | None = Field(None, description="360p | 540p | 720p | 1080p")
duration: str | None = Field(None, description="5s | 10s")
loop: bool | None = Field(None)
start_frame: Luma2ImageRef | None = Field(None)
end_frame: Luma2ImageRef | None = Field(None)
keyframes: list[Luma2ImageRef] | None = Field(None)
keyframe_indexes: list[int] | None = Field(None)
edit: Luma2VideoEdit | None = Field(None)
class Luma2GenerationRequest(BaseModel): class Luma2GenerationRequest(BaseModel):
@ -266,6 +291,7 @@ class Luma2GenerationRequest(BaseModel):
web_search: bool | None = None web_search: bool | None = None
image_ref: list[Luma2ImageRef] | None = None image_ref: list[Luma2ImageRef] | None = None
source: Luma2ImageRef | None = None source: Luma2ImageRef | None = None
video: Luma2VideoOptions | None = Field(None)
class Luma2Generation(BaseModel): class Luma2Generation(BaseModel):
@ -277,3 +303,31 @@ class Luma2Generation(BaseModel):
output: list[LumaImageReference] | None = None output: list[LumaImageReference] | None = None
failure_reason: str | None = None failure_reason: str | None = None
failure_code: str | None = None failure_code: str | None = None
# --- Ray 3.2 multi-keyframe chain ---
LUMA_KEYFRAME_MODE_FRACTION = "fraction" # value in [0.0, 1.0] of the output video duration
LUMA_KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the output
class LumaRay32KeyframeItem:
"""One guide image anchored at a position on the Ray 3.2 output timeline."""
def __init__(self, image: torch.Tensor, mode: str, value: float):
self.image = image
self.mode = mode # LUMA_KEYFRAME_MODE_FRACTION | LUMA_KEYFRAME_MODE_SECONDS
self.value = value
class LumaRay32KeyframeChain:
def __init__(self):
self.items: list[LumaRay32KeyframeItem] = []
def add(self, item: LumaRay32KeyframeItem) -> None:
self.items.append(item)
def clone(self) -> "LumaRay32KeyframeChain":
c = LumaRay32KeyframeChain()
c.items = list(self.items)
return c

View File

@ -60,6 +60,12 @@ from comfy_api_nodes.apis.kling import (
OmniProImageRequest, OmniProImageRequest,
OmniProReferences2VideoRequest, OmniProReferences2VideoRequest,
OmniProText2VideoRequest, OmniProText2VideoRequest,
Kling3TurboSettings,
Kling3TurboText2VideoRequest,
Kling3TurboContent,
Kling3TurboImage2VideoRequest,
Kling3TurboCreateResponse,
Kling3TurboQueryResponse,
TaskStatusResponse, TaskStatusResponse,
TextToVideoWithAudioRequest, TextToVideoWithAudioRequest,
) )
@ -2847,6 +2853,67 @@ class MotionControl(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
def build_turbo_shot_prompt(multi_prompt: list[MultiPromptEntry]) -> str:
"""Render storyboard entries into the Turbo multi-shot prompt 'shot n, m, words; ...'."""
return "; ".join(f"shot {i}, {int(e.duration)}, {e.prompt}" for i, e in enumerate(multi_prompt, 1)) + ";"
def _turbo_video_url(response: Kling3TurboQueryResponse) -> str:
"""Extract the result video URL from a /tasks response (data[].outputs[] where type == 'video')."""
task = response.data[0] if response.data else None
if task and task.outputs:
for output in task.outputs:
if output.type == "video" and output.url:
return output.url
raise RuntimeError(f"Kling 3.0 Turbo task finished without a video output: {response.model_dump()}")
async def execute_kling_turbo(
cls: type[IO.ComfyNode],
*,
prompt: str,
resolution: str,
aspect_ratio: str,
duration: int,
start_frame: torch.Tensor | None,
) -> IO.NodeOutput:
"""Create + poll a Kling 3.0 Turbo task. Image-to-video when start_frame is given, else text-to-video."""
if start_frame is not None:
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
contents = [Kling3TurboContent(type="first_frame", url=tensor_to_base64_string(start_frame))]
if prompt:
contents.insert(0, Kling3TurboContent(type="prompt", text=prompt))
create = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/image-to-video/kling-3.0-turbo", method="POST"),
response_model=Kling3TurboCreateResponse,
data=Kling3TurboImage2VideoRequest(
contents=contents,
settings=Kling3TurboSettings(resolution=resolution, duration=duration), # i2v: no aspect_ratio
),
)
else:
create = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/text-to-video/kling-3.0-turbo", method="POST"),
response_model=Kling3TurboCreateResponse,
data=Kling3TurboText2VideoRequest(
prompt=prompt,
settings=Kling3TurboSettings(resolution=resolution, aspect_ratio=aspect_ratio, duration=duration),
),
)
if not (create.data and create.data.id):
raise RuntimeError(f"Kling 3.0 Turbo create failed. Code: {create.code}, Message: {create.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path="/proxy/kling/tasks", query_params={"task_ids": create.data.id}),
response_model=Kling3TurboQueryResponse,
status_extractor=lambda r: (r.data[0].status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(_turbo_video_url(final_response)))
class KlingVideoNode(IO.ComfyNode): class KlingVideoNode(IO.ComfyNode):
@classmethod @classmethod
@ -2884,7 +2951,11 @@ class KlingVideoNode(IO.ComfyNode):
], ],
tooltip="Generate a series of video segments with individual prompts and durations.", tooltip="Generate a series of video segments with individual prompts and durations.",
), ),
IO.Boolean.Input("generate_audio", default=True), IO.Boolean.Input(
"generate_audio",
default=True,
tooltip="'kling-3.0-turbo' always generates native audio, so the audio toggle is ignored.",
),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"model", "model",
options=[ options=[
@ -2899,6 +2970,17 @@ class KlingVideoNode(IO.ComfyNode):
), ),
], ],
), ),
IO.DynamicCombo.Option(
"kling-3.0-turbo",
[
IO.Combo.Input("resolution", options=["1080p", "720p"], default="720p"),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1"],
tooltip="Ignored in image-to-video mode.",
),
],
),
], ],
tooltip="Model and generation settings.", tooltip="Model and generation settings.",
), ),
@ -2930,6 +3012,7 @@ class KlingVideoNode(IO.ComfyNode):
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends( depends_on=IO.PriceBadgeDepends(
widgets=[ widgets=[
"model",
"model.resolution", "model.resolution",
"generate_audio", "generate_audio",
"multi_shot", "multi_shot",
@ -2944,14 +3027,7 @@ class KlingVideoNode(IO.ComfyNode):
), ),
expr=""" expr="""
( (
$rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio);
$ms := widgets.multi_shot; $ms := widgets.multi_shot;
$isSb := $ms != "disabled"; $isSb := $ms != "disabled";
$n := $isSb ? $number($substring($ms, 0, 1)) : 0; $n := $isSb ? $number($substring($ms, 0, 1)) : 0;
@ -2962,8 +3038,19 @@ class KlingVideoNode(IO.ComfyNode):
$d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0; $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0;
$d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0; $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0;
$dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration"); $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration");
widgets.model = "kling-3.0-turbo"
? {"type":"usd","usd": ($res = "1080p" ? 0.14 : 0.112) * $dur}
: (
$rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio);
{"type":"usd","usd": $rate * $dur} {"type":"usd","usd": $rate * $dur}
) )
)
""", """,
), ),
) )
@ -3015,6 +3102,17 @@ class KlingVideoNode(IO.ComfyNode):
duration = multi_shot["duration"] duration = multi_shot["duration"]
validate_string(multi_shot["prompt"], min_length=1, max_length=2500) validate_string(multi_shot["prompt"], min_length=1, max_length=2500)
if model["model"] == "kling-3.0-turbo":
turbo_prompt = build_turbo_shot_prompt(multi_prompt_list) if custom_multi_shot else multi_shot["prompt"]
return await execute_kling_turbo(
cls,
prompt=turbo_prompt,
resolution=model["resolution"],
aspect_ratio=model["aspect_ratio"],
duration=duration,
start_frame=start_frame,
)
if start_frame is not None: if start_frame is not None:
validate_image_dimensions(start_frame, min_width=300, min_height=300) validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))

View File

@ -3,9 +3,13 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.luma import ( from comfy_api_nodes.apis.luma import (
LUMA_KEYFRAME_MODE_FRACTION,
LUMA_KEYFRAME_MODE_SECONDS,
Luma2Generation, Luma2Generation,
Luma2GenerationRequest, Luma2GenerationRequest,
Luma2ImageRef, Luma2ImageRef,
Luma2VideoEdit,
Luma2VideoOptions,
LumaAspectRatio, LumaAspectRatio,
LumaCharacterRef, LumaCharacterRef,
LumaConceptChain, LumaConceptChain,
@ -18,6 +22,8 @@ from comfy_api_nodes.apis.luma import (
LumaIO, LumaIO,
LumaKeyframes, LumaKeyframes,
LumaModifyImageRef, LumaModifyImageRef,
LumaRay32KeyframeChain,
LumaRay32KeyframeItem,
LumaReference, LumaReference,
LumaReferenceChain, LumaReferenceChain,
LumaVideoModel, LumaVideoModel,
@ -33,6 +39,7 @@ from comfy_api_nodes.util import (
sync_op, sync_op,
upload_image_to_comfyapi, upload_image_to_comfyapi,
upload_images_to_comfyapi, upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string, validate_string,
) )
@ -692,7 +699,10 @@ async def _luma2_upload_image_refs(
async def _luma2_submit_and_poll( async def _luma2_submit_and_poll(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
request: Luma2GenerationRequest, request: Luma2GenerationRequest,
) -> Input.Image: *,
estimated_duration: int | None = None,
) -> Luma2Generation:
"""Submit a Luma Agents generation and poll until done; returns the completed generation."""
initial = await sync_op( initial = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"), ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
@ -700,21 +710,21 @@ async def _luma2_submit_and_poll(
data=request, data=request,
) )
if not initial.id: if not initial.id:
raise RuntimeError("Luma 2 API did not return a generation id.") raise RuntimeError("Luma API did not return a generation id.")
final = await poll_op( final = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"), ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
response_model=Luma2Generation, response_model=Luma2Generation,
status_extractor=lambda r: r.state, status_extractor=lambda r: r.state,
progress_extractor=lambda r: None, progress_extractor=lambda r: None,
estimated_duration=estimated_duration,
) )
if not final.output: if not final.output or not final.output[0].url:
msg = final.failure_reason or "no output returned" msg = final.failure_reason or "no output returned"
raise RuntimeError(f"Luma 2 generation failed: {msg}") if final.failure_code:
url = final.output[0].url msg = f"{msg} [{final.failure_code}]"
if not url: raise RuntimeError(f"Luma generation failed: {msg}")
raise RuntimeError("Luma 2 generation completed without an output URL.") return final
return await download_url_to_image_tensor(url)
class LumaImageNode(IO.ComfyNode): class LumaImageNode(IO.ComfyNode):
@ -843,7 +853,8 @@ class LumaImageNode(IO.ComfyNode):
web_search=model["web_search"], web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9), image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
) )
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) final = await _luma2_submit_and_poll(cls, request)
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
class LumaImageEditNode(IO.ComfyNode): class LumaImageEditNode(IO.ComfyNode):
@ -929,7 +940,533 @@ class LumaImageEditNode(IO.ComfyNode):
web_search=model["web_search"], web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8), image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
) )
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) final = await _luma2_submit_and_poll(cls, request)
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
_BADGE_RAY32_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]),
expr="""
(
$p := {
"360p": {"5s": 0.06, "10s": 0.18},
"540p": {"5s": 0.15, "10s": 0.45},
"720p": {"5s": 0.3, "10s": 0.9},
"1080p": {"5s": 1.2, "10s": 3.6}
};
{"type": "usd", "usd": $lookup($lookup($p, widgets.resolution), widgets.duration)}
)
""",
)
_BADGE_RAY32_VIDEO_5S = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {"360p": 0.06, "540p": 0.15, "720p": 0.3, "1080p": 1.2};
{"type": "usd", "usd": $lookup($p, widgets.resolution)}
)
""",
)
_BADGE_RAY32_EDIT = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {
"360p": {"min": 0.54, "max": 1.08},
"540p": {"min": 0.72, "max": 1.44},
"720p": {"min": 1.08, "max": 2.16},
"1080p": {"min": 2.16, "max": 4.32}
};
$r := $lookup($p, widgets.resolution);
{"type": "range_usd", "min_usd": $r.min, "max_usd": $r.max, "format": {"note": "(by source length)"}}
)
""",
)
_BADGE_RAY32_REFRAME = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {"360p": 0.03, "540p": 0.06, "720p": 0.12, "1080p": 0.36};
{"type": "usd", "usd": $lookup($p, widgets.resolution), "format": {"suffix": "/second"}}
)
""",
)
def _ray32_seed_input() -> IO.Input:
return IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; results are nondeterministic regardless of seed.",
)
async def _ray32_generate(cls: type[IO.ComfyNode], request: Luma2GenerationRequest) -> IO.NodeOutput:
"""Run a ray-3.2 generation and return (video, generation_id)."""
final = await _luma2_submit_and_poll(cls, request, estimated_duration=120)
video = await download_url_to_video_output(final.output[0].url)
return IO.NodeOutput(video, final.id or "")
class LumaRay32TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32TextToVideoNode",
display_name="Luma Ray 3.2 Text to Video",
category="partner/video/Luma",
description="Generate a video from a text prompt using Luma's Ray 3.2 model.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input("duration", options=["5s", "10s"]),
IO.Boolean.Input(
"loop",
default=False,
tooltip="Make the video loop seamlessly. Only available with 5s duration.",
),
_ray32_seed_input(),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO,
)
@classmethod
async def execute(
cls, prompt: str, aspect_ratio: str, resolution: str, duration: str, loop: bool, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
if loop and duration == "10s":
raise ValueError("Looping is only available with 5s duration on Ray 3.2.")
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video",
aspect_ratio=aspect_ratio,
video=Luma2VideoOptions(resolution=resolution, duration=duration, loop=loop or None),
)
return await _ray32_generate(cls, request)
class LumaRay32ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32ImageToVideoNode",
display_name="Luma Ray 3.2 Image to Video",
category="partner/video/Luma",
description="Generate a video from a start and/or end frame using Luma's Ray 3.2 model. "
"Image-anchored generations are always 5 seconds.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Boolean.Input(
"loop",
default=False,
tooltip="Make the video loop seamlessly. Not available when an end_frame is set.",
),
_ray32_seed_input(),
IO.Image.Input("start_frame", optional=True, tooltip="First frame of the generated video."),
IO.Image.Input("end_frame", optional=True, tooltip="Last frame of the generated video."),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO_5S,
)
@classmethod
async def execute(
cls,
prompt: str,
resolution: str,
loop: bool,
seed: int,
start_frame: torch.Tensor | None = None,
end_frame: torch.Tensor | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
if start_frame is None and end_frame is None:
raise ValueError("Provide at least one of start_frame / end_frame.")
if loop and end_frame is not None:
raise ValueError("Looping is not available when an end_frame is set.")
video = Luma2VideoOptions(resolution=resolution, duration="5s", loop=loop or None)
if start_frame is not None:
url = await upload_image_to_comfyapi(cls, start_frame, mime_type="image/png")
video.start_frame = Luma2ImageRef(url=url)
if end_frame is not None:
url = await upload_image_to_comfyapi(cls, end_frame, mime_type="image/png")
video.end_frame = Luma2ImageRef(url=url)
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
return await _ray32_generate(cls, request)
class LumaRay32KeyframeNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32KeyframeNode",
display_name="Luma Ray 3.2 Keyframe",
category="partner/video/Luma",
description="Anchor a guide image to a position on the Ray 3.2 output video timeline. Connect this to "
"the 'keyframes' input of the Luma Ray 3.2 Keyframes to Video node; chain several together via the "
"optional 'keyframes' input below.",
inputs=[
IO.Image.Input("image", tooltip="Guide image to place at the chosen moment of the output video."),
IO.DynamicCombo.Input(
"position",
options=[
IO.DynamicCombo.Option(
"Fraction of duration (0.0-1.0)",
[
IO.Float.Input(
"fraction",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
tooltip="Where in the output video this image applies " "(0.0 = start, 1.0 = end).",
),
],
),
IO.DynamicCombo.Option(
"Absolute time (seconds)",
[
IO.Float.Input(
"seconds",
default=0.0,
min=0.0,
max=10.0,
step=0.1,
display_mode=IO.NumberDisplay.number,
tooltip="Time in seconds from the start of the output video where this "
"image applies.",
),
],
),
],
tooltip="How to place this image on the output video's timeline.",
),
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
"keyframes",
optional=True,
tooltip="Optional earlier keyframes to chain with this one.",
),
],
outputs=[IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Output(display_name="keyframes")],
)
@classmethod
def execute(
cls,
image: torch.Tensor,
position: dict,
keyframes: LumaRay32KeyframeChain | None = None,
) -> IO.NodeOutput:
chain = keyframes.clone() if keyframes is not None else LumaRay32KeyframeChain()
if position["position"] == "Absolute time (seconds)":
mode, value = LUMA_KEYFRAME_MODE_SECONDS, float(position["seconds"])
else:
mode, value = LUMA_KEYFRAME_MODE_FRACTION, float(position["fraction"])
chain.add(LumaRay32KeyframeItem(image=image, mode=mode, value=value))
return IO.NodeOutput(chain)
class LumaRay32KeyframesToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32KeyframesToVideoNode",
display_name="Luma Ray 3.2 Keyframes to Video",
category="partner/video/Luma",
description="Generate a video that interpolates through a sequence of guide images, each anchored to a "
"position on the timeline, using Luma Ray 3.2. Build the sequence with Luma Ray 3.2 Keyframe nodes "
"(at least 2).",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input("duration", options=["5s", "10s"]),
_ray32_seed_input(),
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
"keyframes",
tooltip="Keyframe sequence from Luma Ray 3.2 Keyframe nodes (at least 2).",
),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO,
)
@classmethod
async def execute(
cls,
prompt: str,
resolution: str,
duration: str,
seed: int,
keyframes: LumaRay32KeyframeChain | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
items = keyframes.items if keyframes is not None else []
if len(items) < 2:
raise ValueError(
"Connect at least 2 Luma Ray 3.2 Keyframe nodes "
"(use Luma Ray 3.2 Image to Video for a single start/end frame)."
)
if len(items) > 64:
raise ValueError(f"Ray 3.2 supports at most 64 keyframes; got {len(items)}.")
maxframe = 120 if duration == "5s" else 240
duration_seconds = maxframe / 24 # 5.0 or 10.0
# Resolve each keyframe to an output-frame index, then order by position
# (so the user can chain keyframes in any order — the position is what places them)
placed: list[tuple[int, torch.Tensor]] = []
for item in items:
if item.mode == LUMA_KEYFRAME_MODE_SECONDS:
if item.value > duration_seconds:
raise ValueError(
f"Keyframe position {item.value:g}s is past the end of the {duration} video; "
f"use 0-{duration_seconds:g}s (or switch the keyframe to fraction mode)."
)
idx = round(item.value * 24)
else:
idx = round(item.value * maxframe)
placed.append((max(0, min(maxframe, idx)), item.image))
placed.sort(key=lambda p: p[0])
indexes = [idx for idx, _ in placed]
for a, b in zip(indexes, indexes[1:]):
if a == b:
raise ValueError(
f"Two keyframes resolve to the same output frame ({a}) for a {duration} video "
f"(valid range 0-{maxframe}); give each keyframe a distinct position."
)
refs: list[Luma2ImageRef] = []
for _, image in placed:
url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
refs.append(Luma2ImageRef(url=url))
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video",
video=Luma2VideoOptions(resolution=resolution, duration=duration, keyframes=refs, keyframe_indexes=indexes),
)
return await _ray32_generate(cls, request)
class LumaRay32VideoEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32VideoEditNode",
display_name="Luma Ray 3.2 Video Edit",
category="partner/video/Luma",
description="Re-render an existing video under a new prompt using Luma Ray 3.2 (restyle, relight, add "
"or remove elements) while keeping the original motion. Source video up to 18 seconds; the edited "
"video keeps the source's length.",
inputs=[
IO.Video.Input("video", tooltip="Source video to edit. Up to 18 seconds."),
IO.String.Input("prompt", multiline=True, default="", tooltip="Describes the desired edit."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input(
"strength",
options=[
"auto",
"adhere_1",
"adhere_2",
"adhere_3",
"flex_1",
"flex_2",
"flex_3",
"reimagine_1",
"reimagine_2",
"reimagine_3",
],
default="auto",
tooltip="How strongly to preserve vs. reimagine the source. 'auto' lets Ray 3.2 choose; "
"adhere_* preserves the most, flex_* is balanced, reimagine_* changes the most.",
),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_EDIT,
)
@classmethod
async def execute(
cls, video: Input.Video, prompt: str, resolution: str, strength: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
try:
duration = "5s" if video.get_duration() <= 5.0 else "10s"
except Exception:
duration = "10s"
source_url = await upload_video_to_comfyapi(cls, video, max_duration=18)
edit = Luma2VideoEdit(auto_controls=True) if strength == "auto" else Luma2VideoEdit(strength=strength)
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video_edit",
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
video=Luma2VideoOptions(resolution=resolution, duration=duration, edit=edit),
)
return await _ray32_generate(cls, request)
class LumaRay32VideoReframeNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32VideoReframeNode",
display_name="Luma Ray 3.2 Video Reframe",
category="partner/video/Luma",
description="Change the aspect ratio of an existing video, using Luma Ray 3.2 to fill the newly "
"exposed canvas areas. Source video up to 30 seconds. Billed per second of output.",
inputs=[
IO.Video.Input("video", tooltip="Source video to reframe. Up to 30 seconds."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describes how the newly exposed canvas areas should be filled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_REFRAME,
)
@classmethod
async def execute(
cls, video: Input.Video, prompt: str, aspect_ratio: str, resolution: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
if resolution == "1080p" and aspect_ratio in {"9:16", "3:4"}:
raise ValueError("1080p is not available for vertical aspect ratios (9:16, 3:4) when reframing.")
source_url = await upload_video_to_comfyapi(cls, video, max_duration=30)
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video_reframe",
aspect_ratio=aspect_ratio,
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
video=Luma2VideoOptions(resolution=resolution),
)
return await _ray32_generate(cls, request)
class LumaRay32ExtendVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32ExtendVideoNode",
display_name="Luma Ray 3.2 Extend Video",
category="partner/video/Luma",
description="Extend a previous Ray 3.2 generation forward (continue after it) or backward (lead-in "
"before it). Connect the generation_id output of a prior Luma Ray 3.2 node."
" Extensions are always 5 seconds.",
inputs=[
IO.String.Input(
"source_generation_id",
default="",
tooltip="generation_id of the prior Ray 3.2 video to extend."
" Connect the generation_id output of another Luma Ray 3.2 node.",
),
IO.DynamicCombo.Input(
"direction",
options=[
IO.DynamicCombo.Option(
"Forward (continue after)",
[
IO.Boolean.Input(
"loop",
default=False,
tooltip="Loop the extended video seamlessly (forward extend only).",
),
],
),
IO.DynamicCombo.Option("Backward (lead-in before)", []),
],
tooltip="Forward continues after the prior clip; backward is prepended before it.",
),
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the new content."),
IO.Combo.Input("resolution", options=["540p", "720p", "1080p"], default="720p"),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO_5S,
)
@classmethod
async def execute(
cls, source_generation_id: str, direction: dict, prompt: str, resolution: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
gen_id = (source_generation_id or "").strip()
if not gen_id:
raise ValueError(
"source_generation_id is required (connect the generation_id output of a prior Luma Ray 3.2 node)."
)
video = Luma2VideoOptions(resolution=resolution, duration="5s")
ref = Luma2ImageRef(generation_id=gen_id)
if direction["direction"] == "Forward (continue after)":
video.start_frame = ref
if direction.get("loop"):
video.loop = True
else:
video.end_frame = ref
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
return await _ray32_generate(cls, request)
class LumaExtension(ComfyExtension): class LumaExtension(ComfyExtension):
@ -944,6 +1481,13 @@ class LumaExtension(ComfyExtension):
LumaConceptsNode, LumaConceptsNode,
LumaImageNode, LumaImageNode,
LumaImageEditNode, LumaImageEditNode,
LumaRay32TextToVideoNode,
LumaRay32ImageToVideoNode,
LumaRay32KeyframeNode,
LumaRay32KeyframesToVideoNode,
LumaRay32VideoEditNode,
LumaRay32VideoReframeNode,
LumaRay32ExtendVideoNode,
] ]

View File

@ -4,6 +4,8 @@ import os
import re import re
import time import time
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from io import BytesIO from io import BytesIO
from yarl import URL from yarl import URL
@ -91,6 +93,32 @@ async def sleep_with_interrupt(
await asyncio.sleep(min(1.0, end - now)) await asyncio.sleep(min(1.0, end - now))
def _retry_after_wait(value: str | None, fallback: float, max_wait: float) -> float:
"""Delay before the next retry, honoring a server ``Retry-After`` header."""
seconds: float | None = None
if value is not None:
value = value.strip()
if value.isascii() and value.isdigit():
# delay-seconds form. The ASCII-digit guard keeps exotic Unicode "digit" characters away from float()
# an all-digit string always converts (huge values become inf, never raising).
seconds = float(value)
elif value:
# HTTP-date form. parsedate_to_datetime raises OverflowError (not a ValueError) on absurd years/offsets
try:
parsed = parsedate_to_datetime(value)
except (TypeError, ValueError, OverflowError):
parsed = None
if parsed is not None:
if parsed.tzinfo is None: # naive datetime: HTTP-date is UTC
parsed = parsed.replace(tzinfo=timezone.utc)
delta = (parsed - datetime.now(timezone.utc)).total_seconds()
seconds = delta if delta > 0 else 0.0
if seconds is None:
return fallback
return min(seconds, max_wait)
def mimetype_to_extension(mime_type: str) -> str: def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension.""" """Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower() return mime_type.split("/")[-1].lower()

View File

@ -21,6 +21,7 @@ from server import PromptServer
from . import request_logger from . import request_logger
from ._helpers import ( from ._helpers import (
_retry_after_wait,
default_base_url, default_base_url,
get_comfy_api_headers, get_comfy_api_headers,
get_node_id, get_node_id,
@ -82,6 +83,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
_MAX_RETRY_AFTER_WAIT = 150.0 # Cap a server Retry-After at this many seconds so a large hint can't block execution
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
@ -747,6 +749,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
should_retry = True should_retry = True
if should_retry: if should_retry:
wait_time = _retry_after_wait(resp.headers.get("Retry-After"), wait_time, _MAX_RETRY_AFTER_WAIT)
logging.warning( logging.warning(
"HTTP %s %s -> %s. Waiting %.2fs (%s).", "HTTP %s %s -> %s. Waiting %.2fs (%s).",
method, method,

View File

@ -4,11 +4,22 @@ Provides normalization and helper functions for job status tracking.
""" """
import uuid import uuid
from typing import Optional from typing import Callable, Optional
from comfy_api.internal import prune_dict from comfy_api.internal import prune_dict
# Result of classifying a job for cancellation.
# 'running' -> job is currently executing (interrupt it)
# 'pending' -> job is queued but not started (dequeue it)
# 'terminal' -> job already finished (present in history); cancel is a no-op
# 'unknown' -> job id is not present anywhere
CANCEL_RUNNING = 'running'
CANCEL_PENDING = 'pending'
CANCEL_TERMINAL = 'terminal'
CANCEL_UNKNOWN = 'unknown'
class JobStatus: class JobStatus:
"""Job status constants.""" """Job status constants."""
PENDING = 'pending' PENDING = 'pending'
@ -407,3 +418,71 @@ def get_all_jobs(
jobs = jobs[:limit] jobs = jobs[:limit]
return (jobs, total_count) return (jobs, total_count)
def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str:
"""Classify a job id for cancellation.
Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN.
Queue items are tuples whose second element (index 1) is the prompt_id.
History is a dict keyed by prompt_id, so a job present there has already
finished and cancelling it is a no-op.
"""
for item in running:
if item[1] == prompt_id:
return CANCEL_RUNNING
for item in queued:
if item[1] == prompt_id:
return CANCEL_PENDING
if prompt_id in history:
return CANCEL_TERMINAL
return CANCEL_UNKNOWN
def cancel_job(
prompt_id: str,
running: list,
queued: list,
history: dict,
interrupt: Callable[[str], bool],
dequeue: Callable[[str], bool],
) -> str:
"""Cancel a single job by id, regardless of state.
Maps the cancel onto the runtime's existing mechanics:
- a running job is interrupted via ``interrupt``
- a pending job is removed from the queue via ``dequeue``
- a job that already finished (terminal) is a no-op
- an unknown id is a no-op (callers that need fail-fast behaviour should
validate ids up front with ``classify_job_for_cancel``)
Both ``interrupt`` and ``dequeue`` take the prompt id and return whether
they acted on a job that was *actually* in that state, so the value returned
here reflects what truly happened rather than the (possibly stale)
classification. This matters around the narrow TOCTOU windows where a job
changes state between the caller's snapshot and the action:
- a job classified RUNNING may have finished before ``interrupt`` fires:
``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op).
- a job classified PENDING may have started executing before ``dequeue``
fires: ``dequeue`` returns False, ``interrupt`` then catches the now-
running job and this returns CANCEL_RUNNING. If it had simply finished
instead, both return False and this returns CANCEL_UNKNOWN.
``interrupt`` must be atomic interrupt the job only if it is still the one
running so a cancel can never land on an unrelated prompt that started in
the meantime (see ``execution.PromptQueue.interrupt_if_running``).
"""
classification = classify_job_for_cancel(prompt_id, running, queued, history)
if classification == CANCEL_RUNNING:
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
if classification == CANCEL_PENDING:
if dequeue(prompt_id):
return CANCEL_PENDING
# Left the pending queue between classification and dequeue: if it
# started executing, interrupt the now-running job; otherwise it has
# already finished and the cancel is a genuine no-op.
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
return classification

View File

@ -0,0 +1,97 @@
import math
import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class TextEncodeBooguEdit(io.ComfyNode):
"""Boogu-Image Edit conditioning.
The edit image is used twice, matching the reference pipeline:
- Qwen3-VL vision tokens (instruction understanding) -> positive only
- VAE reference latent (image identity) -> positive and negative
The ref latent is in both conds so it cancels under CFG (identity preserved);
the vision tokens are only in the positive so CFG amplifies the instruction.
The tokenizer selects the right system prompt automatically (image -> TI2I,
empty negative -> DROP), so no template plumbing is needed here.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeBooguEdit",
category="model/conditioning/boogu",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.String.Input("negative_prompt", multiline=True, dynamic_prompts=True, advanced=True),
io.Vae.Input("vae"),
io.Autogrow.Input(
"images",
template=io.Autogrow.TemplateNames(
io.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 17)],
min=0,
),
tooltip="Reference image(s) to edit. Boogu focuses on one reference per sample; more are allowed.",
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, clip, prompt, negative_prompt, vae=None, images: io.Autogrow.Type = None) -> io.NodeOutput:
ref_latents = []
images_vl = []
images = images or {}
for name in sorted(images, key=lambda n: int(n.rsplit("_", 1)[-1])):
image = images[name]
if image is None:
continue
samples = image.movedim(-1, 1)
# Vision tower input: the reference caps the VLM image at 384x384
# (max_vlm_input_pil_pixels in pipeline_boogu.py).
total = int(384 * 384)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
images_vl.append(s.movedim(1, -1)[:, :, :, :3])
# Reference latent: align to 16 px (VAE /8 * patch_size 2).
if vae is not None:
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 16.0) * 16
height = round(samples.shape[2] * scale_by / 16.0) * 16
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3]))
# positive: instruction + vision tokens; negative: empty (no vision). Ref latent on both.
positive = clip.encode_from_tokens_scheduled(clip.tokenize(prompt, images=images_vl))
negative = clip.encode_from_tokens_scheduled(clip.tokenize(negative_prompt))
if len(ref_latents) > 0:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True)
return io.NodeOutput(positive, negative)
class BooguExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeBooguEdit,
]
async def comfy_entrypoint() -> BooguExtension:
return BooguExtension()

View File

@ -13,21 +13,22 @@ class ContextWindowsManualNode(io.ComfyNode):
description="Manually set context windows.", description="Manually set context windows.",
inputs=[ inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
io.Combo.Input("context_schedule", options=[ io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED, comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."), ], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."),
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."), io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
], ],
outputs=[ outputs=[
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model: cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model:
model = model.clone() model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -51,6 +52,7 @@ class ContextWindowsManualNode(io.ComfyNode):
freenoise=freenoise, freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list, cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows, split_conds_to_windows=split_conds_to_windows,
latent_retain_index_list=latent_retain_index_list,
causal_window_fix=causal_window_fix, causal_window_fix=causal_window_fix,
) )
# make memory usage calculation only take into account the context window latents # make memory usage calculation only take into account the context window latents
@ -65,33 +67,71 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
schema = super().define_schema() schema = super().define_schema()
schema.node_id = "WanContextWindowsManual" schema.node_id = "WanContextWindowsManual"
schema.display_name = "WAN Context Windows (Manual)" schema.display_name = "WAN Context Windows (Manual)"
schema.description = "Manually set context windows for WAN-like models (dim=2)." schema.display_name = "Wan Context Windows"
schema.description = "Set context windows for Wan-like models."
schema.category="model/patch/wan" schema.category="model/patch/wan"
schema.inputs = [ schema.inputs = [
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True), io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."),
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True), io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."),
io.Combo.Input("context_schedule", options=[ io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED, comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."), ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
] ]
return schema return schema
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 context_overlap = max(context_overlap // 4, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) retain_index_list = "0" if retain_first_frame else ""
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
class LTXVContextWindowsNode(ContextWindowsManualNode):
@classmethod
def define_schema(cls) -> io.Schema:
schema = super().define_schema()
schema.node_id = "LTXVContextWindows"
schema.display_name = "LTXV Context Windows"
schema.description = "Set context windows for LTXV-like models."
schema.inputs = [
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."),
io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
]
return schema
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool,
retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1
context_overlap = max(context_overlap // 8, 0) # at least overlap 0
retain_index_list = "0" if retain_first_frame else ""
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise,
cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
class ContextWindowsExtension(ComfyExtension): class ContextWindowsExtension(ComfyExtension):
@ -99,6 +139,7 @@ class ContextWindowsExtension(ComfyExtension):
return [ return [
ContextWindowsManualNode, ContextWindowsManualNode,
WanContextWindowsManualNode, WanContextWindowsManualNode,
LTXVContextWindowsNode,
] ]
def comfy_entrypoint(): def comfy_entrypoint():

View File

@ -1583,7 +1583,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file) shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f: with open(shard_path, "rb") as f:
shard_data = torch.load(f) shard_data = torch.load(f, weights_only=True)
all_latents.extend(shard_data["latents"]) all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"]) all_conditioning.extend(shard_data["conditioning"])

View File

@ -77,7 +77,7 @@ class FrameInterpolate(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="FrameInterpolate", node_id="FrameInterpolate",
display_name="Frame Interpolate", display_name="Run Frame Interpolation Model",
category="video", category="video",
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
inputs=[ inputs=[

View File

@ -317,11 +317,74 @@ class PreviewPointCloud(IO.ComfyNode):
) )
MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
class Load3DAdvanced(IO.ComfyNode):
@classmethod
def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in MESH_EXTENSIONS
]
return IO.Schema(
node_id="Load3DAdvanced",
display_name="Load 3D (Advanced)",
category="3d",
search_aliases=[
"load mesh",
"load gltf",
"load glb",
"load obj",
"load fbx",
"load stl",
],
is_experimental=True,
inputs=[
IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model),
IO.Load3D.Input("viewport_state"),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.File3DAny.Output(display_name="model_3d"),
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
],
)
@classmethod
def validate_inputs(cls, model_file, **kwargs) -> bool | str:
if not model_file or model_file == "none":
return True
if not folder_paths.exists_annotated_filepath(model_file):
return f"Invalid 3D model file: {model_file}"
return True
@classmethod
def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
file_3d = None
if model_file and model_file != "none":
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
model_3d_info = viewport_state.get('model_3d_info', [])
return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height)
class Load3DExtension(ComfyExtension): class Load3DExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
Load3D, Load3D,
Load3DAdvanced,
Preview3D, Preview3D,
Preview3DAdvanced, Preview3DAdvanced,
PreviewGaussianSplat, PreviewGaussianSplat,

View File

@ -89,7 +89,8 @@ class SwitchNode(io.ComfyNode):
template = io.MatchType.Template("switch") template = io.MatchType.Template("switch")
return io.Schema( return io.Schema(
node_id="ComfySwitchNode", node_id="ComfySwitchNode",
display_name="Switch", search_aliases=["if", "then", "switch", "conditional", "branch"],
display_name="If/Else Switch",
category="utilities/logic", category="utilities/logic",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[

View File

@ -10,12 +10,11 @@ class String(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="PrimitiveString", node_id="PrimitiveString",
search_aliases=["text", "string", "text box", "prompt"], search_aliases=["text", "string", "text box", "prompt"],
display_name="Text String", display_name="Text String (DEPRECATED)",
category="utilities/primitive", category="utilities/primitive",
inputs=[ inputs=[io.String.Input("value")],
io.String.Input("value"),
],
outputs=[io.String.Output()], outputs=[io.String.Output()],
is_deprecated=True
) )
@classmethod @classmethod
@ -29,12 +28,10 @@ class StringMultiline(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="PrimitiveStringMultiline", node_id="PrimitiveStringMultiline",
search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"], search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"],
display_name="Text String (Multiline)", display_name="Input Text",
category="utilities/primitive", category="utilities/primitive",
essentials_category="Basics", essentials_category="Basics",
inputs=[ inputs=[io.String.Input("value", multiline=True)],
io.String.Input("value", multiline=True),
],
outputs=[io.String.Output()], outputs=[io.String.Output()],
) )

View File

@ -34,14 +34,20 @@ def _unpack(track_data):
return unpack_masks(packed) return unpack_masks(packed)
def _first_frame_cx_area(masks_bool): def _first_appearance_cx_area(masks_bool):
first = masks_bool[0].float() """Per object: first frame it appears in, plus centroid-x and area in that frame."""
H, W = first.shape[-2], first.shape[-1] m = masks_bool.float()
n_pixels = H * W T, H, W = m.shape[0], m.shape[-2], m.shape[-1]
grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W) grid_x = torch.arange(W, device=m.device, dtype=m.dtype).view(1, 1, 1, W)
area = first.sum(dim=(-1, -2)).clamp_(min=1) area_t = m.sum(dim=(-1, -2))
cx = (first * grid_x).sum(dim=(-1, -2)) / area cx_t = (m * grid_x).sum(dim=(-1, -2)) / area_t.clamp(min=1)
return (cx / W).tolist(), (area / n_pixels).tolist() present = area_t > 0
frame_idx = torch.arange(T, device=m.device).unsqueeze(1)
first_t = torch.where(present, frame_idx, T).amin(dim=0)
sel = first_t.clamp(max=T - 1).unsqueeze(0)
cx = cx_t.gather(0, sel).squeeze(0)
area = area_t.gather(0, sel).squeeze(0)
return first_t.tolist(), (cx / W).tolist(), (area / (H * W)).tolist()
def _subset_track_data(track_data, obj_indices): def _subset_track_data(track_data, obj_indices):
@ -81,12 +87,26 @@ def _render_colored_masks(track_data, background="black"):
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest" masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
).view(T, N_obj, H, W) > 0.5 ).view(T, N_obj, H, W) > 0.5
any_mask = masks_full.any(dim=1) any_mask = masks_full.any(dim=1)
obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1) color_overlay = colors[masks_full.to(torch.uint8).argmax(dim=1)]
color_overlay = colors[obj_idx_map]
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3) bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay)) return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
def _render_mask_as_identity(mask, background="black"):
"""Plain comfy MASK (B,H,W) or (H,W) -> (B,H,W,3) rendered as a single identity (palette[0])
on the given background. A batch is treated as multiple views of that one subject."""
device = comfy.model_management.intermediate_device()
dtype = comfy.model_management.intermediate_dtype()
if mask.ndim == 2:
mask = mask.unsqueeze(0)
mask = mask.to(device=device, dtype=dtype)
B, H, W = mask.shape
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
color = torch.tensor(DEFAULT_PALETTE[0], device=device, dtype=dtype).view(1, 1, 1, 3)
bg = torch.tensor(bg_rgb, device=device, dtype=dtype).view(1, 1, 1, 3)
return torch.where((mask > 0.5).unsqueeze(-1), color.expand(B, H, W, 3), bg.expand(B, H, W, 3))
def _extract_mask_to_28ch(rgb_video): def _extract_mask_to_28ch(rgb_video):
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent """Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c) (1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
@ -138,8 +158,8 @@ class WanSCAILToVideo(io.ComfyNode):
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."), io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."), io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."), io.Image.Input("reference_image", optional=True, tooltip="Reference image. The first image is the primary reference (composite all identities onto it). SCAIL-2: extra batch images are used as additional views (back view, close-up, occluded background), each needing a matching reference_image_mask in that identity's color."),
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."), io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask, batch matching reference_image (first = primary reference mask, rest = identity masks for the additional reference_image)."),
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."), io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."), io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."), io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
@ -171,19 +191,21 @@ class WanSCAILToVideo(io.ComfyNode):
video_frame_offset -= prev_trimmed.shape[0] video_frame_offset -= prev_trimmed.shape[0]
video_frame_offset = max(0, video_frame_offset) video_frame_offset = max(0, video_frame_offset)
ref_latent = None
if reference_image is not None: if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) ref_imgs = comfy.utils.common_upscale(reference_image.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
# Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte n_ref = ref_imgs.shape[0]
if replacement_mode and reference_image_mask is not None: # SCAIL-2 multi-reference: the first image is the primary ref, the rest are additional references.
rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype)
reference_image = reference_image * is_char
ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None: # Replacement Mode: composite each ref on black bg using its mask as alpha matte
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) if replacement_mode and reference_image_mask is not None:
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) rm = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
rm = rm[[min(i, rm.shape[0] - 1) for i in range(n_ref)]]
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(ref_imgs.dtype)
ref_imgs = ref_imgs * is_char
# encode each ref individually so each stays a single latent frame (a batched encode would be treated as a video)
ref_latents = [vae.encode(ref_imgs[i:i + 1, :, :, :3]) for i in range(n_ref)]
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True)
if clip_vision_output is not None: if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
@ -221,11 +243,16 @@ class WanSCAILToVideo(io.ComfyNode):
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch}) positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch}) negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
if reference_image_mask is not None: # The ref mask binds reference frames to identities, so it only applies when there's a reference image.
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) if reference_image_mask is not None and reference_image is not None:
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw) ref_mask_hw = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
n_masks = ref_mask_hw.shape[0]
n_ref = reference_image.shape[0]
add_masks = [_extract_mask_to_28ch(ref_mask_hw[min(i, n_masks - 1)][None]) for i in range(1, n_ref)]
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw[:1])
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype) zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1) ref_mask_28ch = torch.cat(add_masks + [ref_mask_1f, zeros], dim=1)
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch}) positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch}) negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
@ -244,12 +271,9 @@ class WanSCAILToVideo(io.ComfyNode):
class SCAIL2ColoredMask(io.ComfyNode): class SCAIL2ColoredMask(io.ComfyNode):
"""Render SAM3 tracks for the driving pose video and (optionally) the reference """Render SAM3 tracks for the driving pose video and reference image(s) into the
image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by` colored masks WanSCAILToVideo consumes. Shared `sort_by` keeps each identity on the
across both outputs guarantees identity K maps to the same color on both same color across both outputs.
sides, for multi-person workflow consistency.
reference_image_mask is always rendered black-bg (model convention)
pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode
""" """
@classmethod @classmethod
@ -260,10 +284,12 @@ class SCAIL2ColoredMask(io.ComfyNode):
category="model/conditioning/wan/scail", category="model/conditioning/wan/scail",
inputs=[ inputs=[
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."), SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
SAM3TrackData.Input("ref_track_data", optional=True, tooltip="SAM3 track of the reference image."), io.MultiType.Input("ref_track_data", [SAM3TrackData, io.Mask], optional=True, display_name="reference_masks",
io.String.Input("object_indices", default="", tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."), tooltip="SAM3 track of the reference image(s) (one identity per object, colored in batch order), or a plain MASK of the reference subject (rendered as a single identity)."),
io.String.Input("object_indices", default="",
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right", io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."), tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). Objects that appear in earlier frames always come first; within a frame, left_to_right = leftmost object (by centroid at first appearance) gets the first color, area = biggest object (by mask area at first appearance) gets the first color; none = keep SAM3's order."),
io.Boolean.Input("replacement_mode", default=False, io.Boolean.Input("replacement_mode", default=False,
tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). " tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). "
"True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."), "True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."),
@ -280,11 +306,11 @@ class SCAIL2ColoredMask(io.ComfyNode):
def _prep(td): def _prep(td):
masks_bool = _unpack(td) masks_bool = _unpack(td)
if sort_by != "none" and masks_bool is not None: if sort_by != "none" and masks_bool is not None:
cx, area = _first_frame_cx_area(masks_bool) first_t, cx, area = _first_appearance_cx_area(masks_bool)
if sort_by == "left_to_right": if sort_by == "left_to_right":
order = sorted(range(len(cx)), key=lambda i: cx[i]) order = sorted(range(len(cx)), key=lambda i: (first_t[i], cx[i]))
else: # "area" else: # "area"
order = sorted(range(len(area)), key=lambda i: -area[i]) order = sorted(range(len(area)), key=lambda i: (first_t[i], -area[i]))
td = _subset_track_data(td, order) td = _subset_track_data(td, order)
if object_indices.strip(): if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
@ -300,8 +326,10 @@ class SCAIL2ColoredMask(io.ComfyNode):
ref_bg = "black" if replacement_mode else "white" ref_bg = "black" if replacement_mode else "white"
if ref_track_data is not None: if ref_track_data is not None:
ref = _prep(ref_track_data) if isinstance(ref_track_data, torch.Tensor): # plain comfy MASK
reference_image_mask = _render_colored_masks(ref, ref_bg) reference_image_mask = _render_mask_as_identity(ref_track_data, ref_bg)
else:
reference_image_mask = _render_colored_masks(_prep(ref_track_data), ref_bg)
else: else:
H, W = drv["orig_size"] H, W = drv["orig_size"]
fill_value = 1.0 if ref_bg == "white" else 0.0 fill_value = 1.0 if ref_bg == "white" else 0.0

View File

@ -65,7 +65,7 @@ class TripoSplatPreprocessImage(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoSplatPreprocessImage", node_id="TripoSplatPreprocessImage",
display_name="TripoSplat Preprocess Image", display_name="TripoSplat Preprocess Image",
category="3d/conditioning", category="model/conditioning/triposplat",
description="Crop center each image to a square canvas on a black background and add padding.", description="Crop center each image to a square canvas on a black background and add padding.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -95,7 +95,7 @@ class TripoSplatConditioning(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoSplatConditioning", node_id="TripoSplatConditioning",
display_name="TripoSplat Conditioning", display_name="TripoSplat Conditioning",
category="3d/conditioning", category="model/conditioning/triposplat",
description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative " description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative "
"conditioning, and create the fixed size noise target (latent + camera) for the KSampler", "conditioning, and create the fixed size noise target (latent + camera) for the KSampler",
inputs=[ inputs=[

View File

@ -235,13 +235,8 @@ class VideoSlice(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="Video Slice", node_id="Video Slice",
display_name="Video Slice", display_name="Trim Video",
search_aliases=[ search_aliases=["trim video duration", "skip first frames", "frame load cap", "start time"],
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
category="video", category="video",
essentials_category="Video Tools", essentials_category="Video Tools",
inputs=[ inputs=[

View File

@ -1308,6 +1308,25 @@ class PromptQueue:
queued = copy.copy(self.queue) queued = copy.copy(self.queue)
return (running, queued) return (running, queued)
def interrupt_if_running(self, prompt_id):
"""Interrupt the running prompt with this id, atomically.
Checks the live running set and signals the interrupt under the queue
mutex, so the worker cannot move the job to done (and start the next
prompt) in between. Returns True if a matching job was running and an
interrupt was signalled, False otherwise. The atomicity is what keeps a
cancel from landing on an unrelated prompt that started after a separate
is-running check: the global interrupt flag is reset at the start of
every prompt (execute_async), so a job that finishes before consuming
the flag cannot leak the interrupt onto its successor.
"""
with self.mutex:
for item in self.currently_running.values():
if item[1] == prompt_id:
nodes.interrupt_processing()
return True
return False
def get_tasks_remaining(self): def get_tasks_remaining(self):
with self.mutex: with self.mutex:
return len(self.queue) + len(self.currently_running) return len(self.queue) + len(self.currently_running)

View File

@ -20,8 +20,6 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_load import comfy.diffusers_load
import comfy.samplers import comfy.samplers
import comfy.sample import comfy.sample
@ -971,7 +969,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@ -2305,6 +2303,9 @@ async def init_external_custom_nodes():
Returns: Returns:
None None
""" """
# TODO: remove at some point when custom nodes don't break.
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
base_node_names = set(NODE_CLASS_MAPPINGS.keys()) base_node_names = set(NODE_CLASS_MAPPINGS.keys())
node_paths = folder_paths.get_folder_paths("custom_nodes") node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = [] node_import_times = []
@ -2431,6 +2432,7 @@ async def init_builtin_extra_nodes():
"nodes_tcfg.py", "nodes_tcfg.py",
"nodes_context_windows.py", "nodes_context_windows.py",
"nodes_qwen.py", "nodes_qwen.py",
"nodes_boogu.py",
"nodes_chroma_radiance.py", "nodes_chroma_radiance.py",
"nodes_pid.py", "nodes_pid.py",
"nodes_model_patch.py", "nodes_model_patch.py",

View File

@ -55,6 +55,12 @@ components:
description: URL for asset preview/thumbnail description: URL for asset preview/thumbnail
format: uri format: uri
type: string type: string
short_url:
description: Durable, owner-gated short link to this asset's content (relative `/api/s/{id}` path). Stable across the underlying signed URL's expiry — resolving it re-mints a fresh signed URL on every request — so it is safe to persist or share into chat, unlike `preview_url`. Only the minting user can resolve it. Omitted when the short-link surface is disabled or the asset has no resolvable content hash.
nullable: true
type: string
x-runtime:
- cloud
size: size:
description: Size of the asset in bytes description: Size of the asset in bytes
format: int64 format: int64
@ -673,6 +679,35 @@ components:
- created_at - created_at
- updated_at - updated_at
type: object type: object
JobsCancelRequest:
additionalProperties: false
description: Request to cancel multiple jobs by ID.
properties:
job_ids:
description: Job identifiers (UUIDs) to cancel.
items:
format: uuid
type: string
maxItems: 100
minItems: 1
type: array
required:
- job_ids
type: object
JobsCancelResponse:
description: Response for POST /api/jobs/cancel.
properties:
cancelled:
description: |
Job IDs for which a cancel event was successfully dispatched by this
call. Jobs already in a terminal or cancelling state are idempotently
skipped and will not appear here.
items:
type: string
type: array
required:
- cancelled
type: object
JobsListResponse: JobsListResponse:
description: Paginated list of jobs for the authenticated user. description: Paginated list of jobs for the authenticated user.
properties: properties:
@ -1006,7 +1041,7 @@ components:
description: If true, clear all pending jobs from the queue description: If true, clear all pending jobs from the queue
type: boolean type: boolean
delete: delete:
description: Array of PENDING job IDs to cancel description: Array of job IDs to cancel; pending and running jobs transition to cancelled
items: items:
type: string type: string
type: array type: array
@ -1822,6 +1857,83 @@ paths:
summary: Update asset metadata summary: Update asset metadata
tags: tags:
- file - file
/api/assets/{id}/content:
get:
description: |
Returns the binary content of an asset by ID.
The contract is the same across runtimes — "GET this path and you
receive the asset's bytes" — but the mechanism differs:
- **Local ComfyUI** streams the bytes directly (`200`,
`application/octet-stream`).
- **Cloud** does not proxy large files; it responds `302` with a
`Location` redirect to a short-lived signed storage URL. Clients that
follow redirects (browsers, `fetch`/XHR, `<img>`/`<video>`) receive
the bytes transparently.
Prefer this over the filename-addressed `/api/view` when you have an
asset ID.
operationId: getAssetContent
parameters:
- description: Asset ID
in: path
name: id
required: true
schema:
type: string
- description: |
Content-Disposition for the response: `attachment` (download) or
`inline` (render in browser). Defaults to `attachment`.
in: query
name: disposition
schema:
default: attachment
enum:
- inline
- attachment
type: string
responses:
"200":
content:
application/octet-stream:
schema:
format: binary
type: string
description: Asset content stream (local runtime streams the bytes directly)
"302":
description: Redirect to a signed storage URL (cloud runtime)
headers:
Cache-Control:
description: Private caching directive scoped to the signed URL lifetime
schema:
type: string
Location:
description: Short-lived signed URL to the asset content in storage
schema:
type: string
Vary:
description: Partitions any cached redirect by auth credentials so a private redirect is not reused across users
schema:
type: string
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Asset not found
"500":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error
security:
- ApiKeyAuth: []
- BearerAuth: []
- CookieAuth: []
summary: Get asset content
tags:
- file
/api/assets/{id}/tags: /api/assets/{id}/tags:
delete: delete:
description: Removes one or more tags from an existing asset description: Removes one or more tags from an existing asset
@ -2675,14 +2787,20 @@ paths:
summary: Get internationalisation translation strings summary: Get internationalisation translation strings
/api/interrupt: /api/interrupt:
post: post:
deprecated: true
description: | description: |
Cancel all currently RUNNING jobs for the authenticated user. Deprecated. Prefer the jobs-namespace cancel endpoints:
This will interrupt any job that is currently in 'in_progress' status. POST /api/jobs/{job_id}/cancel for a single job, or
Note: This endpoint only affects running jobs. To cancel pending jobs, use /api/queue. POST /api/jobs/cancel to cancel jobs by ID.
Cancels the first active job for the authenticated user (the currently
running job if there is one, otherwise the next pending job). Takes no
body and cannot target a specific job — use the jobs-namespace endpoints
for that.
operationId: interruptJob operationId: interruptJob
responses: responses:
"200": "200":
description: Success - Job interrupted or no running job found description: Success - first active job cancelled, or no active job found
"401": "401":
content: content:
application/json: application/json:
@ -2695,7 +2813,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/ErrorResponse' $ref: '#/components/schemas/ErrorResponse'
description: Internal server error description: Internal server error
summary: Interrupt currently running jobs summary: Interrupt the first active job
tags: tags:
- queue - queue
/api/job/{job_id}/status: /api/job/{job_id}/status:
@ -2869,6 +2987,17 @@ paths:
schema: schema:
format: uuid format: uuid
type: string type: string
- description: |
When present, each output item in the response receives a `short_url` field containing an owner-gated durable link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime: use `ephemeral_tool_chain` for short-lived machine-to-machine handoffs (~15 minutes); use `default` for durable human-revisitable links (30 days). Links are minted only for the authenticated request owner and are not resolvable by other users.
in: query
name: short_link
schema:
enum:
- ephemeral_tool_chain
- default
type: string
x-runtime:
- cloud
responses: responses:
"200": "200":
content: content:
@ -2954,6 +3083,64 @@ paths:
summary: Cancel a job summary: Cancel a job
tags: tags:
- workflow - workflow
/api/jobs/cancel:
post:
description: |
Cancel one or more jobs for the authenticated user in a single request.
State-agnostic: cancels both pending and running jobs (both transition to
the cancelled state via the same mechanism as the single-job endpoint).
Idempotent per job: a job already in a terminal or cancelling state is a
no-op and simply will not appear in the returned `cancelled` list.
Fail-fast on unknown IDs: if any provided job ID does not exist for this
user, the request returns 404 and no jobs are cancelled. This surfaces
bad IDs to the caller rather than silently dropping them.
This is the canonical batch-cancel endpoint. The delete operation on
POST /api/queue is deprecated in favour of this.
operationId: cancelJobs
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/JobsCancelRequest'
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/JobsCancelResponse'
description: Success - cancel requests dispatched (or jobs were already terminal)
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Bad Request - job_ids is missing, empty, exceeds the maximum count, or contains an invalid UUID
"401":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unauthorized - Authentication required
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: One or more job IDs not found for this user (no jobs cancelled)
"500":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error - cancellation failed
summary: Cancel multiple jobs
tags:
- workflow
/api/node_replacements: /api/node_replacements:
get: get:
description: | description: |
@ -3104,9 +3291,18 @@ paths:
tags: tags:
- queue - queue
post: post:
deprecated: true
description: | description: |
Cancel specific PENDING jobs by ID or clear all pending jobs in the queue. Deprecated. Prefer the jobs-namespace cancel endpoints:
Note: This endpoint only affects pending jobs. To cancel running jobs, use /api/interrupt. POST /api/jobs/cancel for cancelling jobs by ID, and
POST /api/jobs/{job_id}/cancel for a single job.
Cancel specific jobs by ID (the `delete` field) or clear all pending
jobs in the queue (the `clear` field). Despite the `delete` naming, this
does not delete anything — listed jobs transition to the cancelled state,
and `delete` cancels both pending and running jobs (not pending-only as
previously documented). Job-by-ID cancellation is superseded by
POST /api/jobs/cancel; `clear` has no jobs-namespace replacement yet.
operationId: manageQueue operationId: manageQueue
requestBody: requestBody:
content: content:

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.45.15 comfyui-frontend-package==1.45.19
comfyui-workflow-templates==0.10.0 comfyui-workflow-templates==0.10.0
comfyui-embedded-docs==0.5.4 comfyui-embedded-docs==0.5.4
torch torch

111
server.py
View File

@ -8,7 +8,15 @@ import time
import nodes import nodes
import folder_paths import folder_paths
import execution import execution
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id from comfy_execution.jobs import (
JobStatus,
get_job,
get_all_jobs,
validate_job_id,
cancel_job,
CANCEL_PENDING,
CANCEL_RUNNING,
)
import uuid import uuid
import urllib import urllib
import json import json
@ -899,6 +907,107 @@ class PromptServer():
return web.json_response(job) return web.json_response(job)
def _cancel_job_by_id(job_id):
"""Cancel a single job by id using the queue's existing mechanics.
Running jobs are interrupted (same mechanism as /interrupt); pending
jobs are dequeued (same mechanism as /queue {"delete": [...]}).
Already-finished or unknown ids are no-ops. State-agnostic.
Returns True when a cancel was actually dispatched (running or
pending job), False when the call was a no-op (terminal/unknown id).
"""
running, queued = self.prompt_queue.get_current_queue()
history = self.prompt_queue.get_history()
def interrupt(prompt_id):
logging.info(f"Cancelling running prompt {prompt_id}")
# Atomic: only interrupts if the job is still the one running,
# so a cancel can't land on a prompt that started in the gap
# since the snapshot above. Returns whether it actually fired.
return self.prompt_queue.interrupt_if_running(prompt_id)
def dequeue(prompt_id):
logging.info(f"Cancelling pending prompt {prompt_id}")
return self.prompt_queue.delete_queue_item(lambda a: a[1] == prompt_id)
classification = cancel_job(job_id, running, queued, history, interrupt, dequeue)
return classification in (CANCEL_RUNNING, CANCEL_PENDING)
@routes.post("/api/jobs/{job_id}/cancel")
async def cancel_job_by_id(request):
"""Cancel a single job by id, regardless of state.
Idempotent: cancelling a job that has already finished, or an id
that is not known, returns 200 with {"cancelled": false} rather
than an error.
"""
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response(
{"error": "job_id is required"},
status=400
)
cancelled = _cancel_job_by_id(job_id)
return web.json_response({"cancelled": cancelled})
@routes.post("/api/jobs/cancel")
async def cancel_jobs_batch(request):
"""Cancel a batch of jobs by id.
Body: {"job_ids": ["<uuid>", ...]}
Best-effort and idempotent: every well-formed id is cancelled if it
is running or pending; ids that are already finished or unknown are
no-ops, not errors. A batch of all no-ops still returns 200 with
{"cancelled": false}. This matches the single-cancel endpoint and
means "cancel all" still cancels the in-progress jobs even if some
finished between the client's snapshot and the request. Malformed
ids are still rejected up front with 400 (see below).
"""
try:
json_data = await request.json()
except json.JSONDecodeError:
return web.json_response(
{"error": "Request body must be valid JSON"},
status=400
)
job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None
if not isinstance(job_ids, list):
return web.json_response(
{"error": "job_ids must be a list"},
status=400
)
# Validate that every element is a well-formed job id before doing
# anything else. An unhashable element (e.g. a nested dict or list)
# would cause a TypeError when used as a history dict key; a
# non-string or non-UUID value is never a valid id. Reject early
# with 400 rather than letting the classify loop raise 500.
invalid_ids = []
for jid in job_ids:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
if invalid_ids:
return web.json_response(
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400,
)
# Best-effort: cancel each id that is still running/pending; an id
# that has finished or never existed is a no-op rather than a reason
# to fail the whole batch.
cancelled = False
for jid in job_ids:
if _cancel_job_by_id(jid):
cancelled = True
return web.json_response({"cancelled": cancelled})
@routes.get("/history") @routes.get("/history")
async def get_history(request): async def get_history(request):
max_items = request.rel_url.query.get("max_items", None) max_items = request.rel_url.query.get("max_items", None)

View File

View File

@ -0,0 +1,453 @@
"""Tests for the jobs-namespace cancel endpoints.
Covers both layers:
* the pure cancel helpers in ``comfy_execution.jobs``
(``classify_job_for_cancel`` / ``cancel_job``), which hold the business
logic of mapping a cancel onto interrupt-vs-dequeue, and
* the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and
``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and
best-effort batch cancellation that treats unknown/finished ids as no-ops
while still rejecting malformed ids with 400).
The HTTP layer is exercised against a small aiohttp app whose handlers are a
faithful copy of the wiring in ``server.py`` driven by a fake queue that
mirrors ``execution.PromptQueue`` (``get_current_queue`` / ``get_history`` /
``delete_queue_item``). This keeps the test free of the heavy ComfyUI runtime
(torch, nodes, ...) while still testing the real cancel logic.
"""
import json
import pytest
from aiohttp import web
from comfy_execution.jobs import (
CANCEL_PENDING,
CANCEL_RUNNING,
CANCEL_TERMINAL,
CANCEL_UNKNOWN,
cancel_job,
classify_job_for_cancel,
validate_job_id,
)
# Classifications for which a cancel was actually dispatched (vs a no-op).
_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING)
# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format).
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd"
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
def make_queue_item(prompt_id, number=0):
"""Build a queue tuple shaped like the real ones: index 1 is the id."""
return (number, prompt_id, {}, {}, [])
class FakePromptQueue:
"""Minimal stand-in for execution.PromptQueue for the cancel paths.
Tracks interrupts and dequeues so tests can assert side effects.
"""
def __init__(self, running=None, pending=None, history=None):
self._running = list(running or [])
self._pending = list(pending or [])
self._history = dict(history or {})
self.interrupt_count = 0
def get_current_queue(self):
return (list(self._running), list(self._pending))
def get_history(self, prompt_id=None):
if prompt_id is None:
return dict(self._history)
if prompt_id in self._history:
return {prompt_id: self._history[prompt_id]}
return {}
def delete_queue_item(self, function):
for i, item in enumerate(self._pending):
if function(item):
self._pending.pop(i)
return True
return False
def interrupt_if_running(self, prompt_id):
# Mirrors execution.PromptQueue.interrupt_if_running: only signals an
# interrupt when the id is actually in the running set.
if any(item[1] == prompt_id for item in self._running):
self.interrupt_count += 1
return True
return False
def build_app(queue):
"""Build an aiohttp app exposing the cancel routes against ``queue``.
Handler bodies mirror server.py exactly.
"""
def _cancel_job_by_id(job_id):
running, pending = queue.get_current_queue()
history = queue.get_history()
def interrupt(prompt_id):
return queue.interrupt_if_running(prompt_id)
def dequeue(prompt_id):
return queue.delete_queue_item(lambda a: a[1] == prompt_id)
classification = cancel_job(
job_id, running, pending, history, interrupt, dequeue
)
return classification in _CANCELLED
async def cancel_job_by_id(request):
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response({"error": "job_id is required"}, status=400)
cancelled = _cancel_job_by_id(job_id)
return web.json_response({"cancelled": cancelled})
async def cancel_jobs_batch(request):
try:
json_data = await request.json()
except json.JSONDecodeError:
return web.json_response(
{"error": "Request body must be valid JSON"}, status=400
)
job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None
if not isinstance(job_ids, list):
return web.json_response({"error": "job_ids must be a list"}, status=400)
invalid_ids = []
for jid in job_ids:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
if invalid_ids:
return web.json_response(
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400,
)
cancelled = False
for jid in job_ids:
if _cancel_job_by_id(jid):
cancelled = True
return web.json_response({"cancelled": cancelled})
app = web.Application()
app.router.add_post("/api/jobs/{job_id}/cancel", cancel_job_by_id)
app.router.add_post("/api/jobs/cancel", cancel_jobs_batch)
return app
# ---------------------------------------------------------------------------
# Pure helper tests: classification + cancel side effects
# ---------------------------------------------------------------------------
class TestClassifyJobForCancel:
def test_running(self):
running = [make_queue_item("a")]
assert classify_job_for_cancel("a", running, [], {}) == CANCEL_RUNNING
def test_pending(self):
pending = [make_queue_item("b")]
assert classify_job_for_cancel("b", [], pending, {}) == CANCEL_PENDING
def test_terminal(self):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
assert classify_job_for_cancel("c", [], [], history) == CANCEL_TERMINAL
def test_unknown(self):
assert classify_job_for_cancel("z", [], [], {}) == CANCEL_UNKNOWN
class TestCancelJobHelper:
"""``interrupt`` and ``dequeue`` both take the id and return whether they
actually acted, so cancel_job's return reflects the real outcome."""
def test_running_is_interrupted_not_dequeued(self):
interrupts = []
dequeues = []
result = cancel_job(
"a", [make_queue_item("a")], [], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_RUNNING
assert interrupts == ["a"]
assert dequeues == []
def test_pending_is_dequeued_not_interrupted(self):
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_PENDING
assert dequeues == ["b"]
assert interrupts == []
def test_terminal_is_noop(self):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
interrupts = []
dequeues = []
result = cancel_job(
"c", [], [], history,
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_TERMINAL
assert interrupts == []
assert dequeues == []
def test_unknown_is_noop(self):
interrupts = []
dequeues = []
result = cancel_job(
"z", [], [], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_UNKNOWN
assert interrupts == []
assert dequeues == []
def test_running_but_finished_before_interrupt_returns_unknown(self):
"""Classified RUNNING from a stale snapshot, but the job finished before
the atomic interrupt fired (interrupt returns False). cancel_job reports
UNKNOWN rather than claiming a cancel that did not happen and the
atomic interrupt guarantees no unrelated job was hit."""
interrupts = []
result = cancel_job(
"a", [make_queue_item("a")], [], {},
interrupt=lambda pid: interrupts.append(pid) or False,
dequeue=lambda pid: True,
)
assert result == CANCEL_UNKNOWN
assert interrupts == ["a"] # interrupt was attempted atomically
def test_pending_started_running_is_interrupted(self):
"""Pending->running race: the job leaves the queue (dequeue False)
because it started executing. The atomic interrupt catches the now-
running job, so cancel_job interrupts it and reports CANCEL_RUNNING."""
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: (dequeues.append(pid), False)[1],
)
assert result == CANCEL_RUNNING
assert dequeues == ["b"] # dequeue attempted first
assert interrupts == ["b"] # then the now-running job was interrupted
def test_pending_dequeue_miss_not_running_returns_unknown(self):
"""Dequeue miss where the job is not running anymore (it finished): the
atomic interrupt finds nothing to interrupt and returns False, so
cancel_job is a no-op reporting UNKNOWN never reporting a cancel that
did not happen, and never interrupting a bystander."""
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or False,
dequeue=lambda pid: (dequeues.append(pid), False)[1],
)
assert result == CANCEL_UNKNOWN
assert dequeues == ["b"]
assert interrupts == ["b"] # interrupt attempted, found nothing running
# ---------------------------------------------------------------------------
# HTTP contract tests: POST /api/jobs/{job_id}/cancel
# ---------------------------------------------------------------------------
class TestSingleCancelEndpoint:
@pytest.mark.asyncio
async def test_cancel_running_job_interrupts(self, aiohttp_client):
queue = FakePromptQueue(running=[make_queue_item("a")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/a/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1
@pytest.mark.asyncio
async def test_cancel_pending_job_dequeues(self, aiohttp_client):
queue = FakePromptQueue(pending=[make_queue_item("b")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/b/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
# Pending job removed from the queue; nothing interrupted.
assert queue.get_current_queue()[1] == []
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_terminal_job_is_idempotent_noop(self, aiohttp_client):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
queue = FakePromptQueue(history=history)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/c/cancel")
# Already-finished job: 200 no-op (cancelled=false), not an error.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_unknown_id_is_200_noop(self, aiohttp_client):
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/does-not-exist/cancel")
# Single-cancel of an unknown id is treated as an idempotent no-op.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_pending_that_started_running_interrupts(self, aiohttp_client):
"""Pending->running race end to end: the job is pending at snapshot time
but starts executing by the time we dequeue (delete misses). The live
re-check sees it running and interrupts it, so the cancel is not dropped
and the caller still gets cancelled=True."""
class RacingQueue(FakePromptQueue):
def delete_queue_item(self, function):
# The worker picked the job up just before we removed it: it
# leaves the pending queue (delete misses) and is now running.
self._running = list(self._pending)
self._pending = []
return False
queue = RacingQueue(pending=[make_queue_item("b")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/b/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1
# ---------------------------------------------------------------------------
# HTTP contract tests: POST /api/jobs/cancel (batch)
# ---------------------------------------------------------------------------
class TestBatchCancelEndpoint:
@pytest.mark.asyncio
async def test_batch_happy_path(self, aiohttp_client):
queue = FakePromptQueue(
running=[make_queue_item(_UUID_A)],
pending=[make_queue_item(_UUID_B, number=1)],
)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]})
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1 # running job interrupted
assert queue.get_current_queue()[1] == [] # pending job dequeued
@pytest.mark.asyncio
async def test_batch_best_effort_skips_unknown_id(self, aiohttp_client):
"""An unknown id in the batch is a no-op, not a reason to abort: the
running and pending jobs are still cancelled (200, cancelled=true). This
is the "cancel all as a job finishes" case from review."""
queue = FakePromptQueue(
running=[make_queue_item(_UUID_A)],
pending=[make_queue_item(_UUID_B, number=1)],
)
client = await aiohttp_client(build_app(queue))
resp = await client.post(
"/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]}
)
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1 # running job interrupted
assert queue.get_current_queue()[1] == [] # pending job dequeued
@pytest.mark.asyncio
async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):
history = {
_UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}},
_UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}},
}
queue = FakePromptQueue(history=history)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]})
# All known but terminal: 200 with cancelled=false, nothing dispatched.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_batch_missing_job_ids_is_400(self, aiohttp_client):
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={})
assert resp.status == 400
@pytest.mark.asyncio
async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client):
"""An unhashable element such as a dict or list must yield 400, not 500.
Previously, passing e.g. {"job_ids": [{}]} would reach the classify
loop where ``prompt_id in history`` raises TypeError on an unhashable
type, resulting in an unhandled 500. The input-validation guard must
catch this before any queue or history access.
"""
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]})
assert resp.status == 400
body = await resp.json()
assert "invalid_ids" in body
# No queue side effects.
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client):
"""A string that is not a valid UUID must be rejected with 400."""
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post(
"/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]}
)
assert resp.status == 400
body = await resp.json()
assert "invalid_ids" in body