From d5badc5f380729cd3fbef6a3742d06df3b36419f Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:00:38 -0600 Subject: [PATCH] LTX2 context windows - Clean up unnecessary code --- comfy/context_windows.py | 30 ++++++++++++++++++------------ comfy/model_base.py | 17 ----------------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index f955d4b67..6e21bdc81 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -182,22 +182,22 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC, - aux_data: dict, dim: int) -> tuple[torch.Tensor, int]: + window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]: """Inject overlapping guide frames into a context window slice. - Uses aux_data from WindowingContext to determine which guide frames overlap - with this window's indices, concatenates them onto the video slice, and sets - window attributes for downstream conditioning resize. + Determines which guide frames overlap with this window's indices, concatenates + them onto the video slice, and sets window attributes for downstream conditioning resize. Returns (augmented_slice, num_guide_frames_added). """ - guide_entries = aux_data["guide_entries"] - guide_frames = aux_data["guide_frames"] + guide_entries = window_data.aux_data["guide_entries"] + guide_frames = window_data.guide_frames overlap = compute_guide_overlap(guide_entries, window.index_list) suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap window.guide_frames_indices = suffix_idx window.guide_overlap_info = overlap_info window.guide_kf_local_positions = kf_local_pos + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. guide_downscale_factors = [] @@ -207,6 +207,7 @@ def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWi entry_H = guide_entries[entry_idx]["latent_shape"][1] guide_downscale_factors.append(full_H // entry_H) window.guide_downscale_factors = guide_downscale_factors + if guide_frame_count > 0: idx = tuple([slice(None)] * dim + [suffix_idx]) sliced_guide = guide_frames[idx] @@ -220,7 +221,6 @@ class WindowingContext: guide_frames: torch.Tensor | None aux_data: Any latent_shapes: list | None - is_multimodal: bool @dataclass class ContextSchedule: @@ -310,13 +310,13 @@ class IndexListContextHandler(ContextHandlerABC): guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None if guide_frame_count > 0: - aux_data = {"guide_entries": guide_entries, "guide_frames": guide_frames} + aux_data = {"guide_entries": guide_entries} else: aux_data = None return WindowingContext( tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data, - latent_shapes=latent_shapes, is_multimodal=is_multimodal) + latent_shapes=latent_shapes) def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: self._window_data = self._build_window_data(x_in, conds) @@ -437,9 +437,14 @@ class IndexListContextHandler(ContextHandlerABC): self.set_step(timestep, model_options) window_data = self._window_data - if window_data.is_multimodal or (window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0): + is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 + has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0 + + # if multimodal or has concatenated guide frames on noise latent, use the extended execute path + if is_multimodal or has_guide_frames: return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data) + # basic legacy execution path for single-modal video latent with no guide frames concatenated context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) @@ -475,8 +480,9 @@ class IndexListContextHandler(ContextHandlerABC): timestep: torch.Tensor, model_options: dict[str], window_data: WindowingContext): """Extended execute path for multimodal models and models with guide frames appended to the noise latent.""" + latents = self._unpack(x_in, window_data.latent_shapes) - is_multimodal = window_data.is_multimodal + is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 primary_frames = window_data.tensor num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0 @@ -538,7 +544,7 @@ class IndexListContextHandler(ContextHandlerABC): # Slice video, then inject overlapping guide frames if present sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list) if window_data.aux_data is not None: - sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data.aux_data, self.dim) + sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim) else: sliced_primary, num_guide_frames = sliced_video, 0 sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))] diff --git a/comfy/model_base.py b/comfy/model_base.py index 65ce1bac5..8960ecd19 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1113,23 +1113,6 @@ class LTXAV(BaseModel): if entries is not None and hasattr(entries, 'cond') and entries.cond: return entries.cond return None - - def prepare_window_data(self, x_in, conds, dim, window_data): - primary = comfy.utils.unpack_latents(x_in, window_data.latent_shapes)[0] if window_data.is_multimodal else x_in - guide_entries = self._get_guide_entries(conds) - guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 - if guide_count <= 0: - return comfy.context_windows.WindowingContext( - tensor=primary, guide_frames=None, aux_data=None, - latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal) - video_len = primary.size(dim) - guide_count - video_primary = primary.narrow(dim, 0, video_len) - guide_frames = primary.narrow(dim, video_len, guide_count) - return comfy.context_windows.WindowingContext( - tensor=video_primary, guide_frames=guide_frames, - aux_data={"guide_entries": guide_entries, "guide_frames": guide_frames}, - latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal) - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): # Audio denoise mask — slice using audio modality window