Update compute_guide_overlap to include keyframe indices and temporal downscale ratio instead of using latent_start key added to attention_guide_entries.

- This makes it possible for existing nodes like the ComfyUI-LTXVideo nodes to work with context windows without needing modifications
This commit is contained in:
ozbayb 2026-05-05 22:46:52 -06:00
parent cbe2a1ba42
commit 75e8e4b6dc
2 changed files with 38 additions and 12 deletions

View File

@ -161,11 +161,17 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
return cond_value._copy_with(sliced) return cond_value._copy_with(sliced)
def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int]): def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
"""Compute which concatenated guide frames overlap with a context window. """Compute which concatenated guide frames overlap with a context window.
Each guide's latent-space start is derived from its first token's pixel-t-start
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
model's temporal_downscale_ratio.
Args: Args:
guide_entries: list of guide_attention_entry dicts guide_entries: list of guide_attention_entry dicts
keyframe_idxs: per-token pixel coords cond tensor for the modality
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
window_index_list: the window's frame indices into the video portion window_index_list: the window's frame indices into the video portion
Returns: Returns:
@ -180,11 +186,11 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
overlap_info = [] overlap_info = []
kf_local_positions = [] kf_local_positions = []
suffix_base = 0 suffix_base = 0
token_offset = 0
for entry_idx, entry in enumerate(guide_entries): for entry_idx, entry in enumerate(guide_entries):
latent_start = entry.get("latent_start", None) first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
if latent_start is None: latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
raise ValueError("guide_attention_entry missing required 'latent_start'.")
guide_len = entry["latent_shape"][0] guide_len = entry["latent_shape"][0]
entry_overlap = 0 entry_overlap = 0
@ -198,6 +204,7 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
if entry_overlap > 0: if entry_overlap > 0:
overlap_info.append((entry_idx, entry_overlap)) overlap_info.append((entry_idx, entry_overlap))
suffix_base += guide_len suffix_base += guide_len
token_offset += entry["pre_filter_count"]
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
@ -211,9 +218,11 @@ class WindowingState:
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped) latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal) latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
dim: int = 0 # primary modality temporal dim for context windowing dim: int = 0 # primary modality temporal dim for context windowing
is_multimodal: bool = False is_multimodal: bool = False
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow: def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
"""Reformat window for multimodal contexts by deriving per-modality index lists. """Reformat window for multimodal contexts by deriving per-modality index lists.
@ -279,7 +288,9 @@ class WindowingState:
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]: def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
guide_entries = self.guide_entries[modality_idx] guide_entries = self.guide_entries[modality_idx]
guide_frames = self.guide_latents[modality_idx] guide_frames = self.guide_latents[modality_idx]
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(guide_entries, window.index_list) keyframe_idxs = self.keyframe_idxs[modality_idx]
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0. # Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
anchor_idx = getattr(window, 'causal_anchor_index', None) anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0: if anchor_idx is not None and anchor_idx >= 0:
@ -370,6 +381,18 @@ class IndexListContextHandler(ContextHandlerABC):
return entries.cond return entries.cond
return None return None
@staticmethod
def _get_keyframe_idxs(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
kf = model_conds.get('keyframe_idxs')
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
return kf.cond
return None
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio. """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
If guide frames are present on the primary modality, only the video portion is shuffled. If guide frames are present on the primary modality, only the video portion is shuffled.
@ -395,7 +418,7 @@ class IndexListContextHandler(ContextHandlerABC):
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed) apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
return noise return noise
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState: def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
latent_shapes = self._get_latent_shapes(conds) latent_shapes = self._get_latent_shapes(conds)
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
@ -404,8 +427,10 @@ class IndexListContextHandler(ContextHandlerABC):
unpacked_latents_list = list(unpacked_latents) unpacked_latents_list = list(unpacked_latents)
guide_latents_list = [None] * len(unpacked_latents) guide_latents_list = [None] * len(unpacked_latents)
guide_entries_list = [None] * len(unpacked_latents) guide_entries_list = [None] * len(unpacked_latents)
keyframe_idxs_list = [None] * len(unpacked_latents)
extracted_guide_entries = self._get_guide_entries(conds) extracted_guide_entries = self._get_guide_entries(conds)
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
# Strip guide frames (only from first modality for now) # Strip guide frames (only from first modality for now)
if extracted_guide_entries is not None: if extracted_guide_entries is not None:
@ -416,18 +441,21 @@ class IndexListContextHandler(ContextHandlerABC):
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count) unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count) guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
guide_entries_list[0] = extracted_guide_entries guide_entries_list[0] = extracted_guide_entries
keyframe_idxs_list[0] = extracted_keyframe_idxs
return WindowingState( return WindowingState(
latents=unpacked_latents_list, latents=unpacked_latents_list,
guide_latents=guide_latents_list, guide_latents=guide_latents_list,
guide_entries=guide_entries_list, guide_entries=guide_entries_list,
keyframe_idxs=keyframe_idxs_list,
latent_shapes=latent_shapes, latent_shapes=latent_shapes,
dim=self.dim, dim=self.dim,
is_multimodal=is_multimodal) is_multimodal=is_multimodal,
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
window_state = self._build_window_state(x_in, conds) # build window_state to check frame counts, will be built again in execute window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
total_frame_count = window_state.latents[0].size(self.dim) total_frame_count = window_state.latents[0].size(self.dim)
if total_frame_count > self.context_length: if total_frame_count > self.context_length:
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
@ -545,7 +573,7 @@ class IndexListContextHandler(ContextHandlerABC):
self._model = model self._model = model
self.set_step(timestep, model_options) self.set_step(timestep, model_options)
window_state = self._build_window_state(x_in, conds) window_state = self._build_window_state(x_in, conds, model)
num_modalities = len(window_state.latents) num_modalities = len(window_state.latents)
context_windows = self.get_context_windows(model, window_state.latents[0], model_options) context_windows = self.get_context_windows(model, window_state.latents[0], model_options)

View File

@ -136,7 +136,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, latent_start=0): def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
"""Append a guide_attention_entry to both positive and negative conditioning. """Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control. Each entry tracks one guide reference for per-reference attention control.
@ -147,7 +147,6 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
"strength": strength, "strength": strength,
"pixel_mask": None, "pixel_mask": None,
"latent_shape": latent_shape, "latent_shape": latent_shape,
"latent_start": latent_start,
} }
results = [] results = []
for cond in (positive, negative): for cond in (positive, negative):
@ -364,7 +363,6 @@ class LTXVAddGuide(io.ComfyNode):
guide_latent_shape = list(t.shape[2:]) # [F, H, W] guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry( positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength, positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
latent_start=latent_idx,
) )
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})