From 5bfe660b7ca2941200984377a151b3449d2926f2 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 22 Mar 2026 07:46:45 -0600 Subject: [PATCH] Test implementation for LTX2 context windows --- comfy/context_windows.py | 156 ++++++++++++++++++++++++++++++++------- comfy/model_base.py | 33 +++++++++ 2 files changed, 162 insertions(+), 27 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index cb44ee6e8..29ee2b5b1 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -8,6 +8,8 @@ from abc import ABC, abstractmethod import logging import comfy.model_management import comfy.patcher_extension +import comfy.utils +import comfy.conds if TYPE_CHECKING: from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher @@ -51,12 +53,13 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None): self.index_list = index_list self.context_length = len(index_list) self.dim = dim self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) + self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: @@ -165,10 +168,44 @@ class IndexListContextHandler(ContextHandlerABC): self.callbacks = {} + def _get_latent_shapes(self, conds): + """Extract latent_shapes from conditioning. Returns None if absent.""" + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + return model_conds['latent_shapes'].cond + return None + + def _decompose(self, x, latent_shapes): + """Packed tensor -> list of per-modality tensors.""" + if latent_shapes is not None and len(latent_shapes) > 1: + return comfy.utils.unpack_latents(x, latent_shapes) + return [x] + + def _compose(self, modalities): + """List of per-modality tensors -> single tensor for pipeline.""" + if len(modalities) > 1: + return comfy.utils.pack_latents(modalities) + return modalities[0], [modalities[0].shape] + + def _patch_latent_shapes(self, sub_conds, new_shapes): + """Patch latent_shapes CONDConstant in (already-copied) sub_conds.""" + for cond_list in sub_conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - # for now, assume first dim is batch - should have stored on BaseModel in actual implementation - if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + latent_shapes = self._get_latent_shapes(conds) + primary = self._decompose(x_in, latent_shapes)[0] + if primary.size(self.dim) > self.context_length: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {primary.size(self.dim)} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True @@ -277,36 +314,98 @@ class IndexListContextHandler(ContextHandlerABC): def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): self._model = model self.set_step(timestep, model_options) - context_windows = self.get_context_windows(model, x_in, model_options) - enumerated_context_windows = list(enumerate(context_windows)) - conds_final = [torch.zeros_like(x_in) for _ in conds] + # Decompose — single-modality: [x_in], multimodal: [video, audio, ...] + latent_shapes = self._get_latent_shapes(conds) + modalities = self._decompose(x_in, latent_shapes) + is_multimodal = len(modalities) > 1 + primary = modalities[0] + + # Windows from primary modality's temporal dim + context_windows = self.get_context_windows(model, primary, model_options) + enumerated_context_windows = list(enumerate(context_windows)) + total_windows = len(enumerated_context_windows) + + # Per-modality accumulators: accum[mod_idx][cond_idx] + accum = [[torch.zeros_like(m) for _ in conds] for m in modalities] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities] else: - counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in modalities] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - for enum_window in enumerated_context_windows: - results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) - for result in results: - self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) - try: - # finalize conds - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - # relative is already normalized, so return as is - del counts_final - return conds_final + for window_idx, window in enumerated_context_windows: + comfy.model_management.throw_exception_if_processing_interrupted() + + # Per-modality window indices + if is_multimodal: + per_mod_indices = model.map_context_window_to_modalities( + window.index_list, latent_shapes, self.dim) + # Build per-modality windows and attach to primary window + modality_windows = {} + for mod_idx in range(1, len(modalities)): + modality_windows[mod_idx] = IndexListContextWindow( + per_mod_indices[mod_idx], dim=self.dim, + total_frames=modalities[mod_idx].shape[self.dim]) + window = IndexListContextWindow( + window.index_list, dim=self.dim, total_frames=primary.shape[self.dim], + modality_windows=modality_windows) else: - # normalize conds via division by context usage counts - for i in range(len(conds_final)): - conds_final[i] /= counts_final[i] - del counts_final - return conds_final + per_mod_indices = [window.index_list] + + # Build per-modality windows list (including primary) + mod_windows = [window] # primary window at index 0 + if is_multimodal: + for mod_idx in range(1, len(modalities)): + mod_windows.append(modality_windows[mod_idx]) + + # Slice each modality + sliced = [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(len(modalities))] + + # Compose for pipeline + sub_x, sub_shapes = self._compose(sliced) + + # Callbacks + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None) + + model_options["transformer_options"]["context_window"] = window + sub_timestep = window.get_tensor(timestep, dim=0) + # Resize conds using primary tensor as reference (correct temporal dim) + sub_conds = [self.get_resized_cond(cond, primary, window) for cond in conds] + if is_multimodal: + self._patch_latent_shapes(sub_conds, sub_shapes) + + sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) + + # Decompose output per modality + out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + # out_per_mod[cond_idx][mod_idx] = tensor + + # Accumulate per modality + for mod_idx in range(len(modalities)): + mw = mod_windows[mod_idx] + # Build per-modality sub_conds_out list for combine + mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))] + self.combine_context_window_results( + modalities[mod_idx], mod_sub_out, sub_conds, mw, + window_idx, total_windows, timestep, + accum[mod_idx], counts[mod_idx], biases[mod_idx]) + + try: + result = [] + for ci in range(len(conds)): + finalized = [] + for mod_idx in range(len(modalities)): + if self.fuse_method.name != ContextFuseMethods.RELATIVE: + accum[mod_idx][ci] /= counts[mod_idx][ci] + finalized.append(accum[mod_idx][ci]) + composed, _ = self._compose(finalized) + result.append(composed) + return result finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) @@ -374,7 +473,10 @@ def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, handler: IndexListContextHandler = model_options.get("context_handler", None) if handler is not None: noise_shape = list(noise_shape) - noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + # Guard: only clamp when dim is within bounds and the value is meaningful + # (packed multimodal tensors have noise_shape=[B,1,flat] where flat is not frame count) + if handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) return executor(model, noise_shape, *args, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index c2ae646aa..3096ca4fb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -293,6 +293,11 @@ class BaseModel(torch.nn.Module): Use comfy.context_windows.slice_cond() for common cases.""" return None + def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): + """Map primary modality's window indices to all modalities. + Returns list of index lists, one per modality.""" + return [primary_indices] + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1082,6 +1087,34 @@ class LTXAV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): + result = [primary_indices] + if len(latent_shapes) < 2: + return result + + video_total = latent_shapes[0][dim] + audio_total = latent_shapes[1][dim] + + # Proportional mapping — video and audio cover same real-time duration + v_start, v_end = min(primary_indices), max(primary_indices) + 1 + a_start = round(v_start * audio_total / video_total) + a_end = round(v_end * audio_total / video_total) + audio_indices = list(range(a_start, min(a_end, audio_total))) + if not audio_indices: + audio_indices = [min(a_start, audio_total - 1)] + + result.append(audio_indices) + return result + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: + audio_window = window.modality_windows.get(1) + if audio_window is not None: + import comfy.context_windows + return comfy.context_windows.slice_cond( + cond_value, audio_window, x_in, device, temporal_dim=2) + return None + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)