LTX2 context windows - Clean up unnecessary code

This commit is contained in:
ozbayb 2026-04-07 13:00:38 -06:00
parent c9edd2d7c0
commit d5badc5f38
2 changed files with 18 additions and 29 deletions

View File

@ -182,22 +182,22 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC,
aux_data: dict, dim: int) -> tuple[torch.Tensor, int]:
window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]:
"""Inject overlapping guide frames into a context window slice.
Uses aux_data from WindowingContext to determine which guide frames overlap
with this window's indices, concatenates them onto the video slice, and sets
window attributes for downstream conditioning resize.
Determines which guide frames overlap with this window's indices, concatenates
them onto the video slice, and sets window attributes for downstream conditioning resize.
Returns (augmented_slice, num_guide_frames_added).
"""
guide_entries = aux_data["guide_entries"]
guide_frames = aux_data["guide_frames"]
guide_entries = window_data.aux_data["guide_entries"]
guide_frames = window_data.guide_frames
overlap = compute_guide_overlap(guide_entries, window.index_list)
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap
window.guide_frames_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
guide_downscale_factors = []
@ -207,6 +207,7 @@ def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWi
entry_H = guide_entries[entry_idx]["latent_shape"][1]
guide_downscale_factors.append(full_H // entry_H)
window.guide_downscale_factors = guide_downscale_factors
if guide_frame_count > 0:
idx = tuple([slice(None)] * dim + [suffix_idx])
sliced_guide = guide_frames[idx]
@ -220,7 +221,6 @@ class WindowingContext:
guide_frames: torch.Tensor | None
aux_data: Any
latent_shapes: list | None
is_multimodal: bool
@dataclass
class ContextSchedule:
@ -310,13 +310,13 @@ class IndexListContextHandler(ContextHandlerABC):
guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None
if guide_frame_count > 0:
aux_data = {"guide_entries": guide_entries, "guide_frames": guide_frames}
aux_data = {"guide_entries": guide_entries}
else:
aux_data = None
return WindowingContext(
tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data,
latent_shapes=latent_shapes, is_multimodal=is_multimodal)
latent_shapes=latent_shapes)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
self._window_data = self._build_window_data(x_in, conds)
@ -437,9 +437,14 @@ class IndexListContextHandler(ContextHandlerABC):
self.set_step(timestep, model_options)
window_data = self._window_data
if window_data.is_multimodal or (window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0):
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0
# if multimodal or has concatenated guide frames on noise latent, use the extended execute path
if is_multimodal or has_guide_frames:
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data)
# basic legacy execution path for single-modal video latent with no guide frames concatenated
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
@ -475,8 +480,9 @@ class IndexListContextHandler(ContextHandlerABC):
timestep: torch.Tensor, model_options: dict[str],
window_data: WindowingContext):
"""Extended execute path for multimodal models and models with guide frames appended to the noise latent."""
latents = self._unpack(x_in, window_data.latent_shapes)
is_multimodal = window_data.is_multimodal
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
primary_frames = window_data.tensor
num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0
@ -538,7 +544,7 @@ class IndexListContextHandler(ContextHandlerABC):
# Slice video, then inject overlapping guide frames if present
sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list)
if window_data.aux_data is not None:
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data.aux_data, self.dim)
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim)
else:
sliced_primary, num_guide_frames = sliced_video, 0
sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))]

View File

@ -1113,23 +1113,6 @@ class LTXAV(BaseModel):
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
def prepare_window_data(self, x_in, conds, dim, window_data):
primary = comfy.utils.unpack_latents(x_in, window_data.latent_shapes)[0] if window_data.is_multimodal else x_in
guide_entries = self._get_guide_entries(conds)
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
if guide_count <= 0:
return comfy.context_windows.WindowingContext(
tensor=primary, guide_frames=None, aux_data=None,
latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal)
video_len = primary.size(dim) - guide_count
video_primary = primary.narrow(dim, 0, video_len)
guide_frames = primary.narrow(dim, video_len, guide_count)
return comfy.context_windows.WindowingContext(
tensor=video_primary, guide_frames=guide_frames,
aux_data={"guide_entries": guide_entries, "guide_frames": guide_frames},
latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal)
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# Audio denoise mask — slice using audio modality window