From 874690c01ca949b0479ce09bb64c37ff68b0e6e1 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 6 Apr 2026 11:44:14 -0600 Subject: [PATCH] LTX2 context windows - Refactor guide logic from context_windows into LTXAV model hooks --- comfy/context_windows.py | 88 ++++++++++++---------------------------- comfy/model_base.py | 51 +++++++++++++++++++++++ 2 files changed, 76 insertions(+), 63 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 295f348e6..9e7282fda 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable import torch import numpy as np import collections @@ -181,6 +181,12 @@ def _compute_guide_overlap(guide_entries, window_index_list): return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) +@dataclass +class WindowingContext: + tensor: torch.Tensor + suffix: torch.Tensor | None + aux_data: Any + @dataclass class ContextSchedule: name: str @@ -242,18 +248,6 @@ class IndexListContextHandler(ContextHandlerABC): if 'latent_shapes' in model_conds: model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) - def _get_guide_entries(self, conds): - """Extract guide_attention_entries list from conditioning. Returns None if absent.""" - for cond_list in conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - gae = model_conds.get('guide_attention_entries') - if gae is not None and hasattr(gae, 'cond') and gae.cond: - return gae.cond - return None - def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: latent_shapes = self._get_latent_shapes(conds) primary = self._decompose(x_in, latent_shapes)[0] @@ -379,24 +373,19 @@ class IndexListContextHandler(ContextHandlerABC): is_multimodal = len(modalities) > 1 primary = modalities[0] - # Separate guide frames from primary modality (guides are appended at the end) - guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0 - if guide_count > 0: - video_len = primary.size(self.dim) - guide_count - video_primary = primary.narrow(self.dim, 0, video_len) - guide_suffix = primary.narrow(self.dim, video_len, guide_count) - else: - video_primary = primary - guide_suffix = None + # Let model strip auxiliary frames (e.g. guide frames) + window_data = model.prepare_for_windowing(primary, conds, self.dim) + video_primary = window_data.tensor + aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0 - # Windows from video portion only (excluding guide frames) + # Windows from video portion only context_windows = self.get_context_windows(model, video_primary, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) # Accumulators sized to video portion for primary, full for other modalities accum_modalities = list(modalities) - if guide_suffix is not None: + if window_data.suffix is not None: accum_modalities[0] = video_primary accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities] @@ -406,25 +395,22 @@ class IndexListContextHandler(ContextHandlerABC): counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities] - guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}" - + (f" (+{guide_count} guide)" if guide_count > 0 else "") + + (f" (+{aux_count} aux)" if aux_count > 0 else "") + (f" [{len(modalities)} modalities]" if is_multimodal else "")) # Per-modality window indices if is_multimodal: - # Adjust latent_shapes so video shape reflects video-only frames (excludes guides) map_shapes = latent_shapes - if guide_count > 0: + if video_primary.size(self.dim) != primary.size(self.dim): map_shapes = list(latent_shapes) video_shape = list(latent_shapes[0]) - video_shape[self.dim] = video_shape[self.dim] - guide_count + video_shape[self.dim] = video_primary.size(self.dim) map_shapes[0] = torch.Size(video_shape) per_mod_indices = model.map_context_window_to_modalities( window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list] @@ -446,30 +432,10 @@ class IndexListContextHandler(ContextHandlerABC): for mod_idx in range(1, len(modalities)): mod_windows.append(modality_windows[mod_idx]) - # Slice video, then select overlapping guide frames + # Slice video, then let model inject auxiliary frames sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list) - num_guide_in_window = 0 - if guide_suffix is not None and guide_entries is not None: - overlap = _compute_guide_overlap(guide_entries, window.index_list) - if overlap[3] > 0: - suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap - idx = tuple([slice(None)] * self.dim + [suffix_idx]) - sliced_guide = guide_suffix[idx] - window.guide_suffix_indices = suffix_idx - window.guide_overlap_info = overlap_info - window.guide_kf_local_positions = kf_local_pos - else: - sliced_guide = None - window.guide_suffix_indices = [] - window.guide_overlap_info = [] - window.guide_kf_local_positions = [] - else: - sliced_guide = None - - if sliced_guide is not None: - sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim) - else: - sliced_primary = sliced_video + sliced_primary, num_aux = model.prepare_window_input( + sliced_video, window, window_data.aux_data, self.dim) sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))] # Compose for pipeline @@ -481,7 +447,6 @@ class IndexListContextHandler(ContextHandlerABC): model_options["transformer_options"]["context_window"] = window sub_timestep = window.get_tensor(timestep, dim=0) - # Resize conds using video_primary as reference (excludes guide frames) sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds] if is_multimodal: self._patch_latent_shapes(sub_conds, sub_shapes) @@ -490,14 +455,12 @@ class IndexListContextHandler(ContextHandlerABC): # Decompose output per modality out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] - # out_per_mod[cond_idx][mod_idx] = tensor - # Strip guide frames from primary output before accumulation - if num_guide_in_window > 0: + # Strip auxiliary frames from primary output before accumulation + if num_aux > 0: window_len = len(window.index_list) for ci in range(len(sub_conds_out)): - primary_out = out_per_mod[ci][0] - out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len) + out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len) # Accumulate per modality (using video-only sizes) for mod_idx in range(len(accum_modalities)): @@ -516,10 +479,9 @@ class IndexListContextHandler(ContextHandlerABC): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] f = accum[mod_idx][ci] - # Re-append original guide_suffix (not model output — sampling loop - # respects denoise_mask and never modifies guide frame positions) - if mod_idx == 0 and guide_suffix is not None: - f = torch.cat([f, guide_suffix], dim=self.dim) + # Re-append model's suffix (auxiliary frames stripped before windowing) + if mod_idx == 0 and window_data.suffix is not None: + f = torch.cat([f, window_data.suffix], dim=self.dim) finalized.append(f) composed, _ = self._compose(finalized) result.append(composed) diff --git a/comfy/model_base.py b/comfy/model_base.py index 066e3f8d8..c1e2ef8c2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -287,6 +287,12 @@ class BaseModel(torch.nn.Module): return data return None + def prepare_for_windowing(self, primary, conds, dim): + return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None) + + def prepare_window_input(self, video_slice, window, aux_data, dim): + return video_slice, 0 + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): """Override in subclasses to handle model-specific cond slicing for context windows. Return a sliced cond object, or None to fall through to default handling. @@ -1113,6 +1119,51 @@ class LTXAV(BaseModel): return sum(e["latent_shape"][0] for e in gae.cond) return 0 + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + gae = model_conds.get('guide_attention_entries') + if gae is not None and hasattr(gae, 'cond') and gae.cond: + return gae.cond + return None + + def prepare_for_windowing(self, primary, conds, dim): + guide_count = self.get_guide_frame_count(primary, conds) + if guide_count <= 0: + return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None) + video_len = primary.size(dim) - guide_count + video_primary = primary.narrow(dim, 0, video_len) + guide_suffix = primary.narrow(dim, video_len, guide_count) + guide_entries = self._get_guide_entries(conds) + return comfy.context_windows.WindowingContext( + tensor=video_primary, suffix=guide_suffix, + aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix}) + + def prepare_window_input(self, video_slice, window, aux_data, dim): + if aux_data is None: + return video_slice, 0 + guide_entries = aux_data["guide_entries"] + guide_suffix = aux_data["guide_suffix"] + if guide_entries is None: + window.guide_suffix_indices = [] + window.guide_overlap_info = [] + window.guide_kf_local_positions = [] + return video_slice, 0 + overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list) + suffix_idx, overlap_info, kf_local_pos, num_guide = overlap + window.guide_suffix_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + if num_guide > 0: + idx = tuple([slice(None)] * dim + [suffix_idx]) + sliced_guide = guide_suffix[idx] + return torch.cat([video_slice, sliced_guide], dim=dim), num_guide + return video_slice, 0 + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): # Audio denoise mask — slice using audio modality window if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: