diff --git a/README.md b/README.md index c75353d36..bcec86377 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release -2. **[ComfyUI Desktop](https://github.com/Comfy-Org/Comfy-Desktop)** +2. **[Comfy Desktop](https://github.com/Comfy-Org/Comfy-Desktop)** - Builds a new release using the latest stable core version 3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)** diff --git a/comfy/context_windows.py b/comfy/context_windows.py index db57537a2..5f9899c67 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -8,6 +8,8 @@ from abc import ABC, abstractmethod import logging import comfy.model_management import comfy.patcher_extension +import comfy.utils +import comfy.conds if TYPE_CHECKING: from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher @@ -51,12 +53,18 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0): self.index_list = index_list self.context_length = len(index_list) + self.context_overlap = context_overlap self.dim = dim self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) + self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} + self.guide_frames_indices: list[int] = [] + self.guide_overlap_info: list[tuple[int, int]] = [] + self.guide_kf_local_positions: list[int] = [] + self.guide_downscale_factors: list[int] = [] def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: @@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC): region_idx = int(self.center_ratio * num_regions) return min(max(region_idx, 0), num_regions - 1) + def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow': + if modality_idx == 0: + return self + return self.modality_windows[modality_idx] + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) +def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]): + """Compute which concatenated guide frames overlap with a context window. + + Each guide's latent-space start is derived from its first token's pixel-t-start + in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the + model's temporal_downscale_ratio. + + Args: + guide_entries: list of guide_attention_entry dicts + keyframe_idxs: per-token pixel coords cond tensor for the modality + temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio + window_index_list: the window's frame indices into the video portion + + Returns: + suffix_indices: indices into the guide_frames tensor for frame selection + overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment + kf_local_positions: window-local frame positions for keyframe_idxs regeneration + total_overlap: total number of overlapping guide frames + """ + window_set = set(window_index_list) + window_list = list(window_index_list) + suffix_indices = [] + overlap_info = [] + kf_local_positions = [] + suffix_base = 0 + token_offset = 0 + + for entry_idx, entry in enumerate(guide_entries): + first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item()) + latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio + guide_len = entry["latent_shape"][0] + entry_overlap = 0 + + for local_offset in range(guide_len): + video_pos = latent_start + local_offset + if video_pos in window_set: + suffix_indices.append(suffix_base + local_offset) + kf_local_positions.append(window_list.index(video_pos)) + entry_overlap += 1 + + if entry_overlap > 0: + overlap_info.append((entry_idx, entry_overlap)) + suffix_base += guide_len + token_offset += entry["pre_filter_count"] + + return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) + + +@dataclass +class WindowingState: + """Per-modality context windowing state for each step, + built using IndexListContextHandler._build_window_state(). + For non-multimodal models the lists are length 1 + """ + latents: list[torch.Tensor] # per-modality working latents (guide frames stripped) + guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents + guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata + keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation + latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal) + dim: int = 0 # primary modality temporal dim for context windowing + is_multimodal: bool = False + temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio + + def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow: + """Reformat window for multimodal contexts by deriving per-modality index lists. + Non-multimodal contexts return the input window unchanged. + """ + if not self.is_multimodal: + return window + + x = self.latents[0] + primary_total = self.latent_shapes[0][self.dim] + primary_overlap = window.context_overlap + map_shapes = self.latent_shapes + if x.size(self.dim) != primary_total: + map_shapes = list(self.latent_shapes) + video_shape = list(self.latent_shapes[0]) + video_shape[self.dim] = x.size(self.dim) + map_shapes[0] = torch.Size(video_shape) + try: + per_modality_indices = model.map_context_window_to_modalities( + window.index_list, map_shapes, self.dim) + except AttributeError: + raise NotImplementedError( + f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") + modality_windows = {} + for mod_idx in range(1, len(self.latents)): + modality_total_frames = self.latents[mod_idx].shape[self.dim] + ratio = modality_total_frames / primary_total if primary_total > 0 else 1 + modality_overlap = max(round(primary_overlap * ratio), 0) + modality_windows[mod_idx] = IndexListContextWindow( + per_modality_indices[mod_idx], dim=self.dim, + total_frames=modality_total_frames, + context_overlap=modality_overlap) + return IndexListContextWindow( + window.index_list, dim=self.dim, total_frames=x.shape[self.dim], + modality_windows=modality_windows, context_overlap=primary_overlap) + + def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]: + """Slice latents for a context window, injecting guide frames where applicable. + For multimodal contexts, uses the modality-specific windows derived in prepare_window(). + """ + sliced = [] + guide_frame_counts = [] + for idx in range(len(self.latents)): + modality_window = window.get_window_for_modality(idx) + retain = retain_index_list if idx == 0 else [] + s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain) + if self.guide_entries[idx] is not None: + s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx) + else: + ng = 0 + sliced.append(s) + guide_frame_counts.append(ng) + return sliced, guide_frame_counts + + def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow): + """Strip injected guide frames from per-cond, per-modality outputs in place.""" + for idx in range(len(self.latents)): + if guide_frame_counts[idx] > 0: + window_len = len(window.get_window_for_modality(idx).index_list) + for ci in range(len(out_per_modality)): + out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len) + + def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]: + guide_entries = self.guide_entries[modality_idx] + guide_frames = self.guide_latents[modality_idx] + keyframe_idxs = self.keyframe_idxs[modality_idx] + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap( + guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list) + # Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0. + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + kf_local_pos = [p + 1 for p in kf_local_pos] + window.guide_frames_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. + # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. + guide_downscale_factors = [] + if guide_frame_count > 0: + full_H = guide_frames.shape[3] + for entry_idx, _ in overlap_info: + entry_H = guide_entries[entry_idx]["latent_shape"][1] + guide_downscale_factors.append(full_H // entry_H) + window.guide_downscale_factors = guide_downscale_factors + + if guide_frame_count > 0: + idx = tuple([slice(None)] * self.dim + [suffix_idx]) + return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count + return latent_slice, 0 + + def patch_latent_shapes(self, sub_conds, new_shapes): + if not self.is_multimodal: + return + + for cond_list in sub_conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + + @dataclass class ContextSchedule: name: str @@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co class IndexListContextHandler(ContextHandlerABC): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, - causal_window_fix: bool=True): + latent_retain_index_list: list[int]=[], causal_window_fix: bool=True): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC): self.freenoise = freenoise self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.split_conds_to_windows = split_conds_to_windows + self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else [] self.causal_window_fix = causal_window_fix self.callbacks = {} + @staticmethod + def _get_latent_shapes(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + return model_conds['latent_shapes'].cond + return None + + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + + @staticmethod + def _get_keyframe_idxs(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + kf = model_conds.get('keyframe_idxs') + if kf is not None and hasattr(kf, 'cond') and kf.cond is not None: + return kf.cond + return None + + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: + """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio. + If guide frames are present on the primary modality, only the video portion is shuffled. + """ + guide_entries = self._get_guide_entries(conds) + guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 + + latent_shapes = self._get_latent_shapes(conds) + if latent_shapes is not None and len(latent_shapes) > 1: + modalities = comfy.utils.unpack_latents(noise, latent_shapes) + primary_total = latent_shapes[0][self.dim] + primary_video_count = modalities[0].size(self.dim) - guide_count + apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed) + for i in range(1, len(modalities)): + mod_total = latent_shapes[i][self.dim] + ratio = mod_total / primary_total if primary_total > 0 else 1 + mod_ctx_len = max(round(self.context_length * ratio), 1) + mod_ctx_overlap = max(round(self.context_overlap * ratio), 0) + modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) + noise, _ = comfy.utils.pack_latents(modalities) + return noise + video_count = noise.size(self.dim) - guide_count + apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed) + return noise + + def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState: + """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" + latent_shapes = self._get_latent_shapes(conds) + is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 + unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in] + + unpacked_latents_list = list(unpacked_latents) + guide_latents_list = [None] * len(unpacked_latents) + guide_entries_list = [None] * len(unpacked_latents) + keyframe_idxs_list = [None] * len(unpacked_latents) + + extracted_guide_entries = self._get_guide_entries(conds) + extracted_keyframe_idxs = self._get_keyframe_idxs(conds) + + # Strip guide frames (only from first modality for now) + if extracted_guide_entries is not None: + guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries) + if guide_count > 0: + x = unpacked_latents[0] + latent_count = x.size(self.dim) - guide_count + unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count) + guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count) + guide_entries_list[0] = extracted_guide_entries + keyframe_idxs_list[0] = extracted_keyframe_idxs + + + return WindowingState( + latents=unpacked_latents_list, + guide_latents=guide_latents_list, + guide_entries=guide_entries_list, + keyframe_idxs=keyframe_idxs_list, + latent_shapes=latent_shapes, + dim=self.dim, + is_multimodal=is_multimodal, + temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio) + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - # for now, assume first dim is batch - should have stored on BaseModel in actual implementation - if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute + total_frame_count = window_state.latents[0].size(self.dim) + if total_frame_count > self.context_length: + logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") + if self.latent_retain_index_list: + logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}") return True + logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).") return False def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: @@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) + sample_sigmas = model_options["transformer_options"]["sample_sigmas"] + current_timestep = timestep[0].to(sample_sigmas.dtype) + mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: return # substep from multi-step sampler: keep self._step from the last full step @@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): self._model = model self.set_step(timestep, model_options) - context_windows = self.get_context_windows(model, x_in, model_options) - enumerated_context_windows = list(enumerate(context_windows)) - conds_final = [torch.zeros_like(x_in) for _ in conds] + window_state = self._build_window_state(x_in, conds, model) + num_modalities = len(window_state.latents) + + context_windows = self.get_context_windows(model, window_state.latents[0], model_options) + enumerated_context_windows = list(enumerate(context_windows)) + total_windows = len(enumerated_context_windows) + + # Initialize per-modality accumulators (length 1 for single-modality) + accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] else: - counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) + # accumulate results from each context window for enum_window in enumerated_context_windows: - results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + results = self.evaluate_context_windows( + calc_cond_batch, model, x_in, conds, timestep, [enum_window], + model_options, window_state=window_state, total_windows=total_windows) for result in results: - self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) + # result.sub_conds_out is per-cond, per-modality: list[list[Tensor]] + for mod_idx in range(num_modalities): + mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))] + modality_window = result.window.get_window_for_modality(mod_idx) + self.combine_context_window_results( + window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window, + result.window_idx, total_windows, timestep, + accum[mod_idx], counts[mod_idx], biases[mod_idx]) + + # fuse accumulated results into final conds try: - # finalize conds - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - # relative is already normalized, so return as is - del counts_final - return conds_final - else: - # normalize conds via division by context usage counts - for i in range(len(conds_final)): - conds_final[i] /= counts_final[i] - del counts_final - return conds_final + result_out = [] + for ci in range(len(conds)): + finalized = [] + for mod_idx in range(num_modalities): + if self.fuse_method.name != ContextFuseMethods.RELATIVE: + accum[mod_idx][ci] /= counts[mod_idx][ci] + f = accum[mod_idx][ci] + + # if guide frames were injected, append them to the end of the fused latents for the next step + if window_state.guide_latents[mod_idx] is not None: + f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim) + finalized.append(f) + + # pack modalities together if needed + if window_state.is_multimodal and len(finalized) > 1: + packed, _ = comfy.utils.pack_latents(finalized) + else: + packed = finalized[0] + + result_out.append(packed) + return result_out finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], - model_options, device=None, first_device=None): + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, + timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, window_state: WindowingState, total_windows: int = None, + device=None, first_device=None): + """Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out + + For each window: + 1. Builds windows (for each modality if multimodal) + 2. Slices window for each modality + 3. Injects concatenated latent guide frames where present + 4. Packs together if needed and calls model + 5. Unpacks and strips any guides from outputs + """ + x = window_state.latents[0] + results: list[ContextResults] = [] for window_idx, window in enumerated_context_windows: # allow processing to end between context window executions for faster Cancel comfy.model_management.throw_exception_if_processing_interrupted() - # causal_window_fix: prepend a pre-window frame that will be stripped post-forward + # prepare the window accounting for multimodal windows + window = window_state.prepare_window(window, model) + + # causal_window_fix: prepend a pre-window frame that will be stripped post-forward. + # Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up. anchor_applied = False if self.causal_window_fix: anchor_idx = window.index_list[0] - 1 @@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC): window.causal_anchor_index = anchor_idx anchor_applied = True + # slice the window for each modality, injecting guide frames where applicable + sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device) + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) - # update exposed params + logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}" + + (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "") + ) + + # if multimodal, pack modalities together + if window_state.is_multimodal and len(sliced) > 1: + sub_x, sub_shapes = comfy.utils.pack_latents(sliced) + else: + sub_x, sub_shapes = sliced[0], [sliced[0].shape] + + # get resized conds for window model_options["transformer_options"]["context_window"] = window - # get subsections of x, timestep, conds - sub_x = window.get_tensor(x_in, device) - sub_timestep = window.get_tensor(timestep, device, dim=0) - sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + sub_timestep = window.get_tensor(timestep, dim=0) + sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds] + # if multimodal, patch latent_shapes in conds for correct unpacking in model + window_state.patch_latent_shapes(sub_conds, sub_shapes) + + # call model on window sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - if device is not None: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].to(x_in.device) - # strip causal_window_fix anchor if applied + # unpack outputs + out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + + # strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct if anchor_applied: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1) + for ci in range(len(out_per_modality)): + t = out_per_modality[ci][0] + out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1) - results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + # strip injected guide frames + window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window) + + results.append(ContextResults(window_idx, out_per_modality, sub_conds, window)) return results @@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC): biases_final[i][idx] = bias_total + bias else: # add conds and counts based on weights of fuse method - weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) + weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap) weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) for i in range(len(sub_conds_out)): window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) @@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC): callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) -def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): - # limit noise_shape length to context_length for more accurate vram use estimation +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs): + # Scale noise_shape to a single context window so VRAM estimation budgets per-window. model_options = kwargs.get("model_options", None) if model_options is None: raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") handler: IndexListContextHandler = model_options.get("context_handler", None) if handler is not None: noise_shape = list(noise_shape) - noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) - return executor(model, noise_shape, *args, **kwargs) + is_packed = len(noise_shape) == 3 and noise_shape[1] == 1 + if is_packed: + # TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a + # per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM. + pass + elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + return executor(model, noise_shape, conds, *args, **kwargs) def create_prepare_sampling_wrapper(model: ModelPatcher): @@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + conds = [guider.conds.get('positive', guider.conds.get('negative', []))] + noise = handler._apply_freenoise(noise, conds, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - def create_sampler_sample_wrapper(model: ModelPatcher): model.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, @@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher): _sampler_sample_wrapper ) - def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) @@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: return ContextSchedule(context_schedule, func) -def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): - return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None): + context_overlap = handler.context_overlap if context_overlap is None else context_overlap + return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap) def create_weights_flat(length: int, **kwargs) -> list[float]: @@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]: weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) return weight_sequence -def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): +def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs): # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 # only expected overlap is given different weights weights_torch = torch.ones((length)) # blend left-side on all except first window if min(idxs) > 0: - ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) - weights_torch[:handler.context_overlap] = ramp_up + ramp_up = torch.linspace(1e-37, 1, context_overlap) + weights_torch[:context_overlap] = ramp_up # blend right-side on all except last window if max(idxs) < full_length-1: - ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) - weights_torch[-handler.context_overlap:] = ramp_down + ramp_down = torch.linspace(1, 1e-37, context_overlap) + weights_torch[-context_overlap:] = ramp_down return weights_torch class ContextFuseMethods: diff --git a/comfy/ldm/boogu/model.py b/comfy/ldm/boogu/model.py new file mode 100644 index 000000000..966f3c583 --- /dev/null +++ b/comfy/ldm/boogu/model.py @@ -0,0 +1,321 @@ +# Boogu-Image-0.1 transformer +# Architecture is an OmniGen2 derivative (see comfy/ldm/omnigen/omnigen2.py) with an +# added dual-stream ("double_stream") stage before the single-stream layers, conditioned +# by a Qwen3-VL multimodal LLM. Reuses the OmniGen2/Lumina building blocks and the Flux +# RoPE core, the only new component is the double-stream block + the hybrid forward order. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange + +import comfy.ldm.common_dit +import comfy.ldm.omnigen.omnigen2 +from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.ldm.omnigen.omnigen2 import ( + OmniGen2RotaryPosEmbed, + Lumina2CombinedTimestepCaptionEmbedding, + LuminaRMSNormZero, + LuminaLayerNormContinuous, + LuminaFeedForward, + Attention, + OmniGen2TransformerBlock, + apply_rotary_emb, +) + +class BooguDoubleStreamProcessor(nn.Module): + # Joint attention over [instruct ; img] with separate per-stream q/k/v and output projections. + def __init__(self, dim, head_dim, heads, kv_heads, dtype=None, device=None, operations=None): + super().__init__() + query_dim = head_dim * heads + kv_dim = head_dim * kv_heads + + self.img_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device) + self.img_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device) + self.img_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device) + + self.instruct_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device) + self.instruct_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device) + self.instruct_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device) + + self.instruct_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device) + self.img_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device) + + def forward(self, attn, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}): + batch_size = img_hidden_states.shape[0] + L_instruct = instruct_hidden_states.shape[1] + + img_q = self.img_to_q(img_hidden_states) + img_k = self.img_to_k(img_hidden_states) + img_v = self.img_to_v(img_hidden_states) + + instruct_q = self.instruct_to_q(instruct_hidden_states) + instruct_k = self.instruct_to_k(instruct_hidden_states) + instruct_v = self.instruct_to_v(instruct_hidden_states) + + # Concatenate instruction first, then image (matches reference processor order). + query = torch.cat([instruct_q, img_q], dim=1) + key = torch.cat([instruct_k, img_k], dim=1) + value = torch.cat([instruct_v, img_v], dim=1) + + query = query.view(batch_size, -1, attn.heads, attn.dim_head) + key = key.view(batch_size, -1, attn.kv_heads, attn.dim_head) + value = value.view(batch_size, -1, attn.kv_heads, attn.dim_head) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if rotary_emb is not None: + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if attn.kv_heads < attn.heads: + key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1) + value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1) + + hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options) + + # Split back to instruction/image, apply per-stream output projections, recombine. + instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct]) + img_hidden_states = self.img_out(hidden_states[:, L_instruct:]) + hidden_states = torch.cat([instruct_hidden_states, img_hidden_states], dim=1) + + hidden_states = attn.to_out[0](hidden_states) + return hidden_states + + +class BooguJointAttention(nn.Module): + # Holds the shared q/k RMSNorm + final output projection + def __init__(self, dim, head_dim, heads, kv_heads, eps=1e-5, dtype=None, device=None, operations=None): + super().__init__() + self.heads = heads + self.kv_heads = kv_heads + self.dim_head = head_dim + self.scale = head_dim ** -0.5 + + self.norm_q = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device) + self.to_out = nn.Sequential( + operations.Linear(heads * head_dim, dim, bias=False, dtype=dtype, device=device), + nn.Dropout(0.0), + ) + self.processor = BooguDoubleStreamProcessor(dim, head_dim, heads, kv_heads, dtype=dtype, device=device, operations=operations) + + def forward(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}): + return self.processor(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask, transformer_options=transformer_options) + + +class BooguDoubleStreamBlock(nn.Module): + # Dual-stream block: joint attention over [instruct ; img] + image self-attention, each stream with its own modulation/MLP. + def __init__(self, dim, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=None, device=None, operations=None): + super().__init__() + head_dim = dim // num_attention_heads + + self.img_instruct_attn = BooguJointAttention(dim, head_dim, num_attention_heads, num_kv_heads, eps=1e-5, dtype=dtype, device=device, operations=operations) + self.img_self_attn = Attention( + query_dim=dim, dim_head=head_dim, heads=num_attention_heads, kv_heads=num_kv_heads, + eps=1e-5, bias=False, dtype=dtype, device=device, operations=operations, + ) + + self.img_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations) + self.instruct_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations) + + self.img_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.img_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.img_norm3 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.instruct_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.instruct_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + + self.img_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.img_self_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.img_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.img_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + + self.instruct_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.instruct_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.instruct_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + + def forward(self, img_hidden_states, instruct_hidden_states, joint_rotary_emb, img_rotary_emb, temb, joint_attention_mask=None, img_attention_mask=None, transformer_options={}): + L_instruct = instruct_hidden_states.shape[1] + + img_norm1_out, img_gate_msa, img_scale_mlp, img_gate_mlp = self.img_norm1(img_hidden_states, temb) + img_norm2_out, img_shift_mlp, _, _ = self.img_norm2(img_hidden_states, temb) + img_norm3_out, img_gate_self, _, _ = self.img_norm3(img_hidden_states, temb) + + instruct_norm1_out, instruct_gate_msa, instruct_scale_mlp, instruct_gate_mlp = self.instruct_norm1(instruct_hidden_states, temb) + instruct_norm2_out, instruct_shift_mlp, _, _ = self.instruct_norm2(instruct_hidden_states, temb) + + joint_attn_out = self.img_instruct_attn(img_norm1_out, instruct_norm1_out, joint_rotary_emb, joint_attention_mask, transformer_options=transformer_options) + instruct_attn_out = joint_attn_out[:, :L_instruct] + img_attn_out = joint_attn_out[:, L_instruct:] + + img_self_attn_out = self.img_self_attn(img_norm3_out, img_norm3_out, img_attention_mask, img_rotary_emb, transformer_options=transformer_options) + + img_hidden_states = img_hidden_states + img_gate_msa.unsqueeze(1).tanh() * self.img_attn_norm(img_attn_out) + img_hidden_states = img_hidden_states + img_gate_self.unsqueeze(1).tanh() * self.img_self_attn_norm(img_self_attn_out) + img_mlp_input = (1 + img_scale_mlp.unsqueeze(1)) * img_norm2_out + img_shift_mlp.unsqueeze(1) + img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_mlp_input)) + img_hidden_states = img_hidden_states + img_gate_mlp.unsqueeze(1).tanh() * self.img_ffn_norm2(img_mlp_out) + + instruct_hidden_states = instruct_hidden_states + instruct_gate_msa.unsqueeze(1).tanh() * self.instruct_attn_norm(instruct_attn_out) + instruct_mlp_input = (1 + instruct_scale_mlp.unsqueeze(1)) * instruct_norm2_out + instruct_shift_mlp.unsqueeze(1) + instruct_mlp_out = self.instruct_feed_forward(self.instruct_ffn_norm1(instruct_mlp_input)) + instruct_hidden_states = instruct_hidden_states + instruct_gate_mlp.unsqueeze(1).tanh() * self.instruct_ffn_norm2(instruct_mlp_out) + + return img_hidden_states, instruct_hidden_states + + +class BooguTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 3360, + num_layers: int = 32, + num_double_stream_layers: int = 8, + num_refiner_layers: int = 2, + num_attention_heads: int = 28, + num_kv_heads: int = 7, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (40, 40, 40), + axes_lens: Tuple[int, int, int] = (2048, 1664, 1664), + instruction_feat_dim: int = 4096, + timestep_scale: float = 1000.0, + image_model=None, + device=None, dtype=None, operations=None, + ): + super().__init__() + + self.patch_size = patch_size + self.out_channels = out_channels or in_channels + self.hidden_size = hidden_size + self.dtype = dtype + + self.rope_embedder = OmniGen2RotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device) + self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + text_feat_dim=instruction_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations + ) + + self.noise_refiner = nn.ModuleList([ + OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations) + for _ in range(num_refiner_layers) + ]) + + self.ref_image_refiner = nn.ModuleList([ + OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations) + for _ in range(num_refiner_layers) + ]) + + self.context_refiner = nn.ModuleList([ + OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations) + for _ in range(num_refiner_layers) + ]) + + self.double_stream_layers = nn.ModuleList([ + BooguDoubleStreamBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=dtype, device=device, operations=operations) + for _ in range(num_double_stream_layers) + ]) + + self.single_stream_layers = nn.ModuleList([ + OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations) + for _ in range(num_layers) + ]) + + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations + ) + + self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype)) + + # Patchify/refine helpers are identical to OmniGen2; reuse via bound methods. + flat_and_pad_to_seq = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.flat_and_pad_to_seq + img_patch_embed_and_refine = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.img_patch_embed_and_refine + + def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs): + B, C, H, W = x.shape + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + _, _, H_padded, W_padded = hidden_states.shape + timestep = 1.0 - timesteps + text_hidden_states = context + text_attention_mask = attention_mask + ref_image_hidden_states = ref_latents + device = hidden_states.device + + temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) + + ( + hidden_states, ref_image_hidden_states, + img_mask, ref_img_mask, + l_effective_ref_img_len, l_effective_img_len, + ref_img_sizes, img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + ( + context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb, + rotary_emb, encoder_seq_lengths, seq_lengths, + ) = self.rope_embedder( + hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0], + l_effective_ref_img_len, l_effective_img_len, + ref_img_sizes, img_sizes, device, + ) + + for layer in self.context_refiner: + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options) + + img_len = hidden_states.shape[1] + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, ref_image_hidden_states, + img_mask, ref_img_mask, + noise_rotary_emb, ref_img_rotary_emb, + l_effective_ref_img_len, l_effective_img_len, + temb, + transformer_options=transformer_options, + ) + + # Double-stream stage: the image self-attention only sees the [ref ; noise] tokens, + # which sit after the instruction tokens in the joint rope. + L_instruct = text_hidden_states.shape[1] + combined_img_rotary_emb = rotary_emb[:, L_instruct:] + for layer in self.double_stream_layers: + combined_img_hidden_states, text_hidden_states = layer( + combined_img_hidden_states, text_hidden_states, + rotary_emb, combined_img_rotary_emb, temb, + joint_attention_mask=None, img_attention_mask=None, + transformer_options=transformer_options, + ) + + hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) + + for layer in self.single_stream_layers: + hidden_states = layer(hidden_states, None, rotary_emb, temb, transformer_options=transformer_options) + + hidden_states = self.norm_out(hidden_states, temb) + + p = self.patch_size + output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded // p, p1=p, p2=p)[:, :, :H, :W] + + return -output diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 671fe834d..aec874815 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -515,7 +515,7 @@ class Block(nn.Module): h=H, w=W, ) - x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) def _x_fn( _x_B_T_H_W_D: torch.Tensor, @@ -548,7 +548,7 @@ class Block(nn.Module): shift_cross_attn_B_T_1_1_D, transformer_options=transformer_options, ) - x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) normalized_x_B_T_H_W_D = _fn( x_B_T_H_W_D, @@ -557,7 +557,7 @@ class Block(nn.Module): shift_mlp_B_T_1_1_D, ) result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) - x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) return x_B_T_H_W_D diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index e0a4a0f9b..9953b6679 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel): ) grid_mask = None - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: additional_args.update({ "orig_patchified_shape": list(x.shape)}) denoise_mask = self.patchifier.patchify(denoise_mask)[0] grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] @@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel): x = x * (1 + scale) + shift x = self.proj_out(x) - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: grid_mask = kwargs["grid_mask"] orig_patchified_shape = kwargs["orig_patchified_shape"] full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index e9ca5229d..b8da4cf39 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -22,7 +22,7 @@ def apply_rotary_emb(x, freqs_cis): def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return F.silu(x) * y + return F.silu(x, inplace=True).mul_(y) class TimestepEmbedding(nn.Module): diff --git a/comfy/model_base.py b/comfy/model_base.py index d143dc06f..264dbb9b3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model +import comfy.ldm.lightricks.symmetric_patchifier import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC @@ -54,6 +55,7 @@ import comfy.ldm.pixeldit.model import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.boogu.model import comfy.ldm.qwen_image.model import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model @@ -1203,6 +1205,127 @@ class LTXAV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): + result = [primary_indices] + if len(latent_shapes) < 2: + return result + + video_total = latent_shapes[0][dim] + + for i in range(1, len(latent_shapes)): + mod_total = latent_shapes[i][dim] + # Map each primary index to its proportional range of modality indices and + # concatenate in order. Preserves wrapped/strided geometry so the modality + # attends to the same temporal regions as the primary window. + mod_indices = [] + seen = set() + for v_idx in primary_indices: + a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1) + a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total) + if a_end <= a_start: + a_end = a_start + 1 + for a in range(a_start, a_end): + if a not in seen: + seen.add(a) + mod_indices.append(a) + result.append(mod_indices) + + return result + + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # Audio denoise mask — slice using audio modality window + if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: + audio_window = window.modality_windows.get(1) + if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) + return cond_value._copy_with(sliced) + + # Video denoise mask — split into video + guide portions, slice each + if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + cond_tensor = cond_value.cond + guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) + if guide_count > 0: + T_video = x_in.size(window.dim) + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + suffix_indices = window.guide_frames_indices + if suffix_indices: + idx = tuple([slice(None)] * window.dim + [suffix_indices]) + sliced_guide = guide_mask[idx].to(device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + else: + return cond_value._copy_with(sliced_video) + + # Keyframe indices — regenerate pixel coords for window, select guide positions + if cond_key == "keyframe_idxs": + kf_local_pos = window.guide_kf_local_positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + # account for causal_window_fix anchor in coord space size + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + window_len += 1 + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + scale_factors = self.diffusion_model.vae_scale_factors + pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords( + latent_coords, + scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + + # Adjust spatial end positions for dilated (downscaled) guides. + # Each guide entry may have a different downscale factor; expand the + # per-entry factor to cover all tokens belonging to that entry. + downscale_factors = window.guide_downscale_factors + overlap_info = window.guide_overlap_info + if downscale_factors: + per_token_factor = [] + for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors): + per_token_factor.extend([dsf] * (overlap_count * H * W)) + factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype) + spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor( + scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset + + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + # Guide attention entries — adjust per-guide counts based on window overlap + if cond_key == "guide_attention_entries": + overlap_info = window.guide_overlap_info + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) + + return None + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) @@ -2103,6 +2226,11 @@ class Omnigen2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class Boogu(Omnigen2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(Omnigen2, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.boogu.model.BooguTransformer2DModel) + self.memory_usage_factor_conds = ("ref_latents",) + class QwenImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0cab308..b773f0393 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -761,6 +761,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if '{}double_stream_layers.0.img_instruct_attn.processor.img_to_q.weight'.format(key_prefix) in state_dict_keys: # Boogu-Image (OmniGen2 derivative + dual-stream stage) + dit_config = {} + dit_config["image_model"] = "boogu" + dit_config["hidden_size"] = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[0] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}single_stream_layers.'.format(key_prefix) + '{}.') + dit_config["num_double_stream_layers"] = count_blocks(state_dict_keys, '{}double_stream_layers.'.format(key_prefix) + '{}.') + dit_config["num_refiner_layers"] = count_blocks(state_dict_keys, '{}noise_refiner.'.format(key_prefix) + '{}.') + dit_config["instruction_feat_dim"] = state_dict['{}time_caption_embed.caption_embedder.0.weight'.format(key_prefix)].shape[0] + return dit_config + if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2 dit_config = {} dit_config["image_model"] = "omnigen2" diff --git a/comfy/sd.py b/comfy/sd.py index 688e6db90..d9b1c0553 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -68,6 +68,7 @@ import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 import comfy.text_encoders.qwen3vl +import comfy.text_encoders.boogu import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo @@ -1301,6 +1302,7 @@ class CLIPType(Enum): LENS = 28 PIXELDIT = 29 IDEOGRAM4 = 30 + BOOGU = 31 @@ -1622,6 +1624,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer + elif clip_type == CLIPType.BOOGU and te_model == TEModel.QWEN3VL_8B: # Boogu-Image: full Qwen3-VL-8B, last hidden state, no-think template. + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer + elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused. + klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b" + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type) + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B if te_model == TEModel.QWEN3VL_8B else comfy.text_encoders.flux.KleinTokenizer else: clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3be935577..cc05908ee 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -25,6 +25,7 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.ideogram4 +import comfy.text_encoders.boogu import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image @@ -1758,6 +1759,27 @@ class Omnigen2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) +class Boogu(Omnigen2): + unet_config = { + "image_model": "boogu", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.16, + } + + memory_usage_factor = 2.15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Boogu(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.boogu.BooguTokenizer, comfy.text_encoders.boogu.te(**hunyuan_detect)) + class Ideogram4(supported_models_base.BASE): unet_config = { "image_model": "ideogram4", @@ -2300,6 +2322,7 @@ models = [ ACEStep, ACEStep15, Omnigen2, + Boogu, QwenImage, Ideogram4, Flux2, diff --git a/comfy/text_encoders/boogu.py b/comfy/text_encoders/boogu.py new file mode 100644 index 000000000..d9de92f10 --- /dev/null +++ b/comfy/text_encoders/boogu.py @@ -0,0 +1,58 @@ +"""Boogu-Image text encoder: full Qwen3-VL-8B, last hidden state (4096-dim). + +Boogu uses the final hidden state of Qwen3-VL as the per-token instruction feature +(num_instruction_feature_layers=1, reduce_type=mean -> just the last layer). +The model itself is the standard Qwen3-VL TE, only the chat template differs +(a fixed system prompt and no block). +""" + +import comfy.text_encoders.qwen3vl +from comfy import sd1_clip + + +# System prompts from the reference pipeline (pipeline_boogu.py). +# T2I (non-empty instruction, no image) uses the helpful-assistant prompt +# everything else (the CFG negative / "drop" condition, and any image case) uses the TI2I "describe" prompt. +BOOGU_T2I_SYSTEM = "You are a helpful assistant that generates high-quality images based on user instructions. The instructions are as follows." +BOOGU_DROP_SYSTEM = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate." + + +class BooguTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b") + # apply_chat_template without add_generation_prompt + self.llama_template = "<|im_start|>system\n" + BOOGU_T2I_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n" + self.llama_template_images = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n" + # Reference SYSTEM_PROMPT_DROP: used for the empty negative/uncond instruction. + self.llama_template_drop = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs): + if llama_template is None and len(images) == 0 and text.strip() == "": + llama_template = self.llama_template_drop + # Boogu conditions on the no-think template; thinking=True drops the empty block qwen3vl adds by default. + return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs) + + +class BooguQwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel): + def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"): + super().__init__(device=device, dtype=dtype, attention_mask=attention_mask, model_options=model_options, model_type=model_type) + # apply the final RMSNorm to the tapped last layer + self.layer_norm_hidden_state = True + + +class BooguTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + clip_model = lambda **kw: BooguQwen3VLClipModel(**kw, model_type="qwen3vl_8b") + super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=clip_model, model_options=model_options) + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class BooguTEModel_(BooguTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return BooguTEModel_ diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index adb5a3144..0f30608a9 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -25,6 +25,11 @@ CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = { "default": False, "description": "Show the sign-in button in the frontend even when not signed in", }, + "enable_telemetry": { + "type": "bool", + "default": False, + "description": "Signal the frontend that telemetry collection is enabled", + }, } diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py index fe0f97cb3..2c98c23b7 100644 --- a/comfy_api_nodes/apis/kling.py +++ b/comfy_api_nodes/apis/kling.py @@ -149,3 +149,59 @@ class MotionControlRequest(BaseModel): character_orientation: str = Field(...) mode: str = Field(..., description="'pro' or 'std'") model_name: str = Field(...) + + +class Kling3TurboSettings(BaseModel): + resolution: str = Field("720p", description="'720p' or '1080p'") + aspect_ratio: str | None = Field(None, description="'16:9'/'9:16'/'1:1'; text-to-video only") + duration: int = Field(5, description="3-15 second") + + +class Kling3TurboText2VideoRequest(BaseModel): + prompt: str = Field(..., description="<=3072 chars; may use multi-shot 'shot n, m, words; ...'") + settings: Kling3TurboSettings | None = Field(None) + + +class Kling3TurboContent(BaseModel): + type: str = Field(..., description="'prompt' or 'first_frame'") + text: str | None = Field(None, description="for type=prompt; <=2500 chars") + url: str | None = Field(None, description="for type=first_frame") + + +class Kling3TurboImage2VideoRequest(BaseModel): + contents: list[Kling3TurboContent] = Field(..., description="prompt + first_frame materials") + settings: Kling3TurboSettings | None = Field(None) + + +class Kling3TurboCreateData(BaseModel): + id: str | None = Field(None, description="Task ID") + status: str | None = Field(None) + message: str | None = Field(None) + + +class Kling3TurboCreateResponse(BaseModel): + code: int | None = Field(None) + message: str | None = Field(None) + request_id: str | None = Field(None) + data: Kling3TurboCreateData | None = Field(None) + + +class Kling3TurboOutput(BaseModel): + type: str | None = Field(None, description="'video', 'image', 'audio', ...") + id: str | None = Field(None) + url: str | None = Field(None) + duration: str | None = Field(None) + + +class Kling3TurboTaskData(BaseModel): + id: str | None = Field(None) + status: str | None = Field(None, description="submitted | processing | succeeded | failed") + message: str | None = Field(None) + outputs: list[Kling3TurboOutput] | None = Field(None) + + +class Kling3TurboQueryResponse(BaseModel): + code: int | None = Field(None) + message: str | None = Field(None) + request_id: str | None = Field(None) + data: list[Kling3TurboTaskData] | None = Field(None) diff --git a/comfy_api_nodes/apis/luma.py b/comfy_api_nodes/apis/luma.py index 8c6db2022..2465c3b37 100644 --- a/comfy_api_nodes/apis/luma.py +++ b/comfy_api_nodes/apis/luma.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, confloat class LumaIO: LUMA_REF = "LUMA_REF" LUMA_CONCEPTS = "LUMA_CONCEPTS" + LUMA_RAY32_KEYFRAME = "LUMA_RAY32_KEYFRAME" class LumaReference: @@ -20,13 +21,14 @@ class LumaReference: def create_api_model(self, download_url: str): return LumaImageRef(url=download_url, weight=self.weight) + class LumaReferenceChain: - def __init__(self, first_ref: LumaReference=None): + def __init__(self, first_ref: LumaReference = None): self.refs: list[LumaReference] = [] if first_ref: self.refs.append(first_ref) - def add(self, luma_ref: LumaReference=None): + def add(self, luma_ref: LumaReference = None): self.refs.append(luma_ref) def create_api_model(self, download_urls: list[str], max_refs=4): @@ -124,7 +126,7 @@ def get_luma_concepts(include_none=False): "pull_out", "aerial", "crane_up", - "eye_level" + "eye_level", ] @@ -162,8 +164,8 @@ class LumaVideoModelOutputDuration(str, Enum): class LumaGenerationType(str, Enum): - video = 'video' - image = 'image' + video = "video" + image = "image" class LumaState(str, Enum): @@ -174,86 +176,109 @@ class LumaState(str, Enum): class LumaAssets(BaseModel): - video: Optional[str] = Field(None, description='The URL of the video') - image: Optional[str] = Field(None, description='The URL of the image') - progress_video: Optional[str] = Field(None, description='The URL of the progress video') + video: Optional[str] = Field(None, description="The URL of the video") + image: Optional[str] = Field(None, description="The URL of the image") + progress_video: Optional[str] = Field(None, description="The URL of the progress video") class LumaImageRef(BaseModel): """Used for image gen""" - url: str = Field(..., description='The URL of the image reference') - weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') + + url: str = Field(..., description="The URL of the image reference") + weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference") class LumaImageReference(BaseModel): """Used for video gen""" - type: Optional[str] = Field('image', description='Input type, defaults to image') - url: str = Field(..., description='The URL of the image') + + type: Optional[str] = Field("image", description="Input type, defaults to image") + url: str = Field(..., description="The URL of the image") class LumaModifyImageRef(BaseModel): - url: str = Field(..., description='The URL of the image reference') - weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') + url: str = Field(..., description="The URL of the image reference") + weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference") class LumaCharacterRef(BaseModel): - identity0: LumaImageIdentity = Field(..., description='The image identity object') + identity0: LumaImageIdentity = Field(..., description="The image identity object") class LumaImageIdentity(BaseModel): - images: list[str] = Field(..., description='The URLs of the image identity') + images: list[str] = Field(..., description="The URLs of the image identity") class LumaGenerationReference(BaseModel): - type: str = Field('generation', description='Input type, defaults to generation') - id: str = Field(..., description='The ID of the generation') + type: str = Field("generation", description="Input type, defaults to generation") + id: str = Field(..., description="The ID of the generation") class LumaKeyframes(BaseModel): - frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') - frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') + frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="") + frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="") class LumaConceptObject(BaseModel): - key: str = Field(..., description='Camera Concept name') + key: str = Field(..., description="Camera Concept name") class LumaImageGenerationRequest(BaseModel): - prompt: str = Field(..., description='The prompt of the generation') - model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation') - aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation') - image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects') - style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects') - character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object') - modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object') + prompt: str = Field(..., description="The prompt of the generation") + model: LumaImageModel = Field(LumaImageModel.photon_1, description="The image model used for the generation") + aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9) + image_ref: Optional[list[LumaImageRef]] = Field(None, description="List of image reference objects") + style_ref: Optional[list[LumaImageRef]] = Field(None, description="List of style reference objects") + character_ref: Optional[LumaCharacterRef] = Field(None, description="The image identity object") + modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description="The modify image reference object") class LumaGenerationRequest(BaseModel): - prompt: str = Field(..., description='The prompt of the generation') - model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation') - duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation') - aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation') - resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation') - loop: Optional[bool] = Field(None, description='Whether to loop the video') - keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation') - concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation') + prompt: str = Field(..., description="The prompt of the generation") + model: LumaVideoModel = Field(LumaVideoModel.ray_2, description="The video model used for the generation") + duration: Optional[LumaVideoModelOutputDuration] = Field(None, description="The duration of the generation") + aspect_ratio: Optional[LumaAspectRatio] = Field(None, description="The aspect ratio of the generation") + resolution: Optional[LumaVideoOutputResolution] = Field(None, description="The resolution of the generation") + loop: Optional[bool] = Field(None, description="Whether to loop the video") + keyframes: Optional[LumaKeyframes] = Field(None, description="The keyframes of the generation") + concepts: Optional[list[LumaConceptObject]] = Field(None, description="Camera Concepts to apply to generation") class LumaGeneration(BaseModel): - id: str = Field(..., description='The ID of the generation') - generation_type: LumaGenerationType = Field(..., description='Generation type, image or video') - state: LumaState = Field(..., description='The state of the generation') - failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation') - created_at: str = Field(..., description='The date and time when the generation was created') - assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') - model: str = Field(..., description='The model used for the generation') - request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") + id: str = Field(..., description="The ID of the generation") + generation_type: LumaGenerationType = Field(..., description="Generation type, image or video") + state: LumaState = Field(..., description="The state of the generation") + failure_reason: Optional[str] = Field(None, description="The reason for the state of the generation") + created_at: str = Field(..., description="The date and time when the generation was created") + assets: Optional[LumaAssets] = Field(None, description="The assets of the generation") + model: str = Field(..., description="The model used for the generation") + request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(...) class Luma2ImageRef(BaseModel): url: str | None = None data: str | None = None media_type: str | None = None + generation_id: str | None = Field(None, description="reference a prior generation (extend / source reuse)") + + +class Luma2VideoEdit(BaseModel): + """Edit controls for Ray 3.2 ``video_edit`` generations.""" + + auto_controls: bool | None = Field(None, description="derive a conditioning schedule from the source (recommended)") + strength: str | None = Field(None, description="'adhere_1' .. 'reimagine_3'; constrained by IO.Combo") + + +class Luma2VideoOptions(BaseModel): + """Ray 3.2 ``video`` output settings (text / image / keyframe / edit / extend).""" + + resolution: str | None = Field(None, description="360p | 540p | 720p | 1080p") + duration: str | None = Field(None, description="5s | 10s") + loop: bool | None = Field(None) + start_frame: Luma2ImageRef | None = Field(None) + end_frame: Luma2ImageRef | None = Field(None) + keyframes: list[Luma2ImageRef] | None = Field(None) + keyframe_indexes: list[int] | None = Field(None) + edit: Luma2VideoEdit | None = Field(None) class Luma2GenerationRequest(BaseModel): @@ -266,6 +291,7 @@ class Luma2GenerationRequest(BaseModel): web_search: bool | None = None image_ref: list[Luma2ImageRef] | None = None source: Luma2ImageRef | None = None + video: Luma2VideoOptions | None = Field(None) class Luma2Generation(BaseModel): @@ -277,3 +303,31 @@ class Luma2Generation(BaseModel): output: list[LumaImageReference] | None = None failure_reason: str | None = None failure_code: str | None = None + + +# --- Ray 3.2 multi-keyframe chain --- + +LUMA_KEYFRAME_MODE_FRACTION = "fraction" # value in [0.0, 1.0] of the output video duration +LUMA_KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the output + + +class LumaRay32KeyframeItem: + """One guide image anchored at a position on the Ray 3.2 output timeline.""" + + def __init__(self, image: torch.Tensor, mode: str, value: float): + self.image = image + self.mode = mode # LUMA_KEYFRAME_MODE_FRACTION | LUMA_KEYFRAME_MODE_SECONDS + self.value = value + + +class LumaRay32KeyframeChain: + def __init__(self): + self.items: list[LumaRay32KeyframeItem] = [] + + def add(self, item: LumaRay32KeyframeItem) -> None: + self.items.append(item) + + def clone(self) -> "LumaRay32KeyframeChain": + c = LumaRay32KeyframeChain() + c.items = list(self.items) + return c diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3d4be6065..a63625ada 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -5,7 +5,6 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer import base64 import os -from enum import Enum from fnmatch import fnmatch from io import BytesIO from typing import Any, Literal @@ -78,15 +77,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( ) -class GeminiImageModel(str, Enum): - """ - Gemini Image Model Names allowed by comfy-api - """ - - gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview" - gemini_2_5_flash_image = "gemini-2.5-flash-image" - - async def create_image_parts( cls: type[IO.ComfyNode], images: Input.Image | list[Input.Image], @@ -243,21 +233,15 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N if not response.modelVersion: return None # Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing - if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"): + if response.modelVersion == "gemini-2.5-pro": input_tokens_price = 1.25 output_text_tokens_price = 10.0 output_image_tokens_price = 0.0 - elif response.modelVersion in ( - "gemini-2.5-flash-preview-04-17", - "gemini-2.5-flash", - ): + elif response.modelVersion == "gemini-2.5-flash": input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 0.0 - elif response.modelVersion in ( - "gemini-2.5-flash-image-preview", - "gemini-2.5-flash-image", - ): + elif response.modelVersion == "gemini-2.5-flash-image": input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 30.0 @@ -455,8 +439,6 @@ class GeminiNode(IO.ComfyNode): IO.Combo.Input( "model", options=[ - "gemini-2.5-pro-preview-05-06", - "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-3-pro-preview", @@ -904,8 +886,7 @@ class GeminiImage(IO.ComfyNode): ), IO.Combo.Input( "model", - options=GeminiImageModel, - default=GeminiImageModel.gemini_2_5_flash_image, + options=["gemini-2.5-flash-image"], tooltip="The Gemini model to use for generating responses.", ), IO.Int.Input( diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index c81d3503d..b27de2549 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -60,6 +60,12 @@ from comfy_api_nodes.apis.kling import ( OmniProImageRequest, OmniProReferences2VideoRequest, OmniProText2VideoRequest, + Kling3TurboSettings, + Kling3TurboText2VideoRequest, + Kling3TurboContent, + Kling3TurboImage2VideoRequest, + Kling3TurboCreateResponse, + Kling3TurboQueryResponse, TaskStatusResponse, TextToVideoWithAudioRequest, ) @@ -2847,6 +2853,67 @@ class MotionControl(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +def build_turbo_shot_prompt(multi_prompt: list[MultiPromptEntry]) -> str: + """Render storyboard entries into the Turbo multi-shot prompt 'shot n, m, words; ...'.""" + return "; ".join(f"shot {i}, {int(e.duration)}, {e.prompt}" for i, e in enumerate(multi_prompt, 1)) + ";" + + +def _turbo_video_url(response: Kling3TurboQueryResponse) -> str: + """Extract the result video URL from a /tasks response (data[].outputs[] where type == 'video').""" + task = response.data[0] if response.data else None + if task and task.outputs: + for output in task.outputs: + if output.type == "video" and output.url: + return output.url + raise RuntimeError(f"Kling 3.0 Turbo task finished without a video output: {response.model_dump()}") + + +async def execute_kling_turbo( + cls: type[IO.ComfyNode], + *, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + start_frame: torch.Tensor | None, +) -> IO.NodeOutput: + """Create + poll a Kling 3.0 Turbo task. Image-to-video when start_frame is given, else text-to-video.""" + if start_frame is not None: + validate_image_dimensions(start_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) + contents = [Kling3TurboContent(type="first_frame", url=tensor_to_base64_string(start_frame))] + if prompt: + contents.insert(0, Kling3TurboContent(type="prompt", text=prompt)) + create = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/image-to-video/kling-3.0-turbo", method="POST"), + response_model=Kling3TurboCreateResponse, + data=Kling3TurboImage2VideoRequest( + contents=contents, + settings=Kling3TurboSettings(resolution=resolution, duration=duration), # i2v: no aspect_ratio + ), + ) + else: + create = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/text-to-video/kling-3.0-turbo", method="POST"), + response_model=Kling3TurboCreateResponse, + data=Kling3TurboText2VideoRequest( + prompt=prompt, + settings=Kling3TurboSettings(resolution=resolution, aspect_ratio=aspect_ratio, duration=duration), + ), + ) + if not (create.data and create.data.id): + raise RuntimeError(f"Kling 3.0 Turbo create failed. Code: {create.code}, Message: {create.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path="/proxy/kling/tasks", query_params={"task_ids": create.data.id}), + response_model=Kling3TurboQueryResponse, + status_extractor=lambda r: (r.data[0].status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(_turbo_video_url(final_response))) + + class KlingVideoNode(IO.ComfyNode): @classmethod @@ -2884,7 +2951,11 @@ class KlingVideoNode(IO.ComfyNode): ], tooltip="Generate a series of video segments with individual prompts and durations.", ), - IO.Boolean.Input("generate_audio", default=True), + IO.Boolean.Input( + "generate_audio", + default=True, + tooltip="'kling-3.0-turbo' always generates native audio, so the audio toggle is ignored.", + ), IO.DynamicCombo.Input( "model", options=[ @@ -2899,6 +2970,17 @@ class KlingVideoNode(IO.ComfyNode): ), ], ), + IO.DynamicCombo.Option( + "kling-3.0-turbo", + [ + IO.Combo.Input("resolution", options=["1080p", "720p"], default="720p"), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1"], + tooltip="Ignored in image-to-video mode.", + ), + ], + ), ], tooltip="Model and generation settings.", ), @@ -2930,6 +3012,7 @@ class KlingVideoNode(IO.ComfyNode): price_badge=IO.PriceBadge( depends_on=IO.PriceBadgeDepends( widgets=[ + "model", "model.resolution", "generate_audio", "multi_shot", @@ -2944,14 +3027,7 @@ class KlingVideoNode(IO.ComfyNode): ), expr=""" ( - $rates := { - "4k": {"off": 0.42, "on": 0.42}, - "1080p": {"off": 0.112, "on": 0.168}, - "720p": {"off": 0.084, "on": 0.126} - }; $res := $lookup(widgets, "model.resolution"); - $audio := widgets.generate_audio ? "on" : "off"; - $rate := $lookup($lookup($rates, $res), $audio); $ms := widgets.multi_shot; $isSb := $ms != "disabled"; $n := $isSb ? $number($substring($ms, 0, 1)) : 0; @@ -2962,7 +3038,18 @@ class KlingVideoNode(IO.ComfyNode): $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0; $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0; $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration"); - {"type":"usd","usd": $rate * $dur} + widgets.model = "kling-3.0-turbo" + ? {"type":"usd","usd": ($res = "1080p" ? 0.14 : 0.112) * $dur} + : ( + $rates := { + "4k": {"off": 0.42, "on": 0.42}, + "1080p": {"off": 0.112, "on": 0.168}, + "720p": {"off": 0.084, "on": 0.126} + }; + $audio := widgets.generate_audio ? "on" : "off"; + $rate := $lookup($lookup($rates, $res), $audio); + {"type":"usd","usd": $rate * $dur} + ) ) """, ), @@ -3015,6 +3102,17 @@ class KlingVideoNode(IO.ComfyNode): duration = multi_shot["duration"] validate_string(multi_shot["prompt"], min_length=1, max_length=2500) + if model["model"] == "kling-3.0-turbo": + turbo_prompt = build_turbo_shot_prompt(multi_prompt_list) if custom_multi_shot else multi_shot["prompt"] + return await execute_kling_turbo( + cls, + prompt=turbo_prompt, + resolution=model["resolution"], + aspect_ratio=model["aspect_ratio"], + duration=duration, + start_frame=start_frame, + ) + if start_frame is not None: validate_image_dimensions(start_frame, min_width=300, min_height=300) validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 0d31ac77e..cdfa32d8b 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -3,9 +3,13 @@ from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.luma import ( + LUMA_KEYFRAME_MODE_FRACTION, + LUMA_KEYFRAME_MODE_SECONDS, Luma2Generation, Luma2GenerationRequest, Luma2ImageRef, + Luma2VideoEdit, + Luma2VideoOptions, LumaAspectRatio, LumaCharacterRef, LumaConceptChain, @@ -18,6 +22,8 @@ from comfy_api_nodes.apis.luma import ( LumaIO, LumaKeyframes, LumaModifyImageRef, + LumaRay32KeyframeChain, + LumaRay32KeyframeItem, LumaReference, LumaReferenceChain, LumaVideoModel, @@ -33,6 +39,7 @@ from comfy_api_nodes.util import ( sync_op, upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_string, ) @@ -692,7 +699,10 @@ async def _luma2_upload_image_refs( async def _luma2_submit_and_poll( cls: type[IO.ComfyNode], request: Luma2GenerationRequest, -) -> Input.Image: + *, + estimated_duration: int | None = None, +) -> Luma2Generation: + """Submit a Luma Agents generation and poll until done; returns the completed generation.""" initial = await sync_op( cls, ApiEndpoint(path="/proxy/luma_2/generations", method="POST"), @@ -700,21 +710,21 @@ async def _luma2_submit_and_poll( data=request, ) if not initial.id: - raise RuntimeError("Luma 2 API did not return a generation id.") + raise RuntimeError("Luma API did not return a generation id.") final = await poll_op( cls, ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"), response_model=Luma2Generation, status_extractor=lambda r: r.state, progress_extractor=lambda r: None, + estimated_duration=estimated_duration, ) - if not final.output: + if not final.output or not final.output[0].url: msg = final.failure_reason or "no output returned" - raise RuntimeError(f"Luma 2 generation failed: {msg}") - url = final.output[0].url - if not url: - raise RuntimeError("Luma 2 generation completed without an output URL.") - return await download_url_to_image_tensor(url) + if final.failure_code: + msg = f"{msg} [{final.failure_code}]" + raise RuntimeError(f"Luma generation failed: {msg}") + return final class LumaImageNode(IO.ComfyNode): @@ -843,7 +853,8 @@ class LumaImageNode(IO.ComfyNode): web_search=model["web_search"], image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9), ) - return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + final = await _luma2_submit_and_poll(cls, request) + return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url)) class LumaImageEditNode(IO.ComfyNode): @@ -929,7 +940,533 @@ class LumaImageEditNode(IO.ComfyNode): web_search=model["web_search"], image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8), ) - return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + final = await _luma2_submit_and_poll(cls, request) + return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url)) + + +_BADGE_RAY32_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]), + expr=""" + ( + $p := { + "360p": {"5s": 0.06, "10s": 0.18}, + "540p": {"5s": 0.15, "10s": 0.45}, + "720p": {"5s": 0.3, "10s": 0.9}, + "1080p": {"5s": 1.2, "10s": 3.6} + }; + {"type": "usd", "usd": $lookup($lookup($p, widgets.resolution), widgets.duration)} + ) + """, +) + +_BADGE_RAY32_VIDEO_5S = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := {"360p": 0.06, "540p": 0.15, "720p": 0.3, "1080p": 1.2}; + {"type": "usd", "usd": $lookup($p, widgets.resolution)} + ) + """, +) + +_BADGE_RAY32_EDIT = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := { + "360p": {"min": 0.54, "max": 1.08}, + "540p": {"min": 0.72, "max": 1.44}, + "720p": {"min": 1.08, "max": 2.16}, + "1080p": {"min": 2.16, "max": 4.32} + }; + $r := $lookup($p, widgets.resolution); + {"type": "range_usd", "min_usd": $r.min, "max_usd": $r.max, "format": {"note": "(by source length)"}} + ) + """, +) + +_BADGE_RAY32_REFRAME = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := {"360p": 0.03, "540p": 0.06, "720p": 0.12, "1080p": 0.36}; + {"type": "usd", "usd": $lookup($p, widgets.resolution), "format": {"suffix": "/second"}} + ) + """, +) + + +def _ray32_seed_input() -> IO.Input: + return IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; results are nondeterministic regardless of seed.", + ) + + +async def _ray32_generate(cls: type[IO.ComfyNode], request: Luma2GenerationRequest) -> IO.NodeOutput: + """Run a ray-3.2 generation and return (video, generation_id).""" + final = await _luma2_submit_and_poll(cls, request, estimated_duration=120) + video = await download_url_to_video_output(final.output[0].url) + return IO.NodeOutput(video, final.id or "") + + +class LumaRay32TextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32TextToVideoNode", + display_name="Luma Ray 3.2 Text to Video", + category="partner/video/Luma", + description="Generate a video from a text prompt using Luma's Ray 3.2 model.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input("duration", options=["5s", "10s"]), + IO.Boolean.Input( + "loop", + default=False, + tooltip="Make the video loop seamlessly. Only available with 5s duration.", + ), + _ray32_seed_input(), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO, + ) + + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, resolution: str, duration: str, loop: bool, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + if loop and duration == "10s": + raise ValueError("Looping is only available with 5s duration on Ray 3.2.") + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video", + aspect_ratio=aspect_ratio, + video=Luma2VideoOptions(resolution=resolution, duration=duration, loop=loop or None), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32ImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32ImageToVideoNode", + display_name="Luma Ray 3.2 Image to Video", + category="partner/video/Luma", + description="Generate a video from a start and/or end frame using Luma's Ray 3.2 model. " + "Image-anchored generations are always 5 seconds.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Boolean.Input( + "loop", + default=False, + tooltip="Make the video loop seamlessly. Not available when an end_frame is set.", + ), + _ray32_seed_input(), + IO.Image.Input("start_frame", optional=True, tooltip="First frame of the generated video."), + IO.Image.Input("end_frame", optional=True, tooltip="Last frame of the generated video."), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO_5S, + ) + + @classmethod + async def execute( + cls, + prompt: str, + resolution: str, + loop: bool, + seed: int, + start_frame: torch.Tensor | None = None, + end_frame: torch.Tensor | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + if start_frame is None and end_frame is None: + raise ValueError("Provide at least one of start_frame / end_frame.") + if loop and end_frame is not None: + raise ValueError("Looping is not available when an end_frame is set.") + video = Luma2VideoOptions(resolution=resolution, duration="5s", loop=loop or None) + if start_frame is not None: + url = await upload_image_to_comfyapi(cls, start_frame, mime_type="image/png") + video.start_frame = Luma2ImageRef(url=url) + if end_frame is not None: + url = await upload_image_to_comfyapi(cls, end_frame, mime_type="image/png") + video.end_frame = Luma2ImageRef(url=url) + request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video) + return await _ray32_generate(cls, request) + + +class LumaRay32KeyframeNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32KeyframeNode", + display_name="Luma Ray 3.2 Keyframe", + category="partner/video/Luma", + description="Anchor a guide image to a position on the Ray 3.2 output video timeline. Connect this to " + "the 'keyframes' input of the Luma Ray 3.2 Keyframes to Video node; chain several together via the " + "optional 'keyframes' input below.", + inputs=[ + IO.Image.Input("image", tooltip="Guide image to place at the chosen moment of the output video."), + IO.DynamicCombo.Input( + "position", + options=[ + IO.DynamicCombo.Option( + "Fraction of duration (0.0-1.0)", + [ + IO.Float.Input( + "fraction", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Where in the output video this image applies " "(0.0 = start, 1.0 = end).", + ), + ], + ), + IO.DynamicCombo.Option( + "Absolute time (seconds)", + [ + IO.Float.Input( + "seconds", + default=0.0, + min=0.0, + max=10.0, + step=0.1, + display_mode=IO.NumberDisplay.number, + tooltip="Time in seconds from the start of the output video where this " + "image applies.", + ), + ], + ), + ], + tooltip="How to place this image on the output video's timeline.", + ), + IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input( + "keyframes", + optional=True, + tooltip="Optional earlier keyframes to chain with this one.", + ), + ], + outputs=[IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Output(display_name="keyframes")], + ) + + @classmethod + def execute( + cls, + image: torch.Tensor, + position: dict, + keyframes: LumaRay32KeyframeChain | None = None, + ) -> IO.NodeOutput: + chain = keyframes.clone() if keyframes is not None else LumaRay32KeyframeChain() + if position["position"] == "Absolute time (seconds)": + mode, value = LUMA_KEYFRAME_MODE_SECONDS, float(position["seconds"]) + else: + mode, value = LUMA_KEYFRAME_MODE_FRACTION, float(position["fraction"]) + chain.add(LumaRay32KeyframeItem(image=image, mode=mode, value=value)) + return IO.NodeOutput(chain) + + +class LumaRay32KeyframesToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32KeyframesToVideoNode", + display_name="Luma Ray 3.2 Keyframes to Video", + category="partner/video/Luma", + description="Generate a video that interpolates through a sequence of guide images, each anchored to a " + "position on the timeline, using Luma Ray 3.2. Build the sequence with Luma Ray 3.2 Keyframe nodes " + "(at least 2).", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input("duration", options=["5s", "10s"]), + _ray32_seed_input(), + IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input( + "keyframes", + tooltip="Keyframe sequence from Luma Ray 3.2 Keyframe nodes (at least 2).", + ), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO, + ) + + @classmethod + async def execute( + cls, + prompt: str, + resolution: str, + duration: str, + seed: int, + keyframes: LumaRay32KeyframeChain | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + items = keyframes.items if keyframes is not None else [] + if len(items) < 2: + raise ValueError( + "Connect at least 2 Luma Ray 3.2 Keyframe nodes " + "(use Luma Ray 3.2 Image to Video for a single start/end frame)." + ) + if len(items) > 64: + raise ValueError(f"Ray 3.2 supports at most 64 keyframes; got {len(items)}.") + maxframe = 120 if duration == "5s" else 240 + duration_seconds = maxframe / 24 # 5.0 or 10.0 + # Resolve each keyframe to an output-frame index, then order by position + # (so the user can chain keyframes in any order — the position is what places them) + placed: list[tuple[int, torch.Tensor]] = [] + for item in items: + if item.mode == LUMA_KEYFRAME_MODE_SECONDS: + if item.value > duration_seconds: + raise ValueError( + f"Keyframe position {item.value:g}s is past the end of the {duration} video; " + f"use 0-{duration_seconds:g}s (or switch the keyframe to fraction mode)." + ) + idx = round(item.value * 24) + else: + idx = round(item.value * maxframe) + placed.append((max(0, min(maxframe, idx)), item.image)) + placed.sort(key=lambda p: p[0]) + indexes = [idx for idx, _ in placed] + for a, b in zip(indexes, indexes[1:]): + if a == b: + raise ValueError( + f"Two keyframes resolve to the same output frame ({a}) for a {duration} video " + f"(valid range 0-{maxframe}); give each keyframe a distinct position." + ) + refs: list[Luma2ImageRef] = [] + for _, image in placed: + url = await upload_image_to_comfyapi(cls, image, mime_type="image/png") + refs.append(Luma2ImageRef(url=url)) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video", + video=Luma2VideoOptions(resolution=resolution, duration=duration, keyframes=refs, keyframe_indexes=indexes), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32VideoEditNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32VideoEditNode", + display_name="Luma Ray 3.2 Video Edit", + category="partner/video/Luma", + description="Re-render an existing video under a new prompt using Luma Ray 3.2 (restyle, relight, add " + "or remove elements) while keeping the original motion. Source video up to 18 seconds; the edited " + "video keeps the source's length.", + inputs=[ + IO.Video.Input("video", tooltip="Source video to edit. Up to 18 seconds."), + IO.String.Input("prompt", multiline=True, default="", tooltip="Describes the desired edit."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input( + "strength", + options=[ + "auto", + "adhere_1", + "adhere_2", + "adhere_3", + "flex_1", + "flex_2", + "flex_3", + "reimagine_1", + "reimagine_2", + "reimagine_3", + ], + default="auto", + tooltip="How strongly to preserve vs. reimagine the source. 'auto' lets Ray 3.2 choose; " + "adhere_* preserves the most, flex_* is balanced, reimagine_* changes the most.", + ), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_EDIT, + ) + + @classmethod + async def execute( + cls, video: Input.Video, prompt: str, resolution: str, strength: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + try: + duration = "5s" if video.get_duration() <= 5.0 else "10s" + except Exception: + duration = "10s" + source_url = await upload_video_to_comfyapi(cls, video, max_duration=18) + edit = Luma2VideoEdit(auto_controls=True) if strength == "auto" else Luma2VideoEdit(strength=strength) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video_edit", + source=Luma2ImageRef(url=source_url, media_type="video/mp4"), + video=Luma2VideoOptions(resolution=resolution, duration=duration, edit=edit), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32VideoReframeNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32VideoReframeNode", + display_name="Luma Ray 3.2 Video Reframe", + category="partner/video/Luma", + description="Change the aspect ratio of an existing video, using Luma Ray 3.2 to fill the newly " + "exposed canvas areas. Source video up to 30 seconds. Billed per second of output.", + inputs=[ + IO.Video.Input("video", tooltip="Source video to reframe. Up to 30 seconds."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Describes how the newly exposed canvas areas should be filled.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_REFRAME, + ) + + @classmethod + async def execute( + cls, video: Input.Video, prompt: str, aspect_ratio: str, resolution: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000) + if resolution == "1080p" and aspect_ratio in {"9:16", "3:4"}: + raise ValueError("1080p is not available for vertical aspect ratios (9:16, 3:4) when reframing.") + source_url = await upload_video_to_comfyapi(cls, video, max_duration=30) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video_reframe", + aspect_ratio=aspect_ratio, + source=Luma2ImageRef(url=source_url, media_type="video/mp4"), + video=Luma2VideoOptions(resolution=resolution), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32ExtendVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32ExtendVideoNode", + display_name="Luma Ray 3.2 Extend Video", + category="partner/video/Luma", + description="Extend a previous Ray 3.2 generation forward (continue after it) or backward (lead-in " + "before it). Connect the generation_id output of a prior Luma Ray 3.2 node." + " Extensions are always 5 seconds.", + inputs=[ + IO.String.Input( + "source_generation_id", + default="", + tooltip="generation_id of the prior Ray 3.2 video to extend." + " Connect the generation_id output of another Luma Ray 3.2 node.", + ), + IO.DynamicCombo.Input( + "direction", + options=[ + IO.DynamicCombo.Option( + "Forward (continue after)", + [ + IO.Boolean.Input( + "loop", + default=False, + tooltip="Loop the extended video seamlessly (forward extend only).", + ), + ], + ), + IO.DynamicCombo.Option("Backward (lead-in before)", []), + ], + tooltip="Forward continues after the prior clip; backward is prepended before it.", + ), + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the new content."), + IO.Combo.Input("resolution", options=["540p", "720p", "1080p"], default="720p"), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO_5S, + ) + + @classmethod + async def execute( + cls, source_generation_id: str, direction: dict, prompt: str, resolution: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000) + gen_id = (source_generation_id or "").strip() + if not gen_id: + raise ValueError( + "source_generation_id is required (connect the generation_id output of a prior Luma Ray 3.2 node)." + ) + video = Luma2VideoOptions(resolution=resolution, duration="5s") + ref = Luma2ImageRef(generation_id=gen_id) + if direction["direction"] == "Forward (continue after)": + video.start_frame = ref + if direction.get("loop"): + video.loop = True + else: + video.end_frame = ref + request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video) + return await _ray32_generate(cls, request) class LumaExtension(ComfyExtension): @@ -944,6 +1481,13 @@ class LumaExtension(ComfyExtension): LumaConceptsNode, LumaImageNode, LumaImageEditNode, + LumaRay32TextToVideoNode, + LumaRay32ImageToVideoNode, + LumaRay32KeyframeNode, + LumaRay32KeyframesToVideoNode, + LumaRay32VideoEditNode, + LumaRay32VideoReframeNode, + LumaRay32ExtendVideoNode, ] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 83cf7b001..6b8121cab 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -4,6 +4,8 @@ import os import re import time from collections.abc import Callable +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from io import BytesIO from yarl import URL @@ -91,6 +93,32 @@ async def sleep_with_interrupt( await asyncio.sleep(min(1.0, end - now)) +def _retry_after_wait(value: str | None, fallback: float, max_wait: float) -> float: + """Delay before the next retry, honoring a server ``Retry-After`` header.""" + + seconds: float | None = None + if value is not None: + value = value.strip() + if value.isascii() and value.isdigit(): + # delay-seconds form. The ASCII-digit guard keeps exotic Unicode "digit" characters away from float() + # an all-digit string always converts (huge values become inf, never raising). + seconds = float(value) + elif value: + # HTTP-date form. parsedate_to_datetime raises OverflowError (not a ValueError) on absurd years/offsets + try: + parsed = parsedate_to_datetime(value) + except (TypeError, ValueError, OverflowError): + parsed = None + if parsed is not None: + if parsed.tzinfo is None: # naive datetime: HTTP-date is UTC + parsed = parsed.replace(tzinfo=timezone.utc) + delta = (parsed - datetime.now(timezone.utc)).total_seconds() + seconds = delta if delta > 0 else 0.0 + if seconds is None: + return fallback + return min(seconds, max_wait) + + def mimetype_to_extension(mime_type: str) -> str: """Converts a MIME type to a file extension.""" return mime_type.split("/")[-1].lower() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index adcde7bcb..66aab17f8 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -21,6 +21,7 @@ from server import PromptServer from . import request_logger from ._helpers import ( + _retry_after_wait, default_base_url, get_comfy_api_headers, get_node_id, @@ -82,6 +83,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately +_MAX_RETRY_AFTER_WAIT = 150.0 # Cap a server Retry-After at this many seconds so a large hint can't block execution COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"] @@ -747,6 +749,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): should_retry = True if should_retry: + wait_time = _retry_after_wait(resp.headers.get("Retry-After"), wait_time, _MAX_RETRY_AFTER_WAIT) logging.warning( "HTTP %s %s -> %s. Waiting %.2fs (%s).", method, diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 20ebae155..fa3ab0faf 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -4,11 +4,22 @@ Provides normalization and helper functions for job status tracking. """ import uuid -from typing import Optional +from typing import Callable, Optional from comfy_api.internal import prune_dict +# Result of classifying a job for cancellation. +# 'running' -> job is currently executing (interrupt it) +# 'pending' -> job is queued but not started (dequeue it) +# 'terminal' -> job already finished (present in history); cancel is a no-op +# 'unknown' -> job id is not present anywhere +CANCEL_RUNNING = 'running' +CANCEL_PENDING = 'pending' +CANCEL_TERMINAL = 'terminal' +CANCEL_UNKNOWN = 'unknown' + + class JobStatus: """Job status constants.""" PENDING = 'pending' @@ -407,3 +418,71 @@ def get_all_jobs( jobs = jobs[:limit] return (jobs, total_count) + + +def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str: + """Classify a job id for cancellation. + + Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN. + + Queue items are tuples whose second element (index 1) is the prompt_id. + History is a dict keyed by prompt_id, so a job present there has already + finished and cancelling it is a no-op. + """ + for item in running: + if item[1] == prompt_id: + return CANCEL_RUNNING + for item in queued: + if item[1] == prompt_id: + return CANCEL_PENDING + if prompt_id in history: + return CANCEL_TERMINAL + return CANCEL_UNKNOWN + + +def cancel_job( + prompt_id: str, + running: list, + queued: list, + history: dict, + interrupt: Callable[[str], bool], + dequeue: Callable[[str], bool], +) -> str: + """Cancel a single job by id, regardless of state. + + Maps the cancel onto the runtime's existing mechanics: + - a running job is interrupted via ``interrupt`` + - a pending job is removed from the queue via ``dequeue`` + - a job that already finished (terminal) is a no-op + - an unknown id is a no-op (callers that need fail-fast behaviour should + validate ids up front with ``classify_job_for_cancel``) + + Both ``interrupt`` and ``dequeue`` take the prompt id and return whether + they acted on a job that was *actually* in that state, so the value returned + here reflects what truly happened rather than the (possibly stale) + classification. This matters around the narrow TOCTOU windows where a job + changes state between the caller's snapshot and the action: + + - a job classified RUNNING may have finished before ``interrupt`` fires: + ``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op). + - a job classified PENDING may have started executing before ``dequeue`` + fires: ``dequeue`` returns False, ``interrupt`` then catches the now- + running job and this returns CANCEL_RUNNING. If it had simply finished + instead, both return False and this returns CANCEL_UNKNOWN. + + ``interrupt`` must be atomic — interrupt the job only if it is still the one + running — so a cancel can never land on an unrelated prompt that started in + the meantime (see ``execution.PromptQueue.interrupt_if_running``). + """ + classification = classify_job_for_cancel(prompt_id, running, queued, history) + if classification == CANCEL_RUNNING: + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + if classification == CANCEL_PENDING: + if dequeue(prompt_id): + return CANCEL_PENDING + # Left the pending queue between classification and dequeue: if it + # started executing, interrupt the now-running job; otherwise it has + # already finished and the cancel is a genuine no-op. + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. + return classification diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 77f124e28..6adcc95fa 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode): return IO.Schema( node_id="SaveAudio", search_aliases=["export flac"], - display_name="Save Audio (FLAC) (Deprecated)", + display_name="Save Audio (FLAC) (DEPRECATED)", category="audio", essentials_category="Audio", inputs=[ @@ -166,8 +166,9 @@ class SaveAudio(IO.ComfyNode): IO.String.Input("filename_prefix", default="audio/ComfyUI"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -175,11 +176,10 @@ class SaveAudio(IO.ComfyNode): if audio is None: raise ValueError("SaveAudio: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) ) - save_flac = execute # TODO: remove - class SaveAudioMP3(IO.ComfyNode): @classmethod @@ -187,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode): return IO.Schema( node_id="SaveAudioMP3", search_aliases=["export mp3"], - display_name="Save Audio (MP3) (Deprecated)", + display_name="Save Audio (MP3) (DEPRECATED)", category="audio", essentials_category="Audio", inputs=[ @@ -196,8 +196,9 @@ class SaveAudioMP3(IO.ComfyNode): IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -205,13 +206,12 @@ class SaveAudioMP3(IO.ComfyNode): if audio is None: raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui( audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality ) ) - save_mp3 = execute # TODO: remove - class SaveAudioOpus(IO.ComfyNode): @classmethod @@ -219,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode): return IO.Schema( node_id="SaveAudioOpus", search_aliases=["export opus"], - display_name="Save Audio (Opus) (Deprecated)", + display_name="Save Audio (Opus) (DEPRECATED)", category="audio", inputs=[ IO.Audio.Input("audio"), @@ -227,8 +227,9 @@ class SaveAudioOpus(IO.ComfyNode): IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -236,13 +237,12 @@ class SaveAudioOpus(IO.ComfyNode): if audio is None: raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui( audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality ) ) - save_opus = execute # TODO: remove - class SaveAudioAdvanced(IO.ComfyNode): @classmethod @@ -258,10 +258,7 @@ class SaveAudioAdvanced(IO.ComfyNode): IO.String.Input( "filename_prefix", default="audio/ComfyUI", - tooltip=( - "The prefix for the file to save. May include formatting tokens " - "such as %date:yyyy-MM-dd%." - ), + tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd%."), ), IO.DynamicCombo.Input( "format", @@ -279,6 +276,7 @@ class SaveAudioAdvanced(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Audio.Output("audio")], ) @classmethod @@ -289,7 +287,7 @@ class SaveAudioAdvanced(IO.ComfyNode): ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality) else: ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format) - return IO.NodeOutput(ui=ui) + return IO.NodeOutput(audio, ui=ui) class PreviewAudio(IO.ComfyNode): @@ -305,13 +303,14 @@ class PreviewAudio(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod def execute(cls, audio) -> IO.NodeOutput: if audio is None: raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).") - return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) + return IO.NodeOutput(audio, ui=UI.PreviewAudio(audio, cls=cls)) save_flac = execute # TODO: remove diff --git a/comfy_extras/nodes_boogu.py b/comfy_extras/nodes_boogu.py new file mode 100644 index 000000000..f3951c290 --- /dev/null +++ b/comfy_extras/nodes_boogu.py @@ -0,0 +1,97 @@ +import math + +import node_helpers +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class TextEncodeBooguEdit(io.ComfyNode): + """Boogu-Image Edit conditioning. + + The edit image is used twice, matching the reference pipeline: + - Qwen3-VL vision tokens (instruction understanding) -> positive only + - VAE reference latent (image identity) -> positive and negative + The ref latent is in both conds so it cancels under CFG (identity preserved); + the vision tokens are only in the positive so CFG amplifies the instruction. + The tokenizer selects the right system prompt automatically (image -> TI2I, + empty negative -> DROP), so no template plumbing is needed here. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeBooguEdit", + category="model/conditioning/boogu", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.String.Input("negative_prompt", multiline=True, dynamic_prompts=True, advanced=True), + io.Vae.Input("vae"), + io.Autogrow.Input( + "images", + template=io.Autogrow.TemplateNames( + io.Image.Input("image"), + names=[f"image_{i}" for i in range(1, 17)], + min=0, + ), + tooltip="Reference image(s) to edit. Boogu focuses on one reference per sample; more are allowed.", + ), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, negative_prompt, vae=None, images: io.Autogrow.Type = None) -> io.NodeOutput: + ref_latents = [] + images_vl = [] + + images = images or {} + for name in sorted(images, key=lambda n: int(n.rsplit("_", 1)[-1])): + image = images[name] + if image is None: + continue + samples = image.movedim(-1, 1) + + # Vision tower input: the reference caps the VLM image at 384x384 + # (max_vlm_input_pil_pixels in pipeline_boogu.py). + total = int(384 * 384) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + images_vl.append(s.movedim(1, -1)[:, :, :, :3]) + + # Reference latent: align to 16 px (VAE /8 * patch_size 2). + if vae is not None: + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 16.0) * 16 + height = round(samples.shape[2] * scale_by / 16.0) * 16 + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3])) + + # positive: instruction + vision tokens; negative: empty (no vision). Ref latent on both. + positive = clip.encode_from_tokens_scheduled(clip.tokenize(prompt, images=images_vl)) + negative = clip.encode_from_tokens_scheduled(clip.tokenize(negative_prompt)) + + if len(ref_latents) > 0: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True) + + return io.NodeOutput(positive, negative) + + +class BooguExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeBooguEdit, + ] + + +async def comfy_entrypoint() -> BooguExtension: + return BooguExtension() diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 098c26f23..15d2dc506 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -13,21 +13,22 @@ class ContextWindowsManualNode(io.ComfyNode): description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, - ], tooltip="The stride of the context window."), - io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + ], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."), io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."), ], outputs=[ @@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model: + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -51,6 +52,7 @@ class ContextWindowsManualNode(io.ComfyNode): freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows, + latent_retain_index_list=latent_retain_index_list, causal_window_fix=causal_window_fix, ) # make memory usage calculation only take into account the context window latents @@ -65,33 +67,71 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): schema = super().define_schema() schema.node_id = "WanContextWindowsManual" schema.display_name = "WAN Context Windows (Manual)" - schema.description = "Manually set context windows for WAN-like models (dim=2)." + schema.display_name = "Wan Context Windows" + schema.description = "Set context windows for Wan-like models." schema.category="model/patch/wan" schema.inputs = [ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, - ], tooltip="The stride of the context window."), + ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), - io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), - io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True), + io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True), ] return schema @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: - context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 - context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) + retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 + context_overlap = max(context_overlap // 4, 0) # at least overlap 0 + retain_index_list = "0" if retain_first_frame else "" + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows) + + +class LTXVContextWindowsNode(ContextWindowsManualNode): + @classmethod + def define_schema(cls) -> io.Schema: + schema = super().define_schema() + schema.node_id = "LTXVContextWindows" + schema.display_name = "LTXV Context Windows" + schema.description = "Set context windows for LTXV-like models." + schema.inputs = [ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."), + io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True), + io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True), + ] + return schema + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool, + retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1 + context_overlap = max(context_overlap // 8, 0) # at least overlap 0 + retain_index_list = "0" if retain_first_frame else "" + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, + cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows) class ContextWindowsExtension(ComfyExtension): @@ -99,6 +139,7 @@ class ContextWindowsExtension(ComfyExtension): return [ ContextWindowsManualNode, WanContextWindowsManualNode, + LTXVContextWindowsNode, ] def comfy_entrypoint(): diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 0253b4b4f..73fe75b7f 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1583,7 +1583,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f) + shard_data = torch.load(f, weights_only=True) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 4d5bca17e..44708e5ec 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -77,7 +77,7 @@ class FrameInterpolate(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FrameInterpolate", - display_name="Frame Interpolate", + display_name="Run Frame Interpolation Model", category="video", search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], inputs=[ diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 469a7be55..fe1937ba5 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -214,11 +214,13 @@ class SaveAnimatedWEBP(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: return IO.NodeOutput( + images, ui=UI.ImageSaveHelper.get_save_animated_webp_ui( images=images, filename_prefix=filename_prefix, @@ -230,8 +232,6 @@ class SaveAnimatedWEBP(IO.ComfyNode): ) ) - save_images = execute # TODO: remove - class SaveAnimatedPNG(IO.ComfyNode): @@ -249,11 +249,13 @@ class SaveAnimatedPNG(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: return IO.NodeOutput( + images, ui=UI.ImageSaveHelper.get_save_animated_png_ui( images=images, filename_prefix=filename_prefix, @@ -263,8 +265,6 @@ class SaveAnimatedPNG(IO.ComfyNode): ) ) - save_images = execute # TODO: remove - class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @@ -513,6 +513,7 @@ class SaveSVGNode(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.SVG.Output("svg")], ) @classmethod @@ -562,9 +563,7 @@ class SaveSVGNode(IO.ComfyNode): results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return IO.NodeOutput(ui={"images": results}) - - save_svg = execute # TODO: remove + return IO.NodeOutput(svg, ui={"images": results}) class GetImageSize(IO.ComfyNode): @@ -1157,40 +1156,27 @@ class SaveImageAdvanced(IO.ComfyNode): IO.String.Input( "filename_prefix", default="ComfyUI", - tooltip=( - "The prefix for the file to save. May include formatting tokens " - "such as %date:yyyy-MM-dd% or %Empty Latent Image.width%." - ), + tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd% or %Empty Latent Image.width%."), ), IO.DynamicCombo.Input( "format", options=[ IO.DynamicCombo.Option("png", [ - IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], - default="8-bit", advanced=True), - IO.Combo.Input("input_color_space", options=["sRGB"], - default="sRGB", advanced=True), + IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], default="8-bit", advanced=True), + IO.Combo.Input("input_color_space", options=["sRGB"], default="sRGB", advanced=True), ]), IO.DynamicCombo.Option("exr", [ - IO.Combo.Input("bit_depth", options=["32-bit float"], - default="32-bit float", advanced=True), + IO.Combo.Input("bit_depth", options=["32-bit float"], default="32-bit float", advanced=True), IO.Combo.Input( "input_color_space", options=["sRGB", "HDR", "linear"], default="sRGB", advanced=True, tooltip=( - "Colorspace of the input tensor. The EXR is " - "always written as scene-linear in the matching " - "gamut.\n" - " 'sRGB' — input is sRGB-encoded Rec.709; " - "the inverse sRGB EOTF is applied.\n" - " 'HDR' — input is HLG-encoded Rec.2020 " - "(BT.2100); the inverse HLG OETF is applied " - "to get scene-linear light.\n" - " 'linear' — input is already scene-linear " - "(Rec.709 primaries); written through unchanged. " - "Use this for renderer/compositor output." + "Colorspace of the input tensor. The EXR is always written as scene-linear in the matching gamut.\n" + "sRGB — input is sRGB-encoded Rec.709; the inverse sRGB EOTF is applied.\n" + "HDR — input is HLG-encoded Rec.2020 (BT.2100); the inverse HLG OETF is applied to get scene-linear light.\n" + "linear — input is already scene-linear (Rec.709 primaries); written through unchanged. Use this for renderer/compositor output." ), ), ]), @@ -1200,6 +1186,7 @@ class SaveImageAdvanced(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod @@ -1237,7 +1224,7 @@ class SaveImageAdvanced(IO.ComfyNode): results.append({"filename": file, "subfolder": subfolder, "type": "output"}) counter += 1 - return IO.NodeOutput(ui={"images": results}) + return IO.NodeOutput(images, ui={"images": results}) class ImagesExtension(ComfyExtension): diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 455897859..6e3e88471 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -317,11 +317,74 @@ class PreviewPointCloud(IO.ComfyNode): ) +MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + + +class Load3DAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + input_dir = os.path.join(folder_paths.get_input_directory(), "3d") + os.makedirs(input_dir, exist_ok=True) + + input_path = Path(input_dir) + base_path = Path(folder_paths.get_input_directory()) + + files = [ + normalize_path(str(file_path.relative_to(base_path))) + for file_path in input_path.rglob("*") + if file_path.suffix.lower() in MESH_EXTENSIONS + ] + return IO.Schema( + node_id="Load3DAdvanced", + display_name="Load 3D (Advanced)", + category="3d", + search_aliases=[ + "load mesh", + "load gltf", + "load glb", + "load obj", + "load fbx", + "load stl", + ], + is_experimental=True, + inputs=[ + IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model), + IO.Load3D.Input("viewport_state"), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def validate_inputs(cls, model_file, **kwargs) -> bool | str: + if not model_file or model_file == "none": + return True + if not folder_paths.exists_annotated_filepath(model_file): + return f"Invalid 3D model file: {model_file}" + return True + + @classmethod + def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + file_3d = None + if model_file and model_file != "none": + file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) + model_3d_info = viewport_state.get('model_3d_info', []) + return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height) + + class Load3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Load3D, + Load3DAdvanced, Preview3D, Preview3DAdvanced, PreviewGaussianSplat, diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 95f6ab848..13c1685f7 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -89,7 +89,8 @@ class SwitchNode(io.ComfyNode): template = io.MatchType.Template("switch") return io.Schema( node_id="ComfySwitchNode", - display_name="Switch", + search_aliases=["if", "then", "switch", "conditional", "branch"], + display_name="If/Else Switch", category="utilities/logic", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index c44b09098..7f90daf14 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -10,12 +10,11 @@ class String(io.ComfyNode): return io.Schema( node_id="PrimitiveString", search_aliases=["text", "string", "text box", "prompt"], - display_name="Text String", + display_name="Text String (DEPRECATED)", category="utilities/primitive", - inputs=[ - io.String.Input("value"), - ], + inputs=[io.String.Input("value")], outputs=[io.String.Output()], + is_deprecated=True ) @classmethod @@ -29,12 +28,10 @@ class StringMultiline(io.ComfyNode): return io.Schema( node_id="PrimitiveStringMultiline", search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"], - display_name="Text String (Multiline)", + display_name="Input Text", category="utilities/primitive", essentials_category="Basics", - inputs=[ - io.String.Input("value", multiline=True), - ], + inputs=[io.String.Input("value", multiline=True)], outputs=[io.String.Output()], ) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 050a897dd..d3acc9ad0 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -27,6 +27,7 @@ class SaveWEBM(io.ComfyNode): ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, + outputs=[io.Image.Output(display_name="images")] ) @classmethod @@ -69,7 +70,7 @@ class SaveWEBM(io.ComfyNode): container.mux(stream.encode()) container.close() - return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + return io.NodeOutput(images, ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) class SaveVideo(io.ComfyNode): @classmethod @@ -89,6 +90,7 @@ class SaveVideo(io.ComfyNode): ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, + outputs=[io.Video.Output("video")], ) @classmethod @@ -117,7 +119,7 @@ class SaveVideo(io.ComfyNode): metadata=saved_metadata ) - return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + return io.NodeOutput(video, ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) class CreateVideo(io.ComfyNode): @@ -233,13 +235,8 @@ class VideoSlice(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Video Slice", - display_name="Video Slice", - search_aliases=[ - "trim video duration", - "skip first frames", - "frame load cap", - "start time", - ], + display_name="Trim Video", + search_aliases=["trim video duration", "skip first frames", "frame load cap", "start time"], category="video", essentials_category="Video Tools", inputs=[ diff --git a/execution.py b/execution.py index 9e16e451d..c45317593 100644 --- a/execution.py +++ b/execution.py @@ -1308,6 +1308,25 @@ class PromptQueue: queued = copy.copy(self.queue) return (running, queued) + def interrupt_if_running(self, prompt_id): + """Interrupt the running prompt with this id, atomically. + + Checks the live running set and signals the interrupt under the queue + mutex, so the worker cannot move the job to done (and start the next + prompt) in between. Returns True if a matching job was running and an + interrupt was signalled, False otherwise. The atomicity is what keeps a + cancel from landing on an unrelated prompt that started after a separate + is-running check: the global interrupt flag is reset at the start of + every prompt (execute_async), so a job that finishes before consuming + the flag cannot leak the interrupt onto its successor. + """ + with self.mutex: + for item in self.currently_running.values(): + if item[1] == prompt_id: + nodes.interrupt_processing() + return True + return False + def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 9c395c0b2..6a31d8a63 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -8,21 +8,37 @@ # # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads # #is_default: true # checkpoints: models/checkpoints/ +# configs: models/configs/ +# loras: models/loras/ +# vae: models/vae/ # text_encoders: | # models/text_encoders/ -# models/clip/ # legacy location still supported -# clip_vision: models/clip_vision/ -# configs: models/configs/ -# controlnet: models/controlnet/ +# models/clip/ # diffusion_models: | -# models/diffusion_models -# models/unet +# models/unet/ +# models/diffusion_models/ +# clip_vision: models/clip_vision/ +# style_models: models/style_models/ # embeddings: models/embeddings/ -# loras: models/loras/ +# diffusers: models/diffusers/ +# vae_approx: models/vae_approx/ +# controlnet: | +# models/controlnet/ +# models/t2i_adapter/ +# gligen: models/gligen/ # upscale_models: models/upscale_models/ -# vae: models/vae/ -# audio_encoders: models/audio_encoders/ +# latent_upscale_models: models/latent_upscale_models/ +# custom_nodes: custom_nodes/ +# hypernetworks: models/hypernetworks/ +# photomaker: models/photomaker/ +# classifiers: models/classifiers/ # model_patches: models/model_patches/ +# audio_encoders: models/audio_encoders/ +# background_removal: models/background_removal/ +# frame_interpolation: models/frame_interpolation/ +# geometry_estimation: models/geometry_estimation/ +# optical_flow: models/optical_flow/ +# detection: models/detection/ #config for a1111 ui @@ -45,8 +61,7 @@ # controlnet: models/ControlNet -# For a full list of supported keys (style_models, vae_approx, hypernetworks, photomaker, -# model_patches, audio_encoders, classifiers, etc.) see folder_paths.py. +# For the canonical list of supported keys and extensions, see folder_paths.py. #other_ui: # base_path: path/to/ui diff --git a/nodes.py b/nodes.py index 942f83b23..f92b8819d 100644 --- a/nodes.py +++ b/nodes.py @@ -20,8 +20,6 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch -sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) - import comfy.diffusers_load import comfy.samplers import comfy.sample @@ -482,11 +480,13 @@ class SaveLatent: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = () + return { "required": { + "samples": ("LATENT",), + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) FUNCTION = "save" OUTPUT_NODE = True @@ -524,7 +524,7 @@ class SaveLatent: output["latent_format_version_0"] = torch.tensor([]) comfy.utils.save_torch_file(output, file, metadata=metadata) - return { "ui": { "latents": results } } + return { "ui": { "latents": results }, "result": (samples,) } class LoadLatent: @@ -969,7 +969,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1629,14 +1629,18 @@ class SaveImage: return { "required": { "images": ("IMAGE", {"tooltip": "The images to save."}), - "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + "filename_prefix": ("STRING", { + "default": "ComfyUI", + "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes." + }) }, "hidden": { "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" }, } - RETURN_TYPES = () + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) FUNCTION = "save_images" OUTPUT_NODE = True @@ -1672,7 +1676,7 @@ class SaveImage: }) counter += 1 - return { "ui": { "images": results } } + return { "ui": { "images": results }, "result" : (images,) } class PreviewImage(SaveImage): def __init__(self): @@ -2299,6 +2303,9 @@ async def init_external_custom_nodes(): Returns: None """ + # TODO: remove at some point when custom nodes don't break. + sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) + base_node_names = set(NODE_CLASS_MAPPINGS.keys()) node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] @@ -2425,6 +2432,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_boogu.py", "nodes_chroma_radiance.py", "nodes_pid.py", "nodes_model_patch.py", diff --git a/openapi.yaml b/openapi.yaml index 82ff5b003..380e4476e 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -55,6 +55,12 @@ components: description: URL for asset preview/thumbnail format: uri type: string + short_url: + description: Durable, owner-gated short link to this asset's content (relative `/api/s/{id}` path). Stable across the underlying signed URL's expiry — resolving it re-mints a fresh signed URL on every request — so it is safe to persist or share into chat, unlike `preview_url`. Only the minting user can resolve it. Omitted when the short-link surface is disabled or the asset has no resolvable content hash. + nullable: true + type: string + x-runtime: + - cloud size: description: Size of the asset in bytes format: int64 @@ -673,6 +679,35 @@ components: - created_at - updated_at type: object + JobsCancelRequest: + additionalProperties: false + description: Request to cancel multiple jobs by ID. + properties: + job_ids: + description: Job identifiers (UUIDs) to cancel. + items: + format: uuid + type: string + maxItems: 100 + minItems: 1 + type: array + required: + - job_ids + type: object + JobsCancelResponse: + description: Response for POST /api/jobs/cancel. + properties: + cancelled: + description: | + Job IDs for which a cancel event was successfully dispatched by this + call. Jobs already in a terminal or cancelling state are idempotently + skipped and will not appear here. + items: + type: string + type: array + required: + - cancelled + type: object JobsListResponse: description: Paginated list of jobs for the authenticated user. properties: @@ -1006,7 +1041,7 @@ components: description: If true, clear all pending jobs from the queue type: boolean delete: - description: Array of PENDING job IDs to cancel + description: Array of job IDs to cancel; pending and running jobs transition to cancelled items: type: string type: array @@ -1822,6 +1857,83 @@ paths: summary: Update asset metadata tags: - file + /api/assets/{id}/content: + get: + description: | + Returns the binary content of an asset by ID. + + The contract is the same across runtimes — "GET this path and you + receive the asset's bytes" — but the mechanism differs: + - **Local ComfyUI** streams the bytes directly (`200`, + `application/octet-stream`). + - **Cloud** does not proxy large files; it responds `302` with a + `Location` redirect to a short-lived signed storage URL. Clients that + follow redirects (browsers, `fetch`/XHR, ``/`