LTX2 context windows - Refactor guide logic from context_windows into LTXAV model hooks

This commit is contained in:
ozbayb 2026-04-06 11:44:14 -06:00
parent 350237618d
commit 874690c01c
2 changed files with 76 additions and 63 deletions

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable
import torch
import numpy as np
import collections
@ -181,6 +181,12 @@ def _compute_guide_overlap(guide_entries, window_index_list):
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
@dataclass
class WindowingContext:
tensor: torch.Tensor
suffix: torch.Tensor | None
aux_data: Any
@dataclass
class ContextSchedule:
name: str
@ -242,18 +248,6 @@ class IndexListContextHandler(ContextHandlerABC):
if 'latent_shapes' in model_conds:
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
def _get_guide_entries(self, conds):
"""Extract guide_attention_entries list from conditioning. Returns None if absent."""
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
gae = model_conds.get('guide_attention_entries')
if gae is not None and hasattr(gae, 'cond') and gae.cond:
return gae.cond
return None
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
latent_shapes = self._get_latent_shapes(conds)
primary = self._decompose(x_in, latent_shapes)[0]
@ -379,24 +373,19 @@ class IndexListContextHandler(ContextHandlerABC):
is_multimodal = len(modalities) > 1
primary = modalities[0]
# Separate guide frames from primary modality (guides are appended at the end)
guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0
if guide_count > 0:
video_len = primary.size(self.dim) - guide_count
video_primary = primary.narrow(self.dim, 0, video_len)
guide_suffix = primary.narrow(self.dim, video_len, guide_count)
else:
video_primary = primary
guide_suffix = None
# Let model strip auxiliary frames (e.g. guide frames)
window_data = model.prepare_for_windowing(primary, conds, self.dim)
video_primary = window_data.tensor
aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0
# Windows from video portion only (excluding guide frames)
# Windows from video portion only
context_windows = self.get_context_windows(model, video_primary, model_options)
enumerated_context_windows = list(enumerate(context_windows))
total_windows = len(enumerated_context_windows)
# Accumulators sized to video portion for primary, full for other modalities
accum_modalities = list(modalities)
if guide_suffix is not None:
if window_data.suffix is not None:
accum_modalities[0] = video_primary
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities]
@ -406,25 +395,22 @@ class IndexListContextHandler(ContextHandlerABC):
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities]
guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
for window_idx, window in enumerated_context_windows:
comfy.model_management.throw_exception_if_processing_interrupted()
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
+ (f" (+{guide_count} guide)" if guide_count > 0 else "")
+ (f" (+{aux_count} aux)" if aux_count > 0 else "")
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
# Per-modality window indices
if is_multimodal:
# Adjust latent_shapes so video shape reflects video-only frames (excludes guides)
map_shapes = latent_shapes
if guide_count > 0:
if video_primary.size(self.dim) != primary.size(self.dim):
map_shapes = list(latent_shapes)
video_shape = list(latent_shapes[0])
video_shape[self.dim] = video_shape[self.dim] - guide_count
video_shape[self.dim] = video_primary.size(self.dim)
map_shapes[0] = torch.Size(video_shape)
per_mod_indices = model.map_context_window_to_modalities(
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
@ -446,30 +432,10 @@ class IndexListContextHandler(ContextHandlerABC):
for mod_idx in range(1, len(modalities)):
mod_windows.append(modality_windows[mod_idx])
# Slice video, then select overlapping guide frames
# Slice video, then let model inject auxiliary frames
sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list)
num_guide_in_window = 0
if guide_suffix is not None and guide_entries is not None:
overlap = _compute_guide_overlap(guide_entries, window.index_list)
if overlap[3] > 0:
suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap
idx = tuple([slice(None)] * self.dim + [suffix_idx])
sliced_guide = guide_suffix[idx]
window.guide_suffix_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
else:
sliced_guide = None
window.guide_suffix_indices = []
window.guide_overlap_info = []
window.guide_kf_local_positions = []
else:
sliced_guide = None
if sliced_guide is not None:
sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim)
else:
sliced_primary = sliced_video
sliced_primary, num_aux = model.prepare_window_input(
sliced_video, window, window_data.aux_data, self.dim)
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
# Compose for pipeline
@ -481,7 +447,6 @@ class IndexListContextHandler(ContextHandlerABC):
model_options["transformer_options"]["context_window"] = window
sub_timestep = window.get_tensor(timestep, dim=0)
# Resize conds using video_primary as reference (excludes guide frames)
sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds]
if is_multimodal:
self._patch_latent_shapes(sub_conds, sub_shapes)
@ -490,14 +455,12 @@ class IndexListContextHandler(ContextHandlerABC):
# Decompose output per modality
out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
# out_per_mod[cond_idx][mod_idx] = tensor
# Strip guide frames from primary output before accumulation
if num_guide_in_window > 0:
# Strip auxiliary frames from primary output before accumulation
if num_aux > 0:
window_len = len(window.index_list)
for ci in range(len(sub_conds_out)):
primary_out = out_per_mod[ci][0]
out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len)
out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len)
# Accumulate per modality (using video-only sizes)
for mod_idx in range(len(accum_modalities)):
@ -516,10 +479,9 @@ class IndexListContextHandler(ContextHandlerABC):
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
accum[mod_idx][ci] /= counts[mod_idx][ci]
f = accum[mod_idx][ci]
# Re-append original guide_suffix (not model output — sampling loop
# respects denoise_mask and never modifies guide frame positions)
if mod_idx == 0 and guide_suffix is not None:
f = torch.cat([f, guide_suffix], dim=self.dim)
# Re-append model's suffix (auxiliary frames stripped before windowing)
if mod_idx == 0 and window_data.suffix is not None:
f = torch.cat([f, window_data.suffix], dim=self.dim)
finalized.append(f)
composed, _ = self._compose(finalized)
result.append(composed)

View File

@ -287,6 +287,12 @@ class BaseModel(torch.nn.Module):
return data
return None
def prepare_for_windowing(self, primary, conds, dim):
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
def prepare_window_input(self, video_slice, window, aux_data, dim):
return video_slice, 0
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
@ -1113,6 +1119,51 @@ class LTXAV(BaseModel):
return sum(e["latent_shape"][0] for e in gae.cond)
return 0
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
gae = model_conds.get('guide_attention_entries')
if gae is not None and hasattr(gae, 'cond') and gae.cond:
return gae.cond
return None
def prepare_for_windowing(self, primary, conds, dim):
guide_count = self.get_guide_frame_count(primary, conds)
if guide_count <= 0:
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
video_len = primary.size(dim) - guide_count
video_primary = primary.narrow(dim, 0, video_len)
guide_suffix = primary.narrow(dim, video_len, guide_count)
guide_entries = self._get_guide_entries(conds)
return comfy.context_windows.WindowingContext(
tensor=video_primary, suffix=guide_suffix,
aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix})
def prepare_window_input(self, video_slice, window, aux_data, dim):
if aux_data is None:
return video_slice, 0
guide_entries = aux_data["guide_entries"]
guide_suffix = aux_data["guide_suffix"]
if guide_entries is None:
window.guide_suffix_indices = []
window.guide_overlap_info = []
window.guide_kf_local_positions = []
return video_slice, 0
overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list)
suffix_idx, overlap_info, kf_local_pos, num_guide = overlap
window.guide_suffix_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
if num_guide > 0:
idx = tuple([slice(None)] * dim + [suffix_idx])
sliced_guide = guide_suffix[idx]
return torch.cat([video_slice, sliced_guide], dim=dim), num_guide
return video_slice, 0
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# Audio denoise mask — slice using audio modality window
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: