From 3a061f4bbfc6f61367656ef6e18d0caa5271805b Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 6 Apr 2026 15:13:46 -0600 Subject: [PATCH] LTX2 context windows - Cleanup: Simplify IndexListContextHandler standard execute path --- comfy/context_windows.py | 66 +++++++++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 9e7282fda..e89c9cee2 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -367,18 +367,60 @@ class IndexListContextHandler(ContextHandlerABC): self._model = model self.set_step(timestep, model_options) - # Decompose — single-modality: [x_in], multimodal: [video, audio, ...] + # Check if multimodal or model has auxiliary frames requiring the extended path latent_shapes = self._get_latent_shapes(conds) + is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 + if is_multimodal: + return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, latent_shapes) + window_data = model.prepare_for_windowing(x_in, conds, self.dim) + if window_data.suffix is not None or window_data.aux_data is not None: + return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, + latent_shapes, window_data) + + context_windows = self.get_context_windows(model, x_in, model_options) + enumerated_context_windows = list(enumerate(context_windows)) + + conds_final = [torch.zeros_like(x_in) for _ in conds] + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + else: + counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + for enum_window in enumerated_context_windows: + results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + for result in results: + self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, + conds_final, counts_final, biases_final) + try: + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + del counts_final + return conds_final + else: + for i in range(len(conds_final)): + conds_final[i] /= counts_final[i] + del counts_final + return conds_final + finally: + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, + timestep: torch.Tensor, model_options: dict[str], + latent_shapes, window_data: WindowingContext=None): + """Extended execute path for multimodal models and models with auxiliary frames.""" modalities = self._decompose(x_in, latent_shapes) is_multimodal = len(modalities) > 1 - primary = modalities[0] - # Let model strip auxiliary frames (e.g. guide frames) - window_data = model.prepare_for_windowing(primary, conds, self.dim) + if window_data is None: + window_data = model.prepare_for_windowing(modalities[0], conds, self.dim) + video_primary = window_data.tensor aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0 - # Windows from video portion only context_windows = self.get_context_windows(model, video_primary, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) @@ -407,14 +449,13 @@ class IndexListContextHandler(ContextHandlerABC): # Per-modality window indices if is_multimodal: map_shapes = latent_shapes - if video_primary.size(self.dim) != primary.size(self.dim): + if video_primary.size(self.dim) != modalities[0].size(self.dim): map_shapes = list(latent_shapes) video_shape = list(latent_shapes[0]) video_shape[self.dim] = video_primary.size(self.dim) map_shapes[0] = torch.Size(video_shape) per_mod_indices = model.map_context_window_to_modalities( window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list] - # Build per-modality windows and attach to primary window modality_windows = {} for mod_idx in range(1, len(modalities)): modality_windows[mod_idx] = IndexListContextWindow( @@ -423,11 +464,9 @@ class IndexListContextHandler(ContextHandlerABC): window = IndexListContextWindow( window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim], modality_windows=modality_windows) - else: - per_mod_indices = [window.index_list] - # Build per-modality windows list (including primary) - mod_windows = [window] # primary window at index 0 + # Build per-modality windows list + mod_windows = [window] if is_multimodal: for mod_idx in range(1, len(modalities)): mod_windows.append(modality_windows[mod_idx]) @@ -438,10 +477,8 @@ class IndexListContextHandler(ContextHandlerABC): sliced_video, window, window_data.aux_data, self.dim) sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))] - # Compose for pipeline sub_x, sub_shapes = self._compose(sliced) - # Callbacks for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None) @@ -462,7 +499,7 @@ class IndexListContextHandler(ContextHandlerABC): for ci in range(len(sub_conds_out)): out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len) - # Accumulate per modality (using video-only sizes) + # Accumulate per modality for mod_idx in range(len(accum_modalities)): mw = mod_windows[mod_idx] mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))] @@ -479,7 +516,6 @@ class IndexListContextHandler(ContextHandlerABC): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] f = accum[mod_idx][ci] - # Re-append model's suffix (auxiliary frames stripped before windowing) if mod_idx == 0 and window_data.suffix is not None: f = torch.cat([f, window_data.suffix], dim=self.dim) finalized.append(f)