From 75e8e4b6dcd76783079ca304782d5b88231a97c3 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 5 May 2026 22:46:52 -0600 Subject: [PATCH] Update compute_guide_overlap to include keyframe indices and temporal downscale ratio instead of using latent_start key added to attention_guide_entries. - This makes it possible for existing nodes like the ComfyUI-LTXVideo nodes to work with context windows without needing modifications --- comfy/context_windows.py | 46 ++++++++++++++++++++++++++++++++-------- comfy_extras/nodes_lt.py | 4 +--- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index b15799dde..5f9899c67 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -161,11 +161,17 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) -def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int]): +def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]): """Compute which concatenated guide frames overlap with a context window. + Each guide's latent-space start is derived from its first token's pixel-t-start + in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the + model's temporal_downscale_ratio. + Args: guide_entries: list of guide_attention_entry dicts + keyframe_idxs: per-token pixel coords cond tensor for the modality + temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio window_index_list: the window's frame indices into the video portion Returns: @@ -180,11 +186,11 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int overlap_info = [] kf_local_positions = [] suffix_base = 0 + token_offset = 0 for entry_idx, entry in enumerate(guide_entries): - latent_start = entry.get("latent_start", None) - if latent_start is None: - raise ValueError("guide_attention_entry missing required 'latent_start'.") + first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item()) + latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio guide_len = entry["latent_shape"][0] entry_overlap = 0 @@ -198,6 +204,7 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int if entry_overlap > 0: overlap_info.append((entry_idx, entry_overlap)) suffix_base += guide_len + token_offset += entry["pre_filter_count"] return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) @@ -211,9 +218,11 @@ class WindowingState: latents: list[torch.Tensor] # per-modality working latents (guide frames stripped) guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata + keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal) dim: int = 0 # primary modality temporal dim for context windowing is_multimodal: bool = False + temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow: """Reformat window for multimodal contexts by deriving per-modality index lists. @@ -279,7 +288,9 @@ class WindowingState: def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]: guide_entries = self.guide_entries[modality_idx] guide_frames = self.guide_latents[modality_idx] - suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(guide_entries, window.index_list) + keyframe_idxs = self.keyframe_idxs[modality_idx] + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap( + guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list) # Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0. anchor_idx = getattr(window, 'causal_anchor_index', None) if anchor_idx is not None and anchor_idx >= 0: @@ -370,6 +381,18 @@ class IndexListContextHandler(ContextHandlerABC): return entries.cond return None + @staticmethod + def _get_keyframe_idxs(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + kf = model_conds.get('keyframe_idxs') + if kf is not None and hasattr(kf, 'cond') and kf.cond is not None: + return kf.cond + return None + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio. If guide frames are present on the primary modality, only the video portion is shuffled. @@ -395,7 +418,7 @@ class IndexListContextHandler(ContextHandlerABC): apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed) return noise - def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState: + def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState: """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" latent_shapes = self._get_latent_shapes(conds) is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 @@ -404,8 +427,10 @@ class IndexListContextHandler(ContextHandlerABC): unpacked_latents_list = list(unpacked_latents) guide_latents_list = [None] * len(unpacked_latents) guide_entries_list = [None] * len(unpacked_latents) + keyframe_idxs_list = [None] * len(unpacked_latents) extracted_guide_entries = self._get_guide_entries(conds) + extracted_keyframe_idxs = self._get_keyframe_idxs(conds) # Strip guide frames (only from first modality for now) if extracted_guide_entries is not None: @@ -416,18 +441,21 @@ class IndexListContextHandler(ContextHandlerABC): unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count) guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count) guide_entries_list[0] = extracted_guide_entries + keyframe_idxs_list[0] = extracted_keyframe_idxs return WindowingState( latents=unpacked_latents_list, guide_latents=guide_latents_list, guide_entries=guide_entries_list, + keyframe_idxs=keyframe_idxs_list, latent_shapes=latent_shapes, dim=self.dim, - is_multimodal=is_multimodal) + is_multimodal=is_multimodal, + temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio) def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - window_state = self._build_window_state(x_in, conds) # build window_state to check frame counts, will be built again in execute + window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute total_frame_count = window_state.latents[0].size(self.dim) if total_frame_count > self.context_length: logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") @@ -545,7 +573,7 @@ class IndexListContextHandler(ContextHandlerABC): self._model = model self.set_step(timestep, model_options) - window_state = self._build_window_state(x_in, conds) + window_state = self._build_window_state(x_in, conds, model) num_modalities = len(window_state.latents) context_windows = self.get_context_windows(model, window_state.latents[0], model_options) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 8d37689ce..19d8a387f 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -136,7 +136,7 @@ class LTXVImgToVideoInplace(io.ComfyNode): generate = execute # TODO: remove -def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, latent_start=0): +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): """Append a guide_attention_entry to both positive and negative conditioning. Each entry tracks one guide reference for per-reference attention control. @@ -147,7 +147,6 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s "strength": strength, "pixel_mask": None, "latent_shape": latent_shape, - "latent_start": latent_start, } results = [] for cond in (positive, negative): @@ -364,7 +363,6 @@ class LTXVAddGuide(io.ComfyNode): guide_latent_shape = list(t.shape[2:]) # [F, H, W] positive, negative = _append_guide_attention_entry( positive, negative, pre_filter_count, guide_latent_shape, strength=strength, - latent_start=latent_idx, ) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})