From 5bfe660b7ca2941200984377a151b3449d2926f2 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 22 Mar 2026 07:46:45 -0600 Subject: [PATCH 01/23] Test implementation for LTX2 context windows --- comfy/context_windows.py | 156 ++++++++++++++++++++++++++++++++------- comfy/model_base.py | 33 +++++++++ 2 files changed, 162 insertions(+), 27 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index cb44ee6e8..29ee2b5b1 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -8,6 +8,8 @@ from abc import ABC, abstractmethod import logging import comfy.model_management import comfy.patcher_extension +import comfy.utils +import comfy.conds if TYPE_CHECKING: from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher @@ -51,12 +53,13 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None): self.index_list = index_list self.context_length = len(index_list) self.dim = dim self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) + self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: @@ -165,10 +168,44 @@ class IndexListContextHandler(ContextHandlerABC): self.callbacks = {} + def _get_latent_shapes(self, conds): + """Extract latent_shapes from conditioning. Returns None if absent.""" + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + return model_conds['latent_shapes'].cond + return None + + def _decompose(self, x, latent_shapes): + """Packed tensor -> list of per-modality tensors.""" + if latent_shapes is not None and len(latent_shapes) > 1: + return comfy.utils.unpack_latents(x, latent_shapes) + return [x] + + def _compose(self, modalities): + """List of per-modality tensors -> single tensor for pipeline.""" + if len(modalities) > 1: + return comfy.utils.pack_latents(modalities) + return modalities[0], [modalities[0].shape] + + def _patch_latent_shapes(self, sub_conds, new_shapes): + """Patch latent_shapes CONDConstant in (already-copied) sub_conds.""" + for cond_list in sub_conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - # for now, assume first dim is batch - should have stored on BaseModel in actual implementation - if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + latent_shapes = self._get_latent_shapes(conds) + primary = self._decompose(x_in, latent_shapes)[0] + if primary.size(self.dim) > self.context_length: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {primary.size(self.dim)} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True @@ -277,36 +314,98 @@ class IndexListContextHandler(ContextHandlerABC): def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): self._model = model self.set_step(timestep, model_options) - context_windows = self.get_context_windows(model, x_in, model_options) - enumerated_context_windows = list(enumerate(context_windows)) - conds_final = [torch.zeros_like(x_in) for _ in conds] + # Decompose — single-modality: [x_in], multimodal: [video, audio, ...] + latent_shapes = self._get_latent_shapes(conds) + modalities = self._decompose(x_in, latent_shapes) + is_multimodal = len(modalities) > 1 + primary = modalities[0] + + # Windows from primary modality's temporal dim + context_windows = self.get_context_windows(model, primary, model_options) + enumerated_context_windows = list(enumerate(context_windows)) + total_windows = len(enumerated_context_windows) + + # Per-modality accumulators: accum[mod_idx][cond_idx] + accum = [[torch.zeros_like(m) for _ in conds] for m in modalities] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities] else: - counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in modalities] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - for enum_window in enumerated_context_windows: - results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) - for result in results: - self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) - try: - # finalize conds - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - # relative is already normalized, so return as is - del counts_final - return conds_final + for window_idx, window in enumerated_context_windows: + comfy.model_management.throw_exception_if_processing_interrupted() + + # Per-modality window indices + if is_multimodal: + per_mod_indices = model.map_context_window_to_modalities( + window.index_list, latent_shapes, self.dim) + # Build per-modality windows and attach to primary window + modality_windows = {} + for mod_idx in range(1, len(modalities)): + modality_windows[mod_idx] = IndexListContextWindow( + per_mod_indices[mod_idx], dim=self.dim, + total_frames=modalities[mod_idx].shape[self.dim]) + window = IndexListContextWindow( + window.index_list, dim=self.dim, total_frames=primary.shape[self.dim], + modality_windows=modality_windows) else: - # normalize conds via division by context usage counts - for i in range(len(conds_final)): - conds_final[i] /= counts_final[i] - del counts_final - return conds_final + per_mod_indices = [window.index_list] + + # Build per-modality windows list (including primary) + mod_windows = [window] # primary window at index 0 + if is_multimodal: + for mod_idx in range(1, len(modalities)): + mod_windows.append(modality_windows[mod_idx]) + + # Slice each modality + sliced = [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(len(modalities))] + + # Compose for pipeline + sub_x, sub_shapes = self._compose(sliced) + + # Callbacks + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None) + + model_options["transformer_options"]["context_window"] = window + sub_timestep = window.get_tensor(timestep, dim=0) + # Resize conds using primary tensor as reference (correct temporal dim) + sub_conds = [self.get_resized_cond(cond, primary, window) for cond in conds] + if is_multimodal: + self._patch_latent_shapes(sub_conds, sub_shapes) + + sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) + + # Decompose output per modality + out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + # out_per_mod[cond_idx][mod_idx] = tensor + + # Accumulate per modality + for mod_idx in range(len(modalities)): + mw = mod_windows[mod_idx] + # Build per-modality sub_conds_out list for combine + mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))] + self.combine_context_window_results( + modalities[mod_idx], mod_sub_out, sub_conds, mw, + window_idx, total_windows, timestep, + accum[mod_idx], counts[mod_idx], biases[mod_idx]) + + try: + result = [] + for ci in range(len(conds)): + finalized = [] + for mod_idx in range(len(modalities)): + if self.fuse_method.name != ContextFuseMethods.RELATIVE: + accum[mod_idx][ci] /= counts[mod_idx][ci] + finalized.append(accum[mod_idx][ci]) + composed, _ = self._compose(finalized) + result.append(composed) + return result finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) @@ -374,7 +473,10 @@ def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, handler: IndexListContextHandler = model_options.get("context_handler", None) if handler is not None: noise_shape = list(noise_shape) - noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + # Guard: only clamp when dim is within bounds and the value is meaningful + # (packed multimodal tensors have noise_shape=[B,1,flat] where flat is not frame count) + if handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) return executor(model, noise_shape, *args, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index c2ae646aa..3096ca4fb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -293,6 +293,11 @@ class BaseModel(torch.nn.Module): Use comfy.context_windows.slice_cond() for common cases.""" return None + def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): + """Map primary modality's window indices to all modalities. + Returns list of index lists, one per modality.""" + return [primary_indices] + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1082,6 +1087,34 @@ class LTXAV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): + result = [primary_indices] + if len(latent_shapes) < 2: + return result + + video_total = latent_shapes[0][dim] + audio_total = latent_shapes[1][dim] + + # Proportional mapping — video and audio cover same real-time duration + v_start, v_end = min(primary_indices), max(primary_indices) + 1 + a_start = round(v_start * audio_total / video_total) + a_end = round(v_end * audio_total / video_total) + audio_indices = list(range(a_start, min(a_end, audio_total))) + if not audio_indices: + audio_indices = [min(a_start, audio_total - 1)] + + result.append(audio_indices) + return result + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: + audio_window = window.modality_windows.get(1) + if audio_window is not None: + import comfy.context_windows + return comfy.context_windows.slice_cond( + cond_value, audio_window, x_in, device, temporal_dim=2) + return None + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) From 56de390c2533112ba9925604573b12396a229885 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 23 Mar 2026 13:50:28 -0600 Subject: [PATCH 02/23] LTX2 context windows part 2 - Guide aware processing --- comfy/context_windows.py | 93 ++++++++++++++++++++++++------- comfy/model_base.py | 117 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 185 insertions(+), 25 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 29ee2b5b1..91d019a00 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -204,8 +204,13 @@ class IndexListContextHandler(ContextHandlerABC): def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: latent_shapes = self._get_latent_shapes(conds) primary = self._decompose(x_in, latent_shapes)[0] - if primary.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {primary.size(self.dim)} frames.") + guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0 + video_frames = primary.size(self.dim) - guide_count + if video_frames > self.context_length: + if guide_count > 0: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_count} guide frames excluded).") + else: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True @@ -321,18 +326,32 @@ class IndexListContextHandler(ContextHandlerABC): is_multimodal = len(modalities) > 1 primary = modalities[0] - # Windows from primary modality's temporal dim - context_windows = self.get_context_windows(model, primary, model_options) + # Separate guide frames from primary modality (guides are appended at the end) + guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0 + if guide_count > 0: + video_len = primary.size(self.dim) - guide_count + video_primary = primary.narrow(self.dim, 0, video_len) + guide_suffix = primary.narrow(self.dim, video_len, guide_count) + else: + video_primary = primary + guide_suffix = None + + # Windows from video portion only (excluding guide frames) + context_windows = self.get_context_windows(model, video_primary, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) - # Per-modality accumulators: accum[mod_idx][cond_idx] - accum = [[torch.zeros_like(m) for _ in conds] for m in modalities] + # Accumulators sized to video portion for primary, full for other modalities + accum_modalities = list(modalities) + if guide_suffix is not None: + accum_modalities[0] = video_primary + + accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] else: - counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities] - biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in modalities] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) @@ -340,10 +359,22 @@ class IndexListContextHandler(ContextHandlerABC): for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() + # Attach guide info to window for resize_cond_for_context_window + window.guide_count = guide_count + if guide_suffix is not None: + window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4]) + # Per-modality window indices if is_multimodal: + # Adjust latent_shapes so video shape reflects video-only frames (excludes guides) + map_shapes = latent_shapes + if guide_count > 0: + map_shapes = list(latent_shapes) + video_shape = list(latent_shapes[0]) + video_shape[self.dim] = video_shape[self.dim] - guide_count + map_shapes[0] = torch.Size(video_shape) per_mod_indices = model.map_context_window_to_modalities( - window.index_list, latent_shapes, self.dim) + window.index_list, map_shapes, self.dim) # Build per-modality windows and attach to primary window modality_windows = {} for mod_idx in range(1, len(modalities)): @@ -351,8 +382,11 @@ class IndexListContextHandler(ContextHandlerABC): per_mod_indices[mod_idx], dim=self.dim, total_frames=modalities[mod_idx].shape[self.dim]) window = IndexListContextWindow( - window.index_list, dim=self.dim, total_frames=primary.shape[self.dim], + window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim], modality_windows=modality_windows) + window.guide_count = guide_count + if guide_suffix is not None: + window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4]) else: per_mod_indices = [window.index_list] @@ -362,8 +396,14 @@ class IndexListContextHandler(ContextHandlerABC): for mod_idx in range(1, len(modalities)): mod_windows.append(modality_windows[mod_idx]) - # Slice each modality - sliced = [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(len(modalities))] + # Slice video and guide with same window indices, concatenate + sliced_video = mod_windows[0].get_tensor(video_primary) + if guide_suffix is not None: + sliced_guide = mod_windows[0].get_tensor(guide_suffix) + sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim) + else: + sliced_primary = sliced_video + sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))] # Compose for pipeline sub_x, sub_shapes = self._compose(sliced) @@ -374,8 +414,8 @@ class IndexListContextHandler(ContextHandlerABC): model_options["transformer_options"]["context_window"] = window sub_timestep = window.get_tensor(timestep, dim=0) - # Resize conds using primary tensor as reference (correct temporal dim) - sub_conds = [self.get_resized_cond(cond, primary, window) for cond in conds] + # Resize conds using video_primary as reference (excludes guide frames) + sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds] if is_multimodal: self._patch_latent_shapes(sub_conds, sub_shapes) @@ -385,13 +425,19 @@ class IndexListContextHandler(ContextHandlerABC): out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] # out_per_mod[cond_idx][mod_idx] = tensor - # Accumulate per modality - for mod_idx in range(len(modalities)): + # Strip guide frames from primary output before accumulation + if guide_count > 0: + window_len = len(window.index_list) + for ci in range(len(sub_conds_out)): + primary_out = out_per_mod[ci][0] + out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len) + + # Accumulate per modality (using video-only sizes) + for mod_idx in range(len(accum_modalities)): mw = mod_windows[mod_idx] - # Build per-modality sub_conds_out list for combine mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))] self.combine_context_window_results( - modalities[mod_idx], mod_sub_out, sub_conds, mw, + accum_modalities[mod_idx], mod_sub_out, sub_conds, mw, window_idx, total_windows, timestep, accum[mod_idx], counts[mod_idx], biases[mod_idx]) @@ -399,10 +445,15 @@ class IndexListContextHandler(ContextHandlerABC): result = [] for ci in range(len(conds)): finalized = [] - for mod_idx in range(len(modalities)): + for mod_idx in range(len(accum_modalities)): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] - finalized.append(accum[mod_idx][ci]) + f = accum[mod_idx][ci] + # Re-append original guide_suffix (not model output — sampling loop + # respects denoise_mask and never modifies guide frame positions) + if mod_idx == 0 and guide_suffix is not None: + f = torch.cat([f, guide_suffix], dim=self.dim) + finalized.append(f) composed, _ = self._compose(finalized) result.append(composed) return result diff --git a/comfy/model_base.py b/comfy/model_base.py index 3096ca4fb..12d49305f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -298,6 +298,11 @@ class BaseModel(torch.nn.Module): Returns list of index lists, one per modality.""" return [primary_indices] + def get_guide_frame_count(self, x, conds): + """Return the number of trailing guide frames appended to x along the temporal dim. + Override in subclasses that concatenate guide reference frames to the latent.""" + return 0 + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1021,12 +1026,64 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + def get_guide_frame_count(self, x, conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + gae = model_conds.get('guide_attention_entries') + if gae is not None and hasattr(gae, 'cond') and gae.cond: + return sum(e["latent_shape"][0] for e in gae.cond) + return 0 + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + guide_count = getattr(window, 'guide_count', 0) + + if cond_key == "denoise_mask" and guide_count > 0: + # Slice both video and guide halves with same window indices + cond_tensor = cond_value.cond + T_video = cond_tensor.size(window.dim) - guide_count + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + sliced_guide = window.get_tensor(guide_mask, device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + + if cond_key == "keyframe_idxs" and guide_count > 0: + # Recompute coords for window_len frames so guide tokens are co-located + # with noise tokens in RoPE space (identical to a standalone short video) + window_len = len(window.index_list) + H, W = window.guide_spatial + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + if cond_key == "guide_attention_entries" and guide_count > 0: + # Adjust token counts for window size + window_len = len(window.index_list) + H, W = window.guide_spatial + new_entries = [{**e, "pre_filter_count": window_len * H * W, + "latent_shape": [window_len, H, W]} for e in cond_value.cond] + return cond_value._copy_with(new_entries) + + return None + class LTXAV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) + logging.info(f"LTXAV.extra_conds: guide_attention_entries={'guide_attention_entries' in kwargs}, keyframe_idxs={'keyframe_idxs' in kwargs}") attention_mask = kwargs.get("attention_mask", None) device = kwargs["device"] @@ -1106,13 +1163,65 @@ class LTXAV(BaseModel): result.append(audio_indices) return result + def get_guide_frame_count(self, x, conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + gae = model_conds.get('guide_attention_entries') + logging.info(f"LTXAV.get_guide_frame_count: keys={list(model_conds.keys())}, gae={gae is not None}") + if gae is not None and hasattr(gae, 'cond') and gae.cond: + count = sum(e["latent_shape"][0] for e in gae.cond) + logging.info(f"LTXAV.get_guide_frame_count: found {count} guide frames") + return count + logging.info("LTXAV.get_guide_frame_count: no guide frames found") + return 0 + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # Audio-specific handling if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: audio_window = window.modality_windows.get(1) - if audio_window is not None: - import comfy.context_windows - return comfy.context_windows.slice_cond( - cond_value, audio_window, x_in, device, temporal_dim=2) + if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) + return cond_value._copy_with(sliced) + + # Guide handling (same as LTXV — shared guide mechanism) + guide_count = getattr(window, 'guide_count', 0) + if cond_key in ("keyframe_idxs", "guide_attention_entries", "denoise_mask"): + logging.info(f"LTXAV resize_cond: {cond_key}, guide_count={guide_count}, has_spatial={hasattr(window, 'guide_spatial')}") + + if cond_key == "denoise_mask" and guide_count > 0: + cond_tensor = cond_value.cond + T_video = cond_tensor.size(window.dim) - guide_count + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + sliced_guide = window.get_tensor(guide_mask, device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + + if cond_key == "keyframe_idxs" and guide_count > 0: + window_len = len(window.index_list) + H, W = window.guide_spatial + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + if cond_key == "guide_attention_entries" and guide_count > 0: + window_len = len(window.index_list) + H, W = window.guide_spatial + new_entries = [{**e, "pre_filter_count": window_len * H * W, + "latent_shape": [window_len, H, W]} for e in cond_value.cond] + return cond_value._copy_with(new_entries) + return None class HunyuanVideo(BaseModel): From 941d50e77733ba160086478b614b6e5e0fec7ab7 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 24 Mar 2026 11:43:42 -0600 Subject: [PATCH 03/23] LTX2 context windows part 2b - Calculate guide parameters in model code, refactor --- comfy/context_windows.py | 11 +--- comfy/model_base.py | 124 ++++++++++++++------------------------- 2 files changed, 48 insertions(+), 87 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 91d019a00..357fbae17 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -358,11 +358,9 @@ class IndexListContextHandler(ContextHandlerABC): for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() - - # Attach guide info to window for resize_cond_for_context_window - window.guide_count = guide_count - if guide_suffix is not None: - window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4]) + logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}" + + (f" (+{guide_count} guide)" if guide_count > 0 else "") + + (f" [{len(modalities)} modalities]" if is_multimodal else "")) # Per-modality window indices if is_multimodal: @@ -384,9 +382,6 @@ class IndexListContextHandler(ContextHandlerABC): window = IndexListContextWindow( window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim], modality_windows=modality_windows) - window.guide_count = guide_count - if guide_suffix is not None: - window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4]) else: per_mod_indices = [window.index_list] diff --git a/comfy/model_base.py b/comfy/model_base.py index 12d49305f..9c31e2651 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,6 +303,45 @@ class BaseModel(torch.nn.Module): Override in subclasses that concatenate guide reference frames to the latent.""" return 0 + def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + """Resize guide-related conditioning for context windows. + Derives guide_count from denoise_mask/x_in size difference. + Derives spatial dims from x_in. Requires self.diffusion_model.patchifier.""" + if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + cond_tensor = cond_value.cond + guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) + if guide_count > 0: + T_video = x_in.size(window.dim) + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + sliced_guide = window.get_tensor(guide_mask, device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + + if cond_key == "keyframe_idxs": + window_len = len(window.index_list) + H, W = x_in.shape[3], x_in.shape[4] + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + if cond_key == "guide_attention_entries": + window_len = len(window.index_list) + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [{**e, "pre_filter_count": window_len * H * W, + "latent_shape": [window_len, H, W]} for e in cond_value.cond] + return cond_value._copy_with(new_entries) + + return None + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1038,44 +1077,7 @@ class LTXV(BaseModel): return 0 def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - guide_count = getattr(window, 'guide_count', 0) - - if cond_key == "denoise_mask" and guide_count > 0: - # Slice both video and guide halves with same window indices - cond_tensor = cond_value.cond - T_video = cond_tensor.size(window.dim) - guide_count - video_mask = cond_tensor.narrow(window.dim, 0, T_video) - guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) - sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - sliced_guide = window.get_tensor(guide_mask, device) - return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) - - if cond_key == "keyframe_idxs" and guide_count > 0: - # Recompute coords for window_len frames so guide tokens are co-located - # with noise tokens in RoPE space (identical to a standalone short video) - window_len = len(window.index_list) - H, W = window.guide_spatial - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) - - if cond_key == "guide_attention_entries" and guide_count > 0: - # Adjust token counts for window size - window_len = len(window.index_list) - H, W = window.guide_spatial - new_entries = [{**e, "pre_filter_count": window_len * H * W, - "latent_shape": [window_len, H, W]} for e in cond_value.cond] - return cond_value._copy_with(new_entries) - - return None + return self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list) class LTXAV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): @@ -1083,7 +1085,6 @@ class LTXAV(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) - logging.info(f"LTXAV.extra_conds: guide_attention_entries={'guide_attention_entries' in kwargs}, keyframe_idxs={'keyframe_idxs' in kwargs}") attention_mask = kwargs.get("attention_mask", None) device = kwargs["device"] @@ -1170,12 +1171,8 @@ class LTXAV(BaseModel): for cond_dict in cond_list: model_conds = cond_dict.get('model_conds', {}) gae = model_conds.get('guide_attention_entries') - logging.info(f"LTXAV.get_guide_frame_count: keys={list(model_conds.keys())}, gae={gae is not None}") if gae is not None and hasattr(gae, 'cond') and gae.cond: - count = sum(e["latent_shape"][0] for e in gae.cond) - logging.info(f"LTXAV.get_guide_frame_count: found {count} guide frames") - return count - logging.info("LTXAV.get_guide_frame_count: no guide frames found") + return sum(e["latent_shape"][0] for e in gae.cond) return 0 def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): @@ -1186,41 +1183,10 @@ class LTXAV(BaseModel): sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) return cond_value._copy_with(sliced) - # Guide handling (same as LTXV — shared guide mechanism) - guide_count = getattr(window, 'guide_count', 0) - if cond_key in ("keyframe_idxs", "guide_attention_entries", "denoise_mask"): - logging.info(f"LTXAV resize_cond: {cond_key}, guide_count={guide_count}, has_spatial={hasattr(window, 'guide_spatial')}") - - if cond_key == "denoise_mask" and guide_count > 0: - cond_tensor = cond_value.cond - T_video = cond_tensor.size(window.dim) - guide_count - video_mask = cond_tensor.narrow(window.dim, 0, T_video) - guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) - sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - sliced_guide = window.get_tensor(guide_mask, device) - return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) - - if cond_key == "keyframe_idxs" and guide_count > 0: - window_len = len(window.index_list) - H, W = window.guide_spatial - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) - - if cond_key == "guide_attention_entries" and guide_count > 0: - window_len = len(window.index_list) - H, W = window.guide_spatial - new_entries = [{**e, "pre_filter_count": window_len * H * W, - "latent_shape": [window_len, H, W]} for e in cond_value.cond] - return cond_value._copy_with(new_entries) + # Guide handling (shared with LTXV) + result = self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list) + if result is not None: + return result return None From 115dbb69d18f29b1112cc7aec8a51ecf55f3e7ae Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:23:43 -0600 Subject: [PATCH 04/23] LTX2 context windows part 3 - Generalize guide splitting to windows --- comfy/context_windows.py | 84 ++++++++++++++++++++++++++++++-- comfy/ldm/lightricks/model.py | 2 +- comfy/model_base.py | 90 ++++++++++++++++++++++++++--------- comfy_extras/nodes_lt.py | 4 +- 4 files changed, 152 insertions(+), 28 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 357fbae17..4ace5ec13 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -140,6 +140,48 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) +def _compute_guide_overlap(guide_entries, window_index_list): + """Compute which guide frames overlap with a context window. + + Args: + guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape') + window_index_list: the window's frame indices into the video portion + + Returns None if any entry lacks 'latent_start' (backward compat → legacy path). + Otherwise returns (suffix_indices, overlap_info, kf_local_positions, total_overlap): + suffix_indices: indices into the guide_suffix tensor for frame selection + overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment + kf_local_positions: window-local frame positions for keyframe_idxs regeneration + total_overlap: total number of overlapping guide frames + """ + window_set = set(window_index_list) + window_list = list(window_index_list) + suffix_indices = [] + overlap_info = [] + kf_local_positions = [] + suffix_base = 0 + + for entry_idx, entry in enumerate(guide_entries): + latent_start = entry.get("latent_start", None) + if latent_start is None: + return None + guide_len = entry["latent_shape"][0] + entry_overlap = 0 + + for local_offset in range(guide_len): + video_pos = latent_start + local_offset + if video_pos in window_set: + suffix_indices.append(suffix_base + local_offset) + kf_local_positions.append(window_list.index(video_pos)) + entry_overlap += 1 + + if entry_overlap > 0: + overlap_info.append((entry_idx, entry_overlap)) + suffix_base += guide_len + + return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) + + @dataclass class ContextSchedule: name: str @@ -201,6 +243,18 @@ class IndexListContextHandler(ContextHandlerABC): if 'latent_shapes' in model_conds: model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + def _get_guide_entries(self, conds): + """Extract guide_attention_entries list from conditioning. Returns None if absent.""" + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + gae = model_conds.get('guide_attention_entries') + if gae is not None and hasattr(gae, 'cond') and gae.cond: + return gae.cond + return None + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: latent_shapes = self._get_latent_shapes(conds) primary = self._decompose(x_in, latent_shapes)[0] @@ -353,6 +407,8 @@ class IndexListContextHandler(ContextHandlerABC): counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities] + guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) @@ -391,10 +447,30 @@ class IndexListContextHandler(ContextHandlerABC): for mod_idx in range(1, len(modalities)): mod_windows.append(modality_windows[mod_idx]) - # Slice video and guide with same window indices, concatenate + # Slice video, then select overlapping guide frames sliced_video = mod_windows[0].get_tensor(video_primary) - if guide_suffix is not None: - sliced_guide = mod_windows[0].get_tensor(guide_suffix) + num_guide_in_window = 0 + if guide_suffix is not None and guide_entries is not None: + overlap = _compute_guide_overlap(guide_entries, window.index_list) + if overlap is None: + # Legacy: no latent_start → equal-size assumption + sliced_guide = mod_windows[0].get_tensor(guide_suffix) + num_guide_in_window = sliced_guide.shape[self.dim] + elif overlap[3] > 0: + suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap + idx = tuple([slice(None)] * self.dim + [suffix_idx]) + sliced_guide = guide_suffix[idx] + window.guide_suffix_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + else: + sliced_guide = None + window.guide_overlap_info = [] + window.guide_kf_local_positions = [] + else: + sliced_guide = None + + if sliced_guide is not None: sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim) else: sliced_primary = sliced_video @@ -421,7 +497,7 @@ class IndexListContextHandler(ContextHandlerABC): # out_per_mod[cond_idx][mod_idx] = tensor # Strip guide frames from primary output before accumulation - if guide_count > 0: + if num_guide_in_window > 0: window_len = len(window.index_list) for ci in range(len(sub_conds_out)): primary_out = out_per_mod[ci][0] diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index bfbc08357..c55e19ced 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1028,7 +1028,7 @@ class LTXVModel(LTXBaseModel): ) grid_mask = None - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: additional_args.update({ "orig_patchified_shape": list(x.shape)}) denoise_mask = self.patchifier.patchify(denoise_mask)[0] grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] diff --git a/comfy/model_base.py b/comfy/model_base.py index 9c31e2651..ae2ce2eb0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,8 +305,8 @@ class BaseModel(torch.nn.Module): def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): """Resize guide-related conditioning for context windows. - Derives guide_count from denoise_mask/x_in size difference. - Derives spatial dims from x_in. Requires self.diffusion_model.patchifier.""" + Uses overlap info from window if available (generalized path), + otherwise falls back to legacy equal-size assumption.""" if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): cond_tensor = cond_value.cond guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) @@ -315,30 +315,76 @@ class BaseModel(torch.nn.Module): video_mask = cond_tensor.narrow(window.dim, 0, T_video) guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - sliced_guide = window.get_tensor(guide_mask, device) - return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + # Use overlap-based guide selection if available, otherwise legacy + suffix_indices = getattr(window, 'guide_suffix_indices', None) + if suffix_indices is not None: + idx = tuple([slice(None)] * window.dim + [suffix_indices]) + sliced_guide = guide_mask[idx].to(device) if suffix_indices else None + else: + sliced_guide = window.get_tensor(guide_mask, device) + if sliced_guide is not None and sliced_guide.shape[window.dim] > 0: + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + else: + return cond_value._copy_with(sliced_video) if cond_key == "keyframe_idxs": - window_len = len(window.index_list) - H, W = x_in.shape[3], x_in.shape[4] - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) + kf_local_pos = getattr(window, 'guide_kf_local_positions', None) + if kf_local_pos is not None: + # Generalized: regenerate coords for full window, select guide positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + else: + # Legacy: regenerate for window_len (equal-size assumption) + window_len = len(window.index_list) + H, W = x_in.shape[3], x_in.shape[4] + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) if cond_key == "guide_attention_entries": - window_len = len(window.index_list) - H, W = x_in.shape[3], x_in.shape[4] - new_entries = [{**e, "pre_filter_count": window_len * H * W, - "latent_shape": [window_len, H, W]} for e in cond_value.cond] - return cond_value._copy_with(new_entries) + overlap_info = getattr(window, 'guide_overlap_info', None) + if overlap_info is not None: + # Generalized: per-guide adjustment based on overlap + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) + else: + # Legacy: all entries adjusted to window_len + window_len = len(window.index_list) + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [{**e, "pre_filter_count": window_len * H * W, + "latent_shape": [window_len, H, W]} for e in cond_value.cond] + return cond_value._copy_with(new_entries) return None diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index d7c2e8744..d8ba0bb27 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -135,7 +135,7 @@ class LTXVImgToVideoInplace(io.ComfyNode): generate = execute # TODO: remove -def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, latent_start=0): """Append a guide_attention_entry to both positive and negative conditioning. Each entry tracks one guide reference for per-reference attention control. @@ -146,6 +146,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s "strength": strength, "pixel_mask": None, "latent_shape": latent_shape, + "latent_start": latent_start, } results = [] for cond in (positive, negative): @@ -362,6 +363,7 @@ class LTXVAddGuide(io.ComfyNode): guide_latent_shape = list(t.shape[2:]) # [F, H, W] positive, negative = _append_guide_attention_entry( positive, negative, pre_filter_count, guide_latent_shape, strength=strength, + latent_start=latent_idx, ) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) From ef61ddfaed4d1f9e983dc0ce4f1afdd17d8f519a Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:24:40 -0600 Subject: [PATCH 05/23] Fix freenoise application for LTXAV context windows, fix audio mapping to context windows --- comfy/context_windows.py | 23 ++++++++++++++++++++++- comfy/model_base.py | 19 ++++++++++--------- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 4ace5ec13..b528f6327 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -619,7 +619,28 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + # For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise + # must only shuffle the video portion. Unpack, apply to video, repack. + latent_shapes = None + try: + latent_shapes = guider.conds['positive'][0]['model_conds']['latent_shapes'].cond + except (KeyError, IndexError, AttributeError): + pass + + if latent_shapes is not None and len(latent_shapes) > 1: + modalities = comfy.utils.unpack_latents(noise, latent_shapes) + video_total = latent_shapes[0][handler.dim] + modalities[0] = apply_freenoise(modalities[0], handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + for i in range(1, len(modalities)): + mod_total = latent_shapes[i][handler.dim] + ratio = mod_total / video_total if video_total > 0 else 1 + mod_ctx_len = max(round(handler.context_length * ratio), 1) + mod_ctx_overlap = max(round(handler.context_overlap * ratio), 0) + modalities[i] = apply_freenoise(modalities[i], handler.dim, mod_ctx_len, mod_ctx_overlap, extra_args["seed"]) + noise, _ = comfy.utils.pack_latents(modalities) + else: + noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index ae2ce2eb0..893beb85a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1197,17 +1197,18 @@ class LTXAV(BaseModel): return result video_total = latent_shapes[0][dim] - audio_total = latent_shapes[1][dim] + video_window_len = len(primary_indices) - # Proportional mapping — video and audio cover same real-time duration - v_start, v_end = min(primary_indices), max(primary_indices) + 1 - a_start = round(v_start * audio_total / video_total) - a_end = round(v_end * audio_total / video_total) - audio_indices = list(range(a_start, min(a_end, audio_total))) - if not audio_indices: - audio_indices = [min(a_start, audio_total - 1)] + for i in range(1, len(latent_shapes)): + mod_total = latent_shapes[i][dim] + # Length proportional to video window frame count (not index span) + mod_window_len = max(round(video_window_len * mod_total / video_total), 1) + # Anchor to end of video range + v_end = max(primary_indices) + 1 + mod_end = min(round(v_end * mod_total / video_total), mod_total) + mod_start = max(mod_end - mod_window_len, 0) + result.append(list(range(mod_start, min(mod_start + mod_window_len, mod_total)))) - result.append(audio_indices) return result def get_guide_frame_count(self, x, conds): From 9566c18ced6de934a4fa961d5d44bb0c21217e4c Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 5 Apr 2026 14:30:54 -0600 Subject: [PATCH 06/23] LTX2 context windows - Fix crash when a window doesn't have a guide index --- comfy/context_windows.py | 1 + comfy/ldm/lightricks/model.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index b528f6327..10987e73e 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -465,6 +465,7 @@ class IndexListContextHandler(ContextHandlerABC): window.guide_kf_local_positions = kf_local_pos else: sliced_guide = None + window.guide_suffix_indices = [] window.guide_overlap_info = [] window.guide_kf_local_positions = [] else: diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index c55e19ced..74395beae 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1315,7 +1315,7 @@ class LTXVModel(LTXBaseModel): x = x * (1 + scale) + shift x = self.proj_out(x) - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: grid_mask = kwargs["grid_mask"] orig_patchified_shape = kwargs["orig_patchified_shape"] full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) From 71712472f5f19b5b42970c9c094386bfe377274e Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 5 Apr 2026 14:32:06 -0600 Subject: [PATCH 07/23] LTX2 context windows - Ensure that inplace latent images are retained properly with the retain index list --- comfy/context_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 10987e73e..a4c49bbee 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -448,7 +448,7 @@ class IndexListContextHandler(ContextHandlerABC): mod_windows.append(modality_windows[mod_idx]) # Slice video, then select overlapping guide frames - sliced_video = mod_windows[0].get_tensor(video_primary) + sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list) num_guide_in_window = 0 if guide_suffix is not None and guide_entries is not None: overlap = _compute_guide_overlap(guide_entries, window.index_list) From 3660533f83941fb8a1f3779929598f63e398ac22 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 6 Apr 2026 08:48:51 -0600 Subject: [PATCH 08/23] LTX2 context windows - Cleanup: latent_start value is required for context windows with guides --- comfy/context_windows.py | 11 ++--- comfy/model_base.py | 98 ++++++++++++++-------------------------- 2 files changed, 37 insertions(+), 72 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index a4c49bbee..fe1afdffe 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -147,8 +147,7 @@ def _compute_guide_overlap(guide_entries, window_index_list): guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape') window_index_list: the window's frame indices into the video portion - Returns None if any entry lacks 'latent_start' (backward compat → legacy path). - Otherwise returns (suffix_indices, overlap_info, kf_local_positions, total_overlap): + Returns (suffix_indices, overlap_info, kf_local_positions, total_overlap): suffix_indices: indices into the guide_suffix tensor for frame selection overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment kf_local_positions: window-local frame positions for keyframe_idxs regeneration @@ -164,7 +163,7 @@ def _compute_guide_overlap(guide_entries, window_index_list): for entry_idx, entry in enumerate(guide_entries): latent_start = entry.get("latent_start", None) if latent_start is None: - return None + raise ValueError("guide_attention_entry missing required 'latent_start'.") guide_len = entry["latent_shape"][0] entry_overlap = 0 @@ -452,11 +451,7 @@ class IndexListContextHandler(ContextHandlerABC): num_guide_in_window = 0 if guide_suffix is not None and guide_entries is not None: overlap = _compute_guide_overlap(guide_entries, window.index_list) - if overlap is None: - # Legacy: no latent_start → equal-size assumption - sliced_guide = mod_windows[0].get_tensor(guide_suffix) - num_guide_in_window = sliced_guide.shape[self.dim] - elif overlap[3] > 0: + if overlap[3] > 0: suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap idx = tuple([slice(None)] * self.dim + [suffix_idx]) sliced_guide = guide_suffix[idx] diff --git a/comfy/model_base.py b/comfy/model_base.py index 893beb85a..e4659a236 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,8 +305,8 @@ class BaseModel(torch.nn.Module): def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): """Resize guide-related conditioning for context windows. - Uses overlap info from window if available (generalized path), - otherwise falls back to legacy equal-size assumption.""" + Requires guide_suffix_indices, guide_overlap_info, and guide_kf_local_positions + to be set on the window by _compute_guide_overlap.""" if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): cond_tensor = cond_value.cond guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) @@ -315,76 +315,46 @@ class BaseModel(torch.nn.Module): video_mask = cond_tensor.narrow(window.dim, 0, T_video) guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - # Use overlap-based guide selection if available, otherwise legacy - suffix_indices = getattr(window, 'guide_suffix_indices', None) - if suffix_indices is not None: + suffix_indices = window.guide_suffix_indices + if suffix_indices: idx = tuple([slice(None)] * window.dim + [suffix_indices]) - sliced_guide = guide_mask[idx].to(device) if suffix_indices else None - else: - sliced_guide = window.get_tensor(guide_mask, device) - if sliced_guide is not None and sliced_guide.shape[window.dim] > 0: + sliced_guide = guide_mask[idx].to(device) return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) else: return cond_value._copy_with(sliced_video) if cond_key == "keyframe_idxs": - kf_local_pos = getattr(window, 'guide_kf_local_positions', None) - if kf_local_pos is not None: - # Generalized: regenerate coords for full window, select guide positions - if not kf_local_pos: - return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty - H, W = x_in.shape[3], x_in.shape[4] - window_len = len(window.index_list) - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - tokens = [] - for pos in kf_local_pos: - tokens.extend(range(pos * H * W, (pos + 1) * H * W)) - pixel_coords = pixel_coords[:, :, tokens, :] - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) - else: - # Legacy: regenerate for window_len (equal-size assumption) - window_len = len(window.index_list) - H, W = x_in.shape[3], x_in.shape[4] - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) + kf_local_pos = window.guide_kf_local_positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) if cond_key == "guide_attention_entries": - overlap_info = getattr(window, 'guide_overlap_info', None) - if overlap_info is not None: - # Generalized: per-guide adjustment based on overlap - H, W = x_in.shape[3], x_in.shape[4] - new_entries = [] - for entry_idx, overlap_count in overlap_info: - e = cond_value.cond[entry_idx] - new_entries.append({**e, - "pre_filter_count": overlap_count * H * W, - "latent_shape": [overlap_count, H, W]}) - return cond_value._copy_with(new_entries) - else: - # Legacy: all entries adjusted to window_len - window_len = len(window.index_list) - H, W = x_in.shape[3], x_in.shape[4] - new_entries = [{**e, "pre_filter_count": window_len * H * W, - "latent_shape": [window_len, H, W]} for e in cond_value.cond] - return cond_value._copy_with(new_entries) + overlap_info = window.guide_overlap_info + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) return None From 350237618d1035fd016f845390ecd1784d90c48b Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 6 Apr 2026 09:11:35 -0600 Subject: [PATCH 09/23] LTX2 context windows - Cleanup: Remove model specific code from BaseModel. Older LTXV model's guides + context_windows will need to be re-implemented but outside the scope of LTX2 changes --- comfy/context_windows.py | 6 +- comfy/model_base.py | 136 +++++++++++++++------------------------ 2 files changed, 55 insertions(+), 87 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index fe1afdffe..295f348e6 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -257,7 +257,7 @@ class IndexListContextHandler(ContextHandlerABC): def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: latent_shapes = self._get_latent_shapes(conds) primary = self._decompose(x_in, latent_shapes)[0] - guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0 + guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0 video_frames = primary.size(self.dim) - guide_count if video_frames > self.context_length: if guide_count > 0: @@ -380,7 +380,7 @@ class IndexListContextHandler(ContextHandlerABC): primary = modalities[0] # Separate guide frames from primary modality (guides are appended at the end) - guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0 + guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0 if guide_count > 0: video_len = primary.size(self.dim) - guide_count video_primary = primary.narrow(self.dim, 0, video_len) @@ -427,7 +427,7 @@ class IndexListContextHandler(ContextHandlerABC): video_shape[self.dim] = video_shape[self.dim] - guide_count map_shapes[0] = torch.Size(video_shape) per_mod_indices = model.map_context_window_to_modalities( - window.index_list, map_shapes, self.dim) + window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list] # Build per-modality windows and attach to primary window modality_windows = {} for mod_idx in range(1, len(modalities)): diff --git a/comfy/model_base.py b/comfy/model_base.py index e4659a236..066e3f8d8 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -293,71 +293,6 @@ class BaseModel(torch.nn.Module): Use comfy.context_windows.slice_cond() for common cases.""" return None - def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): - """Map primary modality's window indices to all modalities. - Returns list of index lists, one per modality.""" - return [primary_indices] - - def get_guide_frame_count(self, x, conds): - """Return the number of trailing guide frames appended to x along the temporal dim. - Override in subclasses that concatenate guide reference frames to the latent.""" - return 0 - - def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - """Resize guide-related conditioning for context windows. - Requires guide_suffix_indices, guide_overlap_info, and guide_kf_local_positions - to be set on the window by _compute_guide_overlap.""" - if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): - cond_tensor = cond_value.cond - guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) - if guide_count > 0: - T_video = x_in.size(window.dim) - video_mask = cond_tensor.narrow(window.dim, 0, T_video) - guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) - sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - suffix_indices = window.guide_suffix_indices - if suffix_indices: - idx = tuple([slice(None)] * window.dim + [suffix_indices]) - sliced_guide = guide_mask[idx].to(device) - return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) - else: - return cond_value._copy_with(sliced_video) - - if cond_key == "keyframe_idxs": - kf_local_pos = window.guide_kf_local_positions - if not kf_local_pos: - return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty - H, W = x_in.shape[3], x_in.shape[4] - window_len = len(window.index_list) - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - tokens = [] - for pos in kf_local_pos: - tokens.extend(range(pos * H * W, (pos + 1) * H * W)) - pixel_coords = pixel_coords[:, :, tokens, :] - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) - - if cond_key == "guide_attention_entries": - overlap_info = window.guide_overlap_info - H, W = x_in.shape[3], x_in.shape[4] - new_entries = [] - for entry_idx, overlap_count in overlap_info: - e = cond_value.cond[entry_idx] - new_entries.append({**e, - "pre_filter_count": overlap_count * H * W, - "latent_shape": [overlap_count, H, W]}) - return cond_value._copy_with(new_entries) - - return None - def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1081,20 +1016,6 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image - def get_guide_frame_count(self, x, conds): - for cond_list in conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - gae = model_conds.get('guide_attention_entries') - if gae is not None and hasattr(gae, 'cond') and gae.cond: - return sum(e["latent_shape"][0] for e in gae.cond) - return 0 - - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - return self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list) - class LTXAV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO @@ -1193,17 +1114,64 @@ class LTXAV(BaseModel): return 0 def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - # Audio-specific handling + # Audio denoise mask — slice using audio modality window if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: audio_window = window.modality_windows.get(1) if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) return cond_value._copy_with(sliced) - # Guide handling (shared with LTXV) - result = self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list) - if result is not None: - return result + # Video denoise mask — split into video + guide portions, slice each + if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + cond_tensor = cond_value.cond + guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) + if guide_count > 0: + T_video = x_in.size(window.dim) + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + suffix_indices = window.guide_suffix_indices + if suffix_indices: + idx = tuple([slice(None)] * window.dim + [suffix_indices]) + sliced_guide = guide_mask[idx].to(device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + else: + return cond_value._copy_with(sliced_video) + + # Keyframe indices — regenerate pixel coords for window, select guide positions + if cond_key == "keyframe_idxs": + kf_local_pos = window.guide_kf_local_positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + # Guide attention entries — adjust per-guide counts based on window overlap + if cond_key == "guide_attention_entries": + overlap_info = window.guide_overlap_info + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) return None From 874690c01ca949b0479ce09bb64c37ff68b0e6e1 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 6 Apr 2026 11:44:14 -0600 Subject: [PATCH 10/23] LTX2 context windows - Refactor guide logic from context_windows into LTXAV model hooks --- comfy/context_windows.py | 88 ++++++++++++---------------------------- comfy/model_base.py | 51 +++++++++++++++++++++++ 2 files changed, 76 insertions(+), 63 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 295f348e6..9e7282fda 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable import torch import numpy as np import collections @@ -181,6 +181,12 @@ def _compute_guide_overlap(guide_entries, window_index_list): return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) +@dataclass +class WindowingContext: + tensor: torch.Tensor + suffix: torch.Tensor | None + aux_data: Any + @dataclass class ContextSchedule: name: str @@ -242,18 +248,6 @@ class IndexListContextHandler(ContextHandlerABC): if 'latent_shapes' in model_conds: model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) - def _get_guide_entries(self, conds): - """Extract guide_attention_entries list from conditioning. Returns None if absent.""" - for cond_list in conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - gae = model_conds.get('guide_attention_entries') - if gae is not None and hasattr(gae, 'cond') and gae.cond: - return gae.cond - return None - def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: latent_shapes = self._get_latent_shapes(conds) primary = self._decompose(x_in, latent_shapes)[0] @@ -379,24 +373,19 @@ class IndexListContextHandler(ContextHandlerABC): is_multimodal = len(modalities) > 1 primary = modalities[0] - # Separate guide frames from primary modality (guides are appended at the end) - guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0 - if guide_count > 0: - video_len = primary.size(self.dim) - guide_count - video_primary = primary.narrow(self.dim, 0, video_len) - guide_suffix = primary.narrow(self.dim, video_len, guide_count) - else: - video_primary = primary - guide_suffix = None + # Let model strip auxiliary frames (e.g. guide frames) + window_data = model.prepare_for_windowing(primary, conds, self.dim) + video_primary = window_data.tensor + aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0 - # Windows from video portion only (excluding guide frames) + # Windows from video portion only context_windows = self.get_context_windows(model, video_primary, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) # Accumulators sized to video portion for primary, full for other modalities accum_modalities = list(modalities) - if guide_suffix is not None: + if window_data.suffix is not None: accum_modalities[0] = video_primary accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities] @@ -406,25 +395,22 @@ class IndexListContextHandler(ContextHandlerABC): counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities] - guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}" - + (f" (+{guide_count} guide)" if guide_count > 0 else "") + + (f" (+{aux_count} aux)" if aux_count > 0 else "") + (f" [{len(modalities)} modalities]" if is_multimodal else "")) # Per-modality window indices if is_multimodal: - # Adjust latent_shapes so video shape reflects video-only frames (excludes guides) map_shapes = latent_shapes - if guide_count > 0: + if video_primary.size(self.dim) != primary.size(self.dim): map_shapes = list(latent_shapes) video_shape = list(latent_shapes[0]) - video_shape[self.dim] = video_shape[self.dim] - guide_count + video_shape[self.dim] = video_primary.size(self.dim) map_shapes[0] = torch.Size(video_shape) per_mod_indices = model.map_context_window_to_modalities( window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list] @@ -446,30 +432,10 @@ class IndexListContextHandler(ContextHandlerABC): for mod_idx in range(1, len(modalities)): mod_windows.append(modality_windows[mod_idx]) - # Slice video, then select overlapping guide frames + # Slice video, then let model inject auxiliary frames sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list) - num_guide_in_window = 0 - if guide_suffix is not None and guide_entries is not None: - overlap = _compute_guide_overlap(guide_entries, window.index_list) - if overlap[3] > 0: - suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap - idx = tuple([slice(None)] * self.dim + [suffix_idx]) - sliced_guide = guide_suffix[idx] - window.guide_suffix_indices = suffix_idx - window.guide_overlap_info = overlap_info - window.guide_kf_local_positions = kf_local_pos - else: - sliced_guide = None - window.guide_suffix_indices = [] - window.guide_overlap_info = [] - window.guide_kf_local_positions = [] - else: - sliced_guide = None - - if sliced_guide is not None: - sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim) - else: - sliced_primary = sliced_video + sliced_primary, num_aux = model.prepare_window_input( + sliced_video, window, window_data.aux_data, self.dim) sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))] # Compose for pipeline @@ -481,7 +447,6 @@ class IndexListContextHandler(ContextHandlerABC): model_options["transformer_options"]["context_window"] = window sub_timestep = window.get_tensor(timestep, dim=0) - # Resize conds using video_primary as reference (excludes guide frames) sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds] if is_multimodal: self._patch_latent_shapes(sub_conds, sub_shapes) @@ -490,14 +455,12 @@ class IndexListContextHandler(ContextHandlerABC): # Decompose output per modality out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] - # out_per_mod[cond_idx][mod_idx] = tensor - # Strip guide frames from primary output before accumulation - if num_guide_in_window > 0: + # Strip auxiliary frames from primary output before accumulation + if num_aux > 0: window_len = len(window.index_list) for ci in range(len(sub_conds_out)): - primary_out = out_per_mod[ci][0] - out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len) + out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len) # Accumulate per modality (using video-only sizes) for mod_idx in range(len(accum_modalities)): @@ -516,10 +479,9 @@ class IndexListContextHandler(ContextHandlerABC): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] f = accum[mod_idx][ci] - # Re-append original guide_suffix (not model output — sampling loop - # respects denoise_mask and never modifies guide frame positions) - if mod_idx == 0 and guide_suffix is not None: - f = torch.cat([f, guide_suffix], dim=self.dim) + # Re-append model's suffix (auxiliary frames stripped before windowing) + if mod_idx == 0 and window_data.suffix is not None: + f = torch.cat([f, window_data.suffix], dim=self.dim) finalized.append(f) composed, _ = self._compose(finalized) result.append(composed) diff --git a/comfy/model_base.py b/comfy/model_base.py index 066e3f8d8..c1e2ef8c2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -287,6 +287,12 @@ class BaseModel(torch.nn.Module): return data return None + def prepare_for_windowing(self, primary, conds, dim): + return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None) + + def prepare_window_input(self, video_slice, window, aux_data, dim): + return video_slice, 0 + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): """Override in subclasses to handle model-specific cond slicing for context windows. Return a sliced cond object, or None to fall through to default handling. @@ -1113,6 +1119,51 @@ class LTXAV(BaseModel): return sum(e["latent_shape"][0] for e in gae.cond) return 0 + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + gae = model_conds.get('guide_attention_entries') + if gae is not None and hasattr(gae, 'cond') and gae.cond: + return gae.cond + return None + + def prepare_for_windowing(self, primary, conds, dim): + guide_count = self.get_guide_frame_count(primary, conds) + if guide_count <= 0: + return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None) + video_len = primary.size(dim) - guide_count + video_primary = primary.narrow(dim, 0, video_len) + guide_suffix = primary.narrow(dim, video_len, guide_count) + guide_entries = self._get_guide_entries(conds) + return comfy.context_windows.WindowingContext( + tensor=video_primary, suffix=guide_suffix, + aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix}) + + def prepare_window_input(self, video_slice, window, aux_data, dim): + if aux_data is None: + return video_slice, 0 + guide_entries = aux_data["guide_entries"] + guide_suffix = aux_data["guide_suffix"] + if guide_entries is None: + window.guide_suffix_indices = [] + window.guide_overlap_info = [] + window.guide_kf_local_positions = [] + return video_slice, 0 + overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list) + suffix_idx, overlap_info, kf_local_pos, num_guide = overlap + window.guide_suffix_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + if num_guide > 0: + idx = tuple([slice(None)] * dim + [suffix_idx]) + sliced_guide = guide_suffix[idx] + return torch.cat([video_slice, sliced_guide], dim=dim), num_guide + return video_slice, 0 + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): # Audio denoise mask — slice using audio modality window if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: From 3a061f4bbfc6f61367656ef6e18d0caa5271805b Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 6 Apr 2026 15:13:46 -0600 Subject: [PATCH 11/23] LTX2 context windows - Cleanup: Simplify IndexListContextHandler standard execute path --- comfy/context_windows.py | 66 +++++++++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 9e7282fda..e89c9cee2 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -367,18 +367,60 @@ class IndexListContextHandler(ContextHandlerABC): self._model = model self.set_step(timestep, model_options) - # Decompose — single-modality: [x_in], multimodal: [video, audio, ...] + # Check if multimodal or model has auxiliary frames requiring the extended path latent_shapes = self._get_latent_shapes(conds) + is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 + if is_multimodal: + return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, latent_shapes) + window_data = model.prepare_for_windowing(x_in, conds, self.dim) + if window_data.suffix is not None or window_data.aux_data is not None: + return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, + latent_shapes, window_data) + + context_windows = self.get_context_windows(model, x_in, model_options) + enumerated_context_windows = list(enumerate(context_windows)) + + conds_final = [torch.zeros_like(x_in) for _ in conds] + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + else: + counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + for enum_window in enumerated_context_windows: + results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + for result in results: + self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, + conds_final, counts_final, biases_final) + try: + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + del counts_final + return conds_final + else: + for i in range(len(conds_final)): + conds_final[i] /= counts_final[i] + del counts_final + return conds_final + finally: + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, + timestep: torch.Tensor, model_options: dict[str], + latent_shapes, window_data: WindowingContext=None): + """Extended execute path for multimodal models and models with auxiliary frames.""" modalities = self._decompose(x_in, latent_shapes) is_multimodal = len(modalities) > 1 - primary = modalities[0] - # Let model strip auxiliary frames (e.g. guide frames) - window_data = model.prepare_for_windowing(primary, conds, self.dim) + if window_data is None: + window_data = model.prepare_for_windowing(modalities[0], conds, self.dim) + video_primary = window_data.tensor aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0 - # Windows from video portion only context_windows = self.get_context_windows(model, video_primary, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) @@ -407,14 +449,13 @@ class IndexListContextHandler(ContextHandlerABC): # Per-modality window indices if is_multimodal: map_shapes = latent_shapes - if video_primary.size(self.dim) != primary.size(self.dim): + if video_primary.size(self.dim) != modalities[0].size(self.dim): map_shapes = list(latent_shapes) video_shape = list(latent_shapes[0]) video_shape[self.dim] = video_primary.size(self.dim) map_shapes[0] = torch.Size(video_shape) per_mod_indices = model.map_context_window_to_modalities( window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list] - # Build per-modality windows and attach to primary window modality_windows = {} for mod_idx in range(1, len(modalities)): modality_windows[mod_idx] = IndexListContextWindow( @@ -423,11 +464,9 @@ class IndexListContextHandler(ContextHandlerABC): window = IndexListContextWindow( window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim], modality_windows=modality_windows) - else: - per_mod_indices = [window.index_list] - # Build per-modality windows list (including primary) - mod_windows = [window] # primary window at index 0 + # Build per-modality windows list + mod_windows = [window] if is_multimodal: for mod_idx in range(1, len(modalities)): mod_windows.append(modality_windows[mod_idx]) @@ -438,10 +477,8 @@ class IndexListContextHandler(ContextHandlerABC): sliced_video, window, window_data.aux_data, self.dim) sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))] - # Compose for pipeline sub_x, sub_shapes = self._compose(sliced) - # Callbacks for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None) @@ -462,7 +499,7 @@ class IndexListContextHandler(ContextHandlerABC): for ci in range(len(sub_conds_out)): out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len) - # Accumulate per modality (using video-only sizes) + # Accumulate per modality for mod_idx in range(len(accum_modalities)): mw = mod_windows[mod_idx] mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))] @@ -479,7 +516,6 @@ class IndexListContextHandler(ContextHandlerABC): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] f = accum[mod_idx][ci] - # Re-append model's suffix (auxiliary frames stripped before windowing) if mod_idx == 0 and window_data.suffix is not None: f = torch.cat([f, window_data.suffix], dim=self.dim) finalized.append(f) From f1acd5bd858b732943fea784f21bbbf7afd98ce7 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:43:41 -0600 Subject: [PATCH 12/23] LTX2 context windows - Cleanup: Simplify window data handling, improve variable names, refactor and condense new context window methods to separate execution paths cleanly --- comfy/context_windows.py | 242 ++++++++++++++++++++++++--------------- comfy/model_base.py | 86 ++++++-------- 2 files changed, 183 insertions(+), 145 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index e89c9cee2..2ec927f3e 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -140,15 +140,15 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) -def _compute_guide_overlap(guide_entries, window_index_list): - """Compute which guide frames overlap with a context window. +def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int]): + """Compute which concatenated guide frames overlap with a context window. Args: - guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape') + guide_entries: list of guide_attention_entry dicts window_index_list: the window's frame indices into the video portion - Returns (suffix_indices, overlap_info, kf_local_positions, total_overlap): - suffix_indices: indices into the guide_suffix tensor for frame selection + Returns: + suffix_indices: indices into the guide_frames tensor for frame selection overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment kf_local_positions: window-local frame positions for keyframe_idxs regeneration total_overlap: total number of overlapping guide frames @@ -181,11 +181,37 @@ def _compute_guide_overlap(guide_entries, window_index_list): return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) +def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC, + aux_data: dict, dim: int) -> tuple[torch.Tensor, int]: + """Inject overlapping guide frames into a context window slice. + + Uses aux_data from WindowingContext to determine which guide frames overlap + with this window's indices, concatenates them onto the video slice, and sets + window attributes for downstream conditioning resize. + + Returns (augmented_slice, num_guide_frames_added). + """ + guide_entries = aux_data["guide_entries"] + guide_frames = aux_data["guide_frames"] + overlap = compute_guide_overlap(guide_entries, window.index_list) + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap + window.guide_frames_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + if guide_frame_count > 0: + idx = tuple([slice(None)] * dim + [suffix_idx]) + sliced_guide = guide_frames[idx] + return torch.cat([video_slice, sliced_guide], dim=dim), guide_frame_count + return video_slice, 0 + + @dataclass class WindowingContext: tensor: torch.Tensor - suffix: torch.Tensor | None + guide_frames: torch.Tensor | None aux_data: Any + latent_shapes: list | None + is_multimodal: bool @dataclass class ContextSchedule: @@ -215,8 +241,8 @@ class IndexListContextHandler(ContextHandlerABC): self.callbacks = {} - def _get_latent_shapes(self, conds): - """Extract latent_shapes from conditioning. Returns None if absent.""" + @staticmethod + def _get_latent_shapes(conds): for cond_list in conds: if cond_list is None: continue @@ -226,20 +252,20 @@ class IndexListContextHandler(ContextHandlerABC): return model_conds['latent_shapes'].cond return None - def _decompose(self, x, latent_shapes): - """Packed tensor -> list of per-modality tensors.""" + @staticmethod + def _unpack(combined_latent, latent_shapes): if latent_shapes is not None and len(latent_shapes) > 1: - return comfy.utils.unpack_latents(x, latent_shapes) - return [x] + return comfy.utils.unpack_latents(combined_latent, latent_shapes) + return [combined_latent] - def _compose(self, modalities): - """List of per-modality tensors -> single tensor for pipeline.""" - if len(modalities) > 1: - return comfy.utils.pack_latents(modalities) - return modalities[0], [modalities[0].shape] + @staticmethod + def _pack(latents): + if len(latents) > 1: + return comfy.utils.pack_latents(latents) + return latents[0], [latents[0].shape] - def _patch_latent_shapes(self, sub_conds, new_shapes): - """Patch latent_shapes CONDConstant in (already-copied) sub_conds.""" + @staticmethod + def _patch_latent_shapes(sub_conds, new_shapes): for cond_list in sub_conds: if cond_list is None: continue @@ -248,14 +274,48 @@ class IndexListContextHandler(ContextHandlerABC): if 'latent_shapes' in model_conds: model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) - def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingContext: latent_shapes = self._get_latent_shapes(conds) - primary = self._decompose(x_in, latent_shapes)[0] - guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0 - video_frames = primary.size(self.dim) - guide_count + is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 + if is_multimodal: + video_latent = comfy.utils.unpack_latents(x_in, latent_shapes)[0] + else: + video_latent = x_in + + guide_entries = None + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + guide_entries = entries.cond + break + if guide_entries is not None: + break + + guide_frame_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries is not None else 0 + primary_frame_count = video_latent.size(self.dim) - guide_frame_count + primary_frames = video_latent.narrow(self.dim, 0, primary_frame_count) + guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None + + if guide_frame_count > 0: + aux_data = {"guide_entries": guide_entries, "guide_frames": guide_frames} + else: + aux_data = None + + return WindowingContext( + tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data, + latent_shapes=latent_shapes, is_multimodal=is_multimodal) + + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + self._window_data = self._build_window_data(x_in, conds) + video_frames = self._window_data.tensor.size(self.dim) + guide_frames = self._window_data.guide_frames.size(self.dim) if self._window_data.guide_frames is not None else 0 if video_frames > self.context_length: - if guide_count > 0: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_count} guide frames excluded).") + if guide_frames > 0: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_frames} guide frames).") else: logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.") if self.cond_retain_index_list: @@ -367,15 +427,9 @@ class IndexListContextHandler(ContextHandlerABC): self._model = model self.set_step(timestep, model_options) - # Check if multimodal or model has auxiliary frames requiring the extended path - latent_shapes = self._get_latent_shapes(conds) - is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 - if is_multimodal: - return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, latent_shapes) - window_data = model.prepare_for_windowing(x_in, conds, self.dim) - if window_data.suffix is not None or window_data.aux_data is not None: - return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, - latent_shapes, window_data) + window_data = self._window_data + if window_data.is_multimodal or (window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0): + return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data) context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) @@ -410,101 +464,104 @@ class IndexListContextHandler(ContextHandlerABC): def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str], - latent_shapes, window_data: WindowingContext=None): - """Extended execute path for multimodal models and models with auxiliary frames.""" - modalities = self._decompose(x_in, latent_shapes) - is_multimodal = len(modalities) > 1 + window_data: WindowingContext): + """Extended execute path for multimodal models and models with guide frames appended to the noise latent.""" + latents = self._unpack(x_in, window_data.latent_shapes) + is_multimodal = window_data.is_multimodal - if window_data is None: - window_data = model.prepare_for_windowing(modalities[0], conds, self.dim) + primary_frames = window_data.tensor + num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0 - video_primary = window_data.tensor - aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0 - - context_windows = self.get_context_windows(model, video_primary, model_options) + context_windows = self.get_context_windows(model, primary_frames, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) # Accumulators sized to video portion for primary, full for other modalities - accum_modalities = list(modalities) - if window_data.suffix is not None: - accum_modalities[0] = video_primary + accum_shape_refs = list(latents) + if window_data.guide_frames is not None: + accum_shape_refs[0] = primary_frames - accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities] + accum = [[torch.zeros_like(m) for _ in conds] for m in accum_shape_refs] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] else: - counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] - biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_shape_refs] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() - logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}" - + (f" (+{aux_count} aux)" if aux_count > 0 else "") - + (f" [{len(modalities)} modalities]" if is_multimodal else "")) + logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}" + + (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "") + + (f" [{len(latents)} modalities]" if is_multimodal else "")) # Per-modality window indices if is_multimodal: - map_shapes = latent_shapes - if video_primary.size(self.dim) != modalities[0].size(self.dim): - map_shapes = list(latent_shapes) - video_shape = list(latent_shapes[0]) - video_shape[self.dim] = video_primary.size(self.dim) + map_shapes = window_data.latent_shapes + if primary_frames.size(self.dim) != latents[0].size(self.dim): + map_shapes = list(window_data.latent_shapes) + video_shape = list(window_data.latent_shapes[0]) + video_shape[self.dim] = primary_frames.size(self.dim) map_shapes[0] = torch.Size(video_shape) - per_mod_indices = model.map_context_window_to_modalities( - window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list] + try: + per_modality_indices = model.map_context_window_to_modalities( + window.index_list, map_shapes, self.dim) + except AttributeError: + raise NotImplementedError( + f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") modality_windows = {} - for mod_idx in range(1, len(modalities)): + for mod_idx in range(1, len(latents)): modality_windows[mod_idx] = IndexListContextWindow( - per_mod_indices[mod_idx], dim=self.dim, - total_frames=modalities[mod_idx].shape[self.dim]) + per_modality_indices[mod_idx], dim=self.dim, + total_frames=latents[mod_idx].shape[self.dim]) window = IndexListContextWindow( - window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim], + window.index_list, dim=self.dim, total_frames=primary_frames.shape[self.dim], modality_windows=modality_windows) # Build per-modality windows list - mod_windows = [window] + per_modality_windows_list = [window] if is_multimodal: - for mod_idx in range(1, len(modalities)): - mod_windows.append(modality_windows[mod_idx]) + for mod_idx in range(1, len(latents)): + per_modality_windows_list.append(modality_windows[mod_idx]) - # Slice video, then let model inject auxiliary frames - sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list) - sliced_primary, num_aux = model.prepare_window_input( - sliced_video, window, window_data.aux_data, self.dim) - sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))] + # Slice video, then inject overlapping guide frames if present + sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list) + if window_data.aux_data is not None: + sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data.aux_data, self.dim) + else: + sliced_primary, num_guide_frames = sliced_video, 0 + sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))] - sub_x, sub_shapes = self._compose(sliced) + sub_x, sub_shapes = self._pack(sliced) for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None) model_options["transformer_options"]["context_window"] = window sub_timestep = window.get_tensor(timestep, dim=0) - sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds] + sub_conds = [self.get_resized_cond(cond, primary_frames, window) for cond in conds] if is_multimodal: self._patch_latent_shapes(sub_conds, sub_shapes) sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - # Decompose output per modality - out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + # Unpack output per modality + out_per_modality = [self._unpack(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] - # Strip auxiliary frames from primary output before accumulation - if num_aux > 0: + # Strip guide frames from primary output before accumulation + if num_guide_frames > 0: window_len = len(window.index_list) for ci in range(len(sub_conds_out)): - out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len) + out_per_modality[ci][0] = out_per_modality[ci][0].narrow(self.dim, 0, window_len) # Accumulate per modality - for mod_idx in range(len(accum_modalities)): - mw = mod_windows[mod_idx] - mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))] + for mod_idx in range(len(accum_shape_refs)): + mw = per_modality_windows_list[mod_idx] + sub_conds_out_per_modality = [out_per_modality[ci][mod_idx] for ci in range(len(sub_conds_out))] self.combine_context_window_results( - accum_modalities[mod_idx], mod_sub_out, sub_conds, mw, + accum_shape_refs[mod_idx], sub_conds_out_per_modality, sub_conds, mw, window_idx, total_windows, timestep, accum[mod_idx], counts[mod_idx], biases[mod_idx]) @@ -512,15 +569,15 @@ class IndexListContextHandler(ContextHandlerABC): result = [] for ci in range(len(conds)): finalized = [] - for mod_idx in range(len(accum_modalities)): + for mod_idx in range(len(accum_shape_refs)): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] f = accum[mod_idx][ci] - if mod_idx == 0 and window_data.suffix is not None: - f = torch.cat([f, window_data.suffix], dim=self.dim) + if mod_idx == 0 and window_data.guide_frames is not None: + f = torch.cat([f, window_data.guide_frames], dim=self.dim) finalized.append(f) - composed, _ = self._compose(finalized) - result.append(composed) + packed, _ = self._pack(finalized) + result.append(packed) return result finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): @@ -616,11 +673,8 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois # For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise # must only shuffle the video portion. Unpack, apply to video, repack. - latent_shapes = None - try: - latent_shapes = guider.conds['positive'][0]['model_conds']['latent_shapes'].cond - except (KeyError, IndexError, AttributeError): - pass + latent_shapes = IndexListContextHandler._get_latent_shapes( + [guider.conds.get('positive', guider.conds.get('negative', []))]) if latent_shapes is not None and len(latent_shapes) > 1: modalities = comfy.utils.unpack_latents(noise, latent_shapes) diff --git a/comfy/model_base.py b/comfy/model_base.py index c1e2ef8c2..65ce1bac5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -287,12 +287,6 @@ class BaseModel(torch.nn.Module): return data return None - def prepare_for_windowing(self, primary, conds, dim): - return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None) - - def prepare_window_input(self, video_slice, window, aux_data, dim): - return video_slice, 0 - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): """Override in subclasses to handle model-specific cond slicing for context windows. Return a sliced cond object, or None to fall through to default handling. @@ -1098,7 +1092,7 @@ class LTXAV(BaseModel): for i in range(1, len(latent_shapes)): mod_total = latent_shapes[i][dim] - # Length proportional to video window frame count (not index span) + # Length proportional to video window frame count mod_window_len = max(round(video_window_len * mod_total / video_total), 1) # Anchor to end of video range v_end = max(primary_indices) + 1 @@ -1108,17 +1102,6 @@ class LTXAV(BaseModel): return result - def get_guide_frame_count(self, x, conds): - for cond_list in conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - gae = model_conds.get('guide_attention_entries') - if gae is not None and hasattr(gae, 'cond') and gae.cond: - return sum(e["latent_shape"][0] for e in gae.cond) - return 0 - @staticmethod def _get_guide_entries(conds): for cond_list in conds: @@ -1126,43 +1109,27 @@ class LTXAV(BaseModel): continue for cond_dict in cond_list: model_conds = cond_dict.get('model_conds', {}) - gae = model_conds.get('guide_attention_entries') - if gae is not None and hasattr(gae, 'cond') and gae.cond: - return gae.cond + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond return None - - def prepare_for_windowing(self, primary, conds, dim): - guide_count = self.get_guide_frame_count(primary, conds) + + def prepare_window_data(self, x_in, conds, dim, window_data): + primary = comfy.utils.unpack_latents(x_in, window_data.latent_shapes)[0] if window_data.is_multimodal else x_in + guide_entries = self._get_guide_entries(conds) + guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 if guide_count <= 0: - return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None) + return comfy.context_windows.WindowingContext( + tensor=primary, guide_frames=None, aux_data=None, + latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal) video_len = primary.size(dim) - guide_count video_primary = primary.narrow(dim, 0, video_len) - guide_suffix = primary.narrow(dim, video_len, guide_count) - guide_entries = self._get_guide_entries(conds) + guide_frames = primary.narrow(dim, video_len, guide_count) return comfy.context_windows.WindowingContext( - tensor=video_primary, suffix=guide_suffix, - aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix}) + tensor=video_primary, guide_frames=guide_frames, + aux_data={"guide_entries": guide_entries, "guide_frames": guide_frames}, + latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal) - def prepare_window_input(self, video_slice, window, aux_data, dim): - if aux_data is None: - return video_slice, 0 - guide_entries = aux_data["guide_entries"] - guide_suffix = aux_data["guide_suffix"] - if guide_entries is None: - window.guide_suffix_indices = [] - window.guide_overlap_info = [] - window.guide_kf_local_positions = [] - return video_slice, 0 - overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list) - suffix_idx, overlap_info, kf_local_pos, num_guide = overlap - window.guide_suffix_indices = suffix_idx - window.guide_overlap_info = overlap_info - window.guide_kf_local_positions = kf_local_pos - if num_guide > 0: - idx = tuple([slice(None)] * dim + [suffix_idx]) - sliced_guide = guide_suffix[idx] - return torch.cat([video_slice, sliced_guide], dim=dim), num_guide - return video_slice, 0 def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): # Audio denoise mask — slice using audio modality window @@ -1181,7 +1148,7 @@ class LTXAV(BaseModel): video_mask = cond_tensor.narrow(window.dim, 0, T_video) guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - suffix_indices = window.guide_suffix_indices + suffix_indices = window.guide_frames_indices if suffix_indices: idx = tuple([slice(None)] * window.dim + [suffix_indices]) sliced_guide = guide_mask[idx].to(device) @@ -1199,14 +1166,31 @@ class LTXAV(BaseModel): patchifier = self.diffusion_model.patchifier latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + scale_factors = self.diffusion_model.vae_scale_factors pixel_coords = latent_to_pixel_coords( latent_coords, - self.diffusion_model.vae_scale_factors, + scale_factors, causal_fix=self.diffusion_model.causal_temporal_positioning) tokens = [] for pos in kf_local_pos: tokens.extend(range(pos * H * W, (pos + 1) * H * W)) pixel_coords = pixel_coords[:, :, tokens, :] + + # Adjust spatial end positions for dilated (downscaled) guides. + # Each guide entry may have a different downscale factor; expand the + # per-entry factor to cover all tokens belonging to that entry. + downscale_factors = getattr(window, 'guide_downscale_factors', []) + overlap_info = window.guide_overlap_info + if downscale_factors: + per_token_factor = [] + for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors): + per_token_factor.extend([dsf] * (overlap_count * H * W)) + factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype) + spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor( + scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset + B = cond_value.cond.shape[0] if B > 1: pixel_coords = pixel_coords.expand(B, -1, -1, -1) From c9edd2d7c0521e7ed8c46599eb9b930d30e71911 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:44:28 -0600 Subject: [PATCH 13/23] LTX2 context windows - Add handling for downscaled IC-Lora guide frames --- comfy/context_windows.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2ec927f3e..f955d4b67 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -198,6 +198,15 @@ def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWi window.guide_frames_indices = suffix_idx window.guide_overlap_info = overlap_info window.guide_kf_local_positions = kf_local_pos + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. + # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. + guide_downscale_factors = [] + if guide_frame_count > 0: + full_H = guide_frames.shape[3] + for entry_idx, _ in overlap_info: + entry_H = guide_entries[entry_idx]["latent_shape"][1] + guide_downscale_factors.append(full_H // entry_H) + window.guide_downscale_factors = guide_downscale_factors if guide_frame_count > 0: idx = tuple([slice(None)] * dim + [suffix_idx]) sliced_guide = guide_frames[idx] From d5badc5f380729cd3fbef6a3742d06df3b36419f Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:00:38 -0600 Subject: [PATCH 14/23] LTX2 context windows - Clean up unnecessary code --- comfy/context_windows.py | 30 ++++++++++++++++++------------ comfy/model_base.py | 17 ----------------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index f955d4b67..6e21bdc81 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -182,22 +182,22 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC, - aux_data: dict, dim: int) -> tuple[torch.Tensor, int]: + window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]: """Inject overlapping guide frames into a context window slice. - Uses aux_data from WindowingContext to determine which guide frames overlap - with this window's indices, concatenates them onto the video slice, and sets - window attributes for downstream conditioning resize. + Determines which guide frames overlap with this window's indices, concatenates + them onto the video slice, and sets window attributes for downstream conditioning resize. Returns (augmented_slice, num_guide_frames_added). """ - guide_entries = aux_data["guide_entries"] - guide_frames = aux_data["guide_frames"] + guide_entries = window_data.aux_data["guide_entries"] + guide_frames = window_data.guide_frames overlap = compute_guide_overlap(guide_entries, window.index_list) suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap window.guide_frames_indices = suffix_idx window.guide_overlap_info = overlap_info window.guide_kf_local_positions = kf_local_pos + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. guide_downscale_factors = [] @@ -207,6 +207,7 @@ def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWi entry_H = guide_entries[entry_idx]["latent_shape"][1] guide_downscale_factors.append(full_H // entry_H) window.guide_downscale_factors = guide_downscale_factors + if guide_frame_count > 0: idx = tuple([slice(None)] * dim + [suffix_idx]) sliced_guide = guide_frames[idx] @@ -220,7 +221,6 @@ class WindowingContext: guide_frames: torch.Tensor | None aux_data: Any latent_shapes: list | None - is_multimodal: bool @dataclass class ContextSchedule: @@ -310,13 +310,13 @@ class IndexListContextHandler(ContextHandlerABC): guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None if guide_frame_count > 0: - aux_data = {"guide_entries": guide_entries, "guide_frames": guide_frames} + aux_data = {"guide_entries": guide_entries} else: aux_data = None return WindowingContext( tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data, - latent_shapes=latent_shapes, is_multimodal=is_multimodal) + latent_shapes=latent_shapes) def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: self._window_data = self._build_window_data(x_in, conds) @@ -437,9 +437,14 @@ class IndexListContextHandler(ContextHandlerABC): self.set_step(timestep, model_options) window_data = self._window_data - if window_data.is_multimodal or (window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0): + is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 + has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0 + + # if multimodal or has concatenated guide frames on noise latent, use the extended execute path + if is_multimodal or has_guide_frames: return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data) + # basic legacy execution path for single-modal video latent with no guide frames concatenated context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) @@ -475,8 +480,9 @@ class IndexListContextHandler(ContextHandlerABC): timestep: torch.Tensor, model_options: dict[str], window_data: WindowingContext): """Extended execute path for multimodal models and models with guide frames appended to the noise latent.""" + latents = self._unpack(x_in, window_data.latent_shapes) - is_multimodal = window_data.is_multimodal + is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 primary_frames = window_data.tensor num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0 @@ -538,7 +544,7 @@ class IndexListContextHandler(ContextHandlerABC): # Slice video, then inject overlapping guide frames if present sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list) if window_data.aux_data is not None: - sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data.aux_data, self.dim) + sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim) else: sliced_primary, num_guide_frames = sliced_video, 0 sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))] diff --git a/comfy/model_base.py b/comfy/model_base.py index 65ce1bac5..8960ecd19 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1113,23 +1113,6 @@ class LTXAV(BaseModel): if entries is not None and hasattr(entries, 'cond') and entries.cond: return entries.cond return None - - def prepare_window_data(self, x_in, conds, dim, window_data): - primary = comfy.utils.unpack_latents(x_in, window_data.latent_shapes)[0] if window_data.is_multimodal else x_in - guide_entries = self._get_guide_entries(conds) - guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 - if guide_count <= 0: - return comfy.context_windows.WindowingContext( - tensor=primary, guide_frames=None, aux_data=None, - latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal) - video_len = primary.size(dim) - guide_count - video_primary = primary.narrow(dim, 0, video_len) - guide_frames = primary.narrow(dim, video_len, guide_count) - return comfy.context_windows.WindowingContext( - tensor=video_primary, guide_frames=guide_frames, - aux_data={"guide_entries": guide_entries, "guide_frames": guide_frames}, - latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal) - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): # Audio denoise mask — slice using audio modality window From d1a9e2e4df261466414526a93b7350e0729fc4b5 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:25:31 -0600 Subject: [PATCH 15/23] Fix whitespace --- comfy/context_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 6e21bdc81..2723c77ff 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -197,7 +197,7 @@ def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWi window.guide_frames_indices = suffix_idx window.guide_overlap_info = overlap_info window.guide_kf_local_positions = kf_local_pos - + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. guide_downscale_factors = [] From 88643f3978c79dc794c641c4a939f7899b60df39 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:11:19 -0600 Subject: [PATCH 16/23] Fix logging of guide frame number --- comfy/context_windows.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2723c77ff..964f7dd33 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -485,8 +485,6 @@ class IndexListContextHandler(ContextHandlerABC): is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 primary_frames = window_data.tensor - num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0 - context_windows = self.get_context_windows(model, primary_frames, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) @@ -508,9 +506,6 @@ class IndexListContextHandler(ContextHandlerABC): for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() - logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}" - + (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "") - + (f" [{len(latents)} modalities]" if is_multimodal else "")) # Per-modality window indices if is_multimodal: @@ -547,6 +542,9 @@ class IndexListContextHandler(ContextHandlerABC): sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim) else: sliced_primary, num_guide_frames = sliced_video, 0 + logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}" + + (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "") + + (f" [{len(latents)} modalities]" if is_multimodal else "")) sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))] sub_x, sub_shapes = self._pack(sliced) From ae3830a6d2028f4ccadc021ef6503752db2cff71 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sat, 11 Apr 2026 11:31:04 -0600 Subject: [PATCH 17/23] LTX2 Context Windows - Collect multimodal methods into WindowingState; Condense execution path to treat all latents as potentially multimodal --- comfy/context_windows.py | 501 +++++++++++++++++++-------------------- comfy/model_base.py | 2 +- 2 files changed, 245 insertions(+), 258 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 964f7dd33..a9f456426 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Callable import torch import numpy as np import collections @@ -60,6 +60,10 @@ class IndexListContextWindow(ContextWindowABC): self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} + self.guide_frames_indices: list[int] = [] + self.guide_overlap_info: list[tuple[int, int]] = [] + self.guide_kf_local_positions: list[int] = [] + self.guide_downscale_factors: list[int] = [] def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: @@ -84,6 +88,11 @@ class IndexListContextWindow(ContextWindowABC): region_idx = int(self.center_ratio * num_regions) return min(max(region_idx, 0), num_regions - 1) + def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow': + if modality_idx == 0: + return self + return self.modality_windows[modality_idx] + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -181,46 +190,109 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) -def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC, - window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]: - """Inject overlapping guide frames into a context window slice. - - Determines which guide frames overlap with this window's indices, concatenates - them onto the video slice, and sets window attributes for downstream conditioning resize. - - Returns (augmented_slice, num_guide_frames_added). - """ - guide_entries = window_data.aux_data["guide_entries"] - guide_frames = window_data.guide_frames - overlap = compute_guide_overlap(guide_entries, window.index_list) - suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap - window.guide_frames_indices = suffix_idx - window.guide_overlap_info = overlap_info - window.guide_kf_local_positions = kf_local_pos - - # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. - # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. - guide_downscale_factors = [] - if guide_frame_count > 0: - full_H = guide_frames.shape[3] - for entry_idx, _ in overlap_info: - entry_H = guide_entries[entry_idx]["latent_shape"][1] - guide_downscale_factors.append(full_H // entry_H) - window.guide_downscale_factors = guide_downscale_factors - - if guide_frame_count > 0: - idx = tuple([slice(None)] * dim + [suffix_idx]) - sliced_guide = guide_frames[idx] - return torch.cat([video_slice, sliced_guide], dim=dim), guide_frame_count - return video_slice, 0 - - @dataclass -class WindowingContext: - tensor: torch.Tensor - guide_frames: torch.Tensor | None - aux_data: Any - latent_shapes: list | None +class WindowingState: + """Per-modality context windowing state for each step, + built using IndexListContextHandler._build_window_state(). + For non-multimodal models the lists are length 1 + """ + latents: list[torch.Tensor] # per-modality working latents (guide frames stripped) + guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents + guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata + latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal) + dim: int = 0 # primary modality temporal dim for context windowing + is_multimodal: bool = False + + def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow: + """Reformat window for multimodal contexts by deriving per-modality index lists. + Non-multimodal contexts return the input window unchanged. + """ + if not self.is_multimodal: + return window + + x = self.latents[0] + map_shapes = self.latent_shapes + if x.size(self.dim) != self.latent_shapes[0][self.dim]: + map_shapes = list(self.latent_shapes) + video_shape = list(self.latent_shapes[0]) + video_shape[self.dim] = x.size(self.dim) + map_shapes[0] = torch.Size(video_shape) + try: + per_modality_indices = model.map_context_window_to_modalities( + window.index_list, map_shapes, self.dim) + except AttributeError: + raise NotImplementedError( + f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") + modality_windows = {} + for mod_idx in range(1, len(self.latents)): + modality_windows[mod_idx] = IndexListContextWindow( + per_modality_indices[mod_idx], dim=self.dim, + total_frames=self.latents[mod_idx].shape[self.dim]) + return IndexListContextWindow( + window.index_list, dim=self.dim, total_frames=x.shape[self.dim], + modality_windows=modality_windows) + + def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]: + """Slice latents for a context window, injecting guide frames where applicable. + For multimodal contexts, uses the modality-specific windows derived in prepare_window(). + """ + sliced = [] + guide_frame_counts = [] + for idx in range(len(self.latents)): + modality_window = window.get_window_for_modality(idx) + retain = retain_index_list if idx == 0 else [] + s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain) + if self.guide_entries[idx] is not None: + s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx) + else: + ng = 0 + sliced.append(s) + guide_frame_counts.append(ng) + return sliced, guide_frame_counts + + def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow): + """Strip injected guide frames from per-cond, per-modality outputs in place.""" + for idx in range(len(self.latents)): + if guide_frame_counts[idx] > 0: + window_len = len(window.get_window_for_modality(idx).index_list) + for ci in range(len(out_per_modality)): + out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len) + + def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]: + guide_entries = self.guide_entries[modality_idx] + guide_frames = self.guide_latents[modality_idx] + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(guide_entries, window.index_list) + window.guide_frames_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. + # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. + guide_downscale_factors = [] + if guide_frame_count > 0: + full_H = guide_frames.shape[3] + for entry_idx, _ in overlap_info: + entry_H = guide_entries[entry_idx]["latent_shape"][1] + guide_downscale_factors.append(full_H // entry_H) + window.guide_downscale_factors = guide_downscale_factors + + if guide_frame_count > 0: + idx = tuple([slice(None)] * self.dim + [suffix_idx]) + return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count + return latent_slice, 0 + + def patch_latent_shapes(self, sub_conds, new_shapes): + if not self.is_multimodal: + return + + for cond_list in sub_conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + @dataclass class ContextSchedule: @@ -261,37 +333,35 @@ class IndexListContextHandler(ContextHandlerABC): return model_conds['latent_shapes'].cond return None - @staticmethod - def _unpack(combined_latent, latent_shapes): + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: + """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.""" + latent_shapes = self._get_latent_shapes(conds) if latent_shapes is not None and len(latent_shapes) > 1: - return comfy.utils.unpack_latents(combined_latent, latent_shapes) - return [combined_latent] + modalities = comfy.utils.unpack_latents(noise, latent_shapes) + primary_total = latent_shapes[0][self.dim] + modalities[0] = apply_freenoise(modalities[0], self.dim, self.context_length, self.context_overlap, seed) + for i in range(1, len(modalities)): + mod_total = latent_shapes[i][self.dim] + ratio = mod_total / primary_total if primary_total > 0 else 1 + mod_ctx_len = max(round(self.context_length * ratio), 1) + mod_ctx_overlap = max(round(self.context_overlap * ratio), 0) + modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) + noise, _ = comfy.utils.pack_latents(modalities) + return noise + return apply_freenoise(noise, self.dim, self.context_length, self.context_overlap, seed) - @staticmethod - def _pack(latents): - if len(latents) > 1: - return comfy.utils.pack_latents(latents) - return latents[0], [latents[0].shape] - - @staticmethod - def _patch_latent_shapes(sub_conds, new_shapes): - for cond_list in sub_conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - if 'latent_shapes' in model_conds: - model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) - - def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingContext: + def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState: + """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" latent_shapes = self._get_latent_shapes(conds) is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 - if is_multimodal: - video_latent = comfy.utils.unpack_latents(x_in, latent_shapes)[0] - else: - video_latent = x_in + unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in] - guide_entries = None + unpacked_latents_list = list(unpacked_latents) + guide_latents_list = [None] * len(unpacked_latents) + guide_entries_list = [None] * len(unpacked_latents) + + # Scan for 'guide_attention_entries' in conds + extracted_guide_entries = None for cond_list in conds: if cond_list is None: continue @@ -299,37 +369,39 @@ class IndexListContextHandler(ContextHandlerABC): model_conds = cond_dict.get('model_conds', {}) entries = model_conds.get('guide_attention_entries') if entries is not None and hasattr(entries, 'cond') and entries.cond: - guide_entries = entries.cond + extracted_guide_entries = entries.cond break - if guide_entries is not None: + if extracted_guide_entries is not None: break - guide_frame_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries is not None else 0 - primary_frame_count = video_latent.size(self.dim) - guide_frame_count - primary_frames = video_latent.narrow(self.dim, 0, primary_frame_count) - guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None + # Strip guide frames (only from first modality for now) + if extracted_guide_entries is not None: + guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries) + if guide_count > 0: + x = unpacked_latents[0] + latent_count = x.size(self.dim) - guide_count + unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count) + guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count) + guide_entries_list[0] = extracted_guide_entries - if guide_frame_count > 0: - aux_data = {"guide_entries": guide_entries} - else: - aux_data = None - return WindowingContext( - tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data, - latent_shapes=latent_shapes) + return WindowingState( + latents=unpacked_latents_list, + guide_latents=guide_latents_list, + guide_entries=guide_entries_list, + latent_shapes=latent_shapes, + dim=self.dim, + is_multimodal=is_multimodal) def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - self._window_data = self._build_window_data(x_in, conds) - video_frames = self._window_data.tensor.size(self.dim) - guide_frames = self._window_data.guide_frames.size(self.dim) if self._window_data.guide_frames is not None else 0 - if video_frames > self.context_length: - if guide_frames > 0: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_frames} guide frames).") - else: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.") + window_state = self._build_window_state(x_in, conds) # build window_state to check frame counts, will be built again in execute + total_frame_count = window_state.latents[0].size(self.dim) + if total_frame_count > self.context_length: + logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True + logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).") return False def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: @@ -436,188 +508,121 @@ class IndexListContextHandler(ContextHandlerABC): self._model = model self.set_step(timestep, model_options) - window_data = self._window_data - is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 - has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0 + window_state = self._build_window_state(x_in, conds) + num_modalities = len(window_state.latents) - # if multimodal or has concatenated guide frames on noise latent, use the extended execute path - if is_multimodal or has_guide_frames: - return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data) - - # basic legacy execution path for single-modal video latent with no guide frames concatenated - context_windows = self.get_context_windows(model, x_in, model_options) - enumerated_context_windows = list(enumerate(context_windows)) - - conds_final = [torch.zeros_like(x_in) for _ in conds] - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - else: - counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] - - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): - callback(self, model, x_in, conds, timestep, model_options) - - for enum_window in enumerated_context_windows: - results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) - for result in results: - self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) - try: - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - del counts_final - return conds_final - else: - for i in range(len(conds_final)): - conds_final[i] /= counts_final[i] - del counts_final - return conds_final - finally: - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): - callback(self, model, x_in, conds, timestep, model_options) - - def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, - timestep: torch.Tensor, model_options: dict[str], - window_data: WindowingContext): - """Extended execute path for multimodal models and models with guide frames appended to the noise latent.""" - - latents = self._unpack(x_in, window_data.latent_shapes) - is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 - - primary_frames = window_data.tensor - context_windows = self.get_context_windows(model, primary_frames, model_options) + context_windows = self.get_context_windows(model, window_state.latents[0], model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) - # Accumulators sized to video portion for primary, full for other modalities - accum_shape_refs = list(latents) - if window_data.guide_frames is not None: - accum_shape_refs[0] = primary_frames - - accum = [[torch.zeros_like(m) for _ in conds] for m in accum_shape_refs] + # Initialize per-modality accumulators (length 1 for single-modality) + accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] else: - counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] - biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_shape_refs] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - for window_idx, window in enumerated_context_windows: - comfy.model_management.throw_exception_if_processing_interrupted() - - # Per-modality window indices - if is_multimodal: - map_shapes = window_data.latent_shapes - if primary_frames.size(self.dim) != latents[0].size(self.dim): - map_shapes = list(window_data.latent_shapes) - video_shape = list(window_data.latent_shapes[0]) - video_shape[self.dim] = primary_frames.size(self.dim) - map_shapes[0] = torch.Size(video_shape) - try: - per_modality_indices = model.map_context_window_to_modalities( - window.index_list, map_shapes, self.dim) - except AttributeError: - raise NotImplementedError( - f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") - modality_windows = {} - for mod_idx in range(1, len(latents)): - modality_windows[mod_idx] = IndexListContextWindow( - per_modality_indices[mod_idx], dim=self.dim, - total_frames=latents[mod_idx].shape[self.dim]) - window = IndexListContextWindow( - window.index_list, dim=self.dim, total_frames=primary_frames.shape[self.dim], - modality_windows=modality_windows) - - # Build per-modality windows list - per_modality_windows_list = [window] - if is_multimodal: - for mod_idx in range(1, len(latents)): - per_modality_windows_list.append(modality_windows[mod_idx]) - - # Slice video, then inject overlapping guide frames if present - sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list) - if window_data.aux_data is not None: - sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim) - else: - sliced_primary, num_guide_frames = sliced_video, 0 - logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}" - + (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "") - + (f" [{len(latents)} modalities]" if is_multimodal else "")) - sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))] - - sub_x, sub_shapes = self._pack(sliced) - - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): - callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None) - - model_options["transformer_options"]["context_window"] = window - sub_timestep = window.get_tensor(timestep, dim=0) - sub_conds = [self.get_resized_cond(cond, primary_frames, window) for cond in conds] - if is_multimodal: - self._patch_latent_shapes(sub_conds, sub_shapes) - - sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - - # Unpack output per modality - out_per_modality = [self._unpack(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] - - # Strip guide frames from primary output before accumulation - if num_guide_frames > 0: - window_len = len(window.index_list) - for ci in range(len(sub_conds_out)): - out_per_modality[ci][0] = out_per_modality[ci][0].narrow(self.dim, 0, window_len) - - # Accumulate per modality - for mod_idx in range(len(accum_shape_refs)): - mw = per_modality_windows_list[mod_idx] - sub_conds_out_per_modality = [out_per_modality[ci][mod_idx] for ci in range(len(sub_conds_out))] - self.combine_context_window_results( - accum_shape_refs[mod_idx], sub_conds_out_per_modality, sub_conds, mw, - window_idx, total_windows, timestep, - accum[mod_idx], counts[mod_idx], biases[mod_idx]) + # accumulate results from each context window + for enum_window in enumerated_context_windows: + results = self.evaluate_context_windows( + calc_cond_batch, model, x_in, conds, timestep, [enum_window], + model_options, window_state=window_state, total_windows=total_windows) + for result in results: + # result.sub_conds_out is per-cond, per-modality: list[list[Tensor]] + for mod_idx in range(num_modalities): + mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))] + modality_window = result.window.get_window_for_modality(mod_idx) + self.combine_context_window_results( + window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window, + result.window_idx, total_windows, timestep, + accum[mod_idx], counts[mod_idx], biases[mod_idx]) + # fuse accumulated results into final conds try: - result = [] + result_out = [] for ci in range(len(conds)): finalized = [] - for mod_idx in range(len(accum_shape_refs)): + for mod_idx in range(num_modalities): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] f = accum[mod_idx][ci] - if mod_idx == 0 and window_data.guide_frames is not None: - f = torch.cat([f, window_data.guide_frames], dim=self.dim) + + # if guide frames were injected, append them to the end of the fused latents for the next step + if window_state.guide_latents[mod_idx] is not None: + f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim) finalized.append(f) - packed, _ = self._pack(finalized) - result.append(packed) - return result + + # pack modalities together if needed + if window_state.is_multimodal and len(finalized) > 1: + packed, _ = comfy.utils.pack_latents(finalized) + else: + packed = finalized[0] + + result_out.append(packed) + return result_out finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], - model_options, device=None, first_device=None): + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, + timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, window_state: WindowingState, total_windows: int = None, + device=None, first_device=None): + """Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out + + For each window: + 1. Builds windows (for each modality if multimodal) + 2. Slices window for each modality + 3. Injects concatenated latent guide frames where present + 4. Packs together if needed and calls model + 5. Unpacks and strips any guides from outputs + """ + x = window_state.latents[0] + results: list[ContextResults] = [] for window_idx, window in enumerated_context_windows: # allow processing to end between context window executions for faster Cancel comfy.model_management.throw_exception_if_processing_interrupted() + # prepare the window accounting for multimodal windows + window = window_state.prepare_window(window, model) + + # slice the window for each modality, injecting guide frames where applicable + sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.cond_retain_index_list, device) + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) - # update exposed params - model_options["transformer_options"]["context_window"] = window - # get subsections of x, timestep, conds - sub_x = window.get_tensor(x_in, device) - sub_timestep = window.get_tensor(timestep, device, dim=0) - sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}" + + (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "") + ) + # if multimodal, pack modalities together + if window_state.is_multimodal and len(sliced) > 1: + sub_x, sub_shapes = comfy.utils.pack_latents(sliced) + else: + sub_x, sub_shapes = sliced[0], [sliced[0].shape] + + # get resized conds for window + model_options["transformer_options"]["context_window"] = window + sub_timestep = window.get_tensor(timestep, dim=0) + sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds] + + # if multimodal, patch latent_shapes in conds for correct unpacking in model + window_state.patch_latent_shapes(sub_conds, sub_shapes) + + # call model on window sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - if device is not None: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].to(x_in.device) - results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + + # unpack outputs and strip guide frames + out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window) + + results.append(ContextResults(window_idx, out_per_modality, sub_conds, window)) return results @@ -684,28 +689,11 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - # For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise - # must only shuffle the video portion. Unpack, apply to video, repack. - latent_shapes = IndexListContextHandler._get_latent_shapes( - [guider.conds.get('positive', guider.conds.get('negative', []))]) - - if latent_shapes is not None and len(latent_shapes) > 1: - modalities = comfy.utils.unpack_latents(noise, latent_shapes) - video_total = latent_shapes[0][handler.dim] - modalities[0] = apply_freenoise(modalities[0], handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) - for i in range(1, len(modalities)): - mod_total = latent_shapes[i][handler.dim] - ratio = mod_total / video_total if video_total > 0 else 1 - mod_ctx_len = max(round(handler.context_length * ratio), 1) - mod_ctx_overlap = max(round(handler.context_overlap * ratio), 0) - modalities[i] = apply_freenoise(modalities[i], handler.dim, mod_ctx_len, mod_ctx_overlap, extra_args["seed"]) - noise, _ = comfy.utils.pack_latents(modalities) - else: - noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + conds = [guider.conds.get('positive', guider.conds.get('negative', []))] + noise = handler._apply_freenoise(noise, conds, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - def create_sampler_sample_wrapper(model: ModelPatcher): model.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, @@ -713,7 +701,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher): _sampler_sample_wrapper ) - def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8960ecd19..08f8b058d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1162,7 +1162,7 @@ class LTXAV(BaseModel): # Adjust spatial end positions for dilated (downscaled) guides. # Each guide entry may have a different downscale factor; expand the # per-entry factor to cover all tokens belonging to that entry. - downscale_factors = getattr(window, 'guide_downscale_factors', []) + downscale_factors = window.guide_downscale_factors overlap_info = window.guide_overlap_info if downscale_factors: per_token_factor = [] From f1f3182be1802ca27bc725745f936eb5b3f03af5 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 12 Apr 2026 14:52:54 -0600 Subject: [PATCH 18/23] LTX2 context windows - Fix audio index mapping for wrapped/strided primary windows The previous window-level calculation collapsed wrapped or strided primary windows into a contiguous audio tail, so audio attended to a different temporal region than the video. Replace with per-frame mapping that computes each primary index's audio span independently and concatenates in order. --- comfy/model_base.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 08f8b058d..18b3f5c0e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1088,17 +1088,24 @@ class LTXAV(BaseModel): return result video_total = latent_shapes[0][dim] - video_window_len = len(primary_indices) for i in range(1, len(latent_shapes)): mod_total = latent_shapes[i][dim] - # Length proportional to video window frame count - mod_window_len = max(round(video_window_len * mod_total / video_total), 1) - # Anchor to end of video range - v_end = max(primary_indices) + 1 - mod_end = min(round(v_end * mod_total / video_total), mod_total) - mod_start = max(mod_end - mod_window_len, 0) - result.append(list(range(mod_start, min(mod_start + mod_window_len, mod_total)))) + # Map each primary index to its proportional range of modality indices and + # concatenate in order. Preserves wrapped/strided geometry so the modality + # attends to the same temporal regions as the primary window. + mod_indices = [] + seen = set() + for v_idx in primary_indices: + a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1) + a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total) + if a_end <= a_start: + a_end = a_start + 1 + for a in range(a_start, a_end): + if a not in seen: + seen.add(a) + mod_indices.append(a) + result.append(mod_indices) return result From d59d6fb7a0796a6adfa9efcd55f80b4477e4020a Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 12 Apr 2026 15:43:36 -0600 Subject: [PATCH 19/23] LTX2 context windows - Skip VRAM estimate clamp for packed latents --- comfy/context_windows.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index a9f456426..89963699c 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -656,19 +656,22 @@ class IndexListContextHandler(ContextHandlerABC): callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) -def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): - # limit noise_shape length to context_length for more accurate vram use estimation +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs): + # Scale noise_shape to a single context window so VRAM estimation budgets per-window. model_options = kwargs.get("model_options", None) if model_options is None: raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") handler: IndexListContextHandler = model_options.get("context_handler", None) if handler is not None: noise_shape = list(noise_shape) - # Guard: only clamp when dim is within bounds and the value is meaningful - # (packed multimodal tensors have noise_shape=[B,1,flat] where flat is not frame count) - if handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: + is_packed = len(noise_shape) == 3 and noise_shape[1] == 1 + if is_packed: + # TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a + # per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM. + pass + elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) - return executor(model, noise_shape, *args, **kwargs) + return executor(model, noise_shape, conds, *args, **kwargs) def create_prepare_sampling_wrapper(model: ModelPatcher): From f72583d1f301c2b904b12f9404f168b031b43337 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 12 Apr 2026 15:46:20 -0600 Subject: [PATCH 20/23] LTX2 context windows - Move symmetric_patchifier import to module level --- comfy/model_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index e7ad3011a..c9959f255 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model +import comfy.ldm.lightricks.symmetric_patchifier import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC @@ -1156,9 +1157,8 @@ class LTXAV(BaseModel): window_len = len(window.index_list) patchifier = self.diffusion_model.patchifier latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords scale_factors = self.diffusion_model.vae_scale_factors - pixel_coords = latent_to_pixel_coords( + pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords( latent_coords, scale_factors, causal_fix=self.diffusion_model.causal_temporal_positioning) From a8b084ed5813852d95916d779e726135b7d07d50 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:42:45 -0600 Subject: [PATCH 21/23] LTX2 context windows - Thread per-modality overlap into fuse weights --- comfy/context_windows.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 89963699c..409bcc271 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -53,9 +53,10 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0): self.index_list = index_list self.context_length = len(index_list) + self.context_overlap = context_overlap self.dim = dim self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) @@ -211,8 +212,10 @@ class WindowingState: return window x = self.latents[0] + primary_total = self.latent_shapes[0][self.dim] + primary_overlap = window.context_overlap map_shapes = self.latent_shapes - if x.size(self.dim) != self.latent_shapes[0][self.dim]: + if x.size(self.dim) != primary_total: map_shapes = list(self.latent_shapes) video_shape = list(self.latent_shapes[0]) video_shape[self.dim] = x.size(self.dim) @@ -225,12 +228,16 @@ class WindowingState: f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") modality_windows = {} for mod_idx in range(1, len(self.latents)): + modality_total_frames = self.latents[mod_idx].shape[self.dim] + ratio = modality_total_frames / primary_total if primary_total > 0 else 1 + modality_overlap = max(round(primary_overlap * ratio), 0) modality_windows[mod_idx] = IndexListContextWindow( per_modality_indices[mod_idx], dim=self.dim, - total_frames=self.latents[mod_idx].shape[self.dim]) + total_frames=modality_total_frames, + context_overlap=modality_overlap) return IndexListContextWindow( window.index_list, dim=self.dim, total_frames=x.shape[self.dim], - modality_windows=modality_windows) + modality_windows=modality_windows, context_overlap=primary_overlap) def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]: """Slice latents for a context window, injecting guide frames where applicable. @@ -501,7 +508,7 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): @@ -646,7 +653,7 @@ class IndexListContextHandler(ContextHandlerABC): biases_final[i][idx] = bias_total + bias else: # add conds and counts based on weights of fuse method - weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) + weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap) weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) for i in range(len(sub_conds_out)): window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) @@ -849,8 +856,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: return ContextSchedule(context_schedule, func) -def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): - return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None): + context_overlap = handler.context_overlap if context_overlap is None else context_overlap + return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap) def create_weights_flat(length: int, **kwargs) -> list[float]: @@ -868,18 +876,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]: weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) return weight_sequence -def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): +def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs): # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 # only expected overlap is given different weights weights_torch = torch.ones((length)) # blend left-side on all except first window if min(idxs) > 0: - ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) - weights_torch[:handler.context_overlap] = ramp_up + ramp_up = torch.linspace(1e-37, 1, context_overlap) + weights_torch[:context_overlap] = ramp_up # blend right-side on all except last window if max(idxs) < full_length-1: - ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) - weights_torch[-handler.context_overlap:] = ramp_down + ramp_down = torch.linspace(1, 1e-37, context_overlap) + weights_torch[-context_overlap:] = ramp_down return weights_torch class ContextFuseMethods: From 6a53695006e7bfe1f1823ab38d720e8356dd3726 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:18:10 -0600 Subject: [PATCH 22/23] LTX2 context windows - Skip guide frames in freenoise shuffle --- comfy/context_windows.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 409bcc271..8fb7b9642 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -340,13 +340,31 @@ class IndexListContextHandler(ContextHandlerABC): return model_conds['latent_shapes'].cond return None + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: - """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.""" + """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio. + If guide frames are present on the primary modality, only the video portion is shuffled. + """ + guide_entries = self._get_guide_entries(conds) + guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 + latent_shapes = self._get_latent_shapes(conds) if latent_shapes is not None and len(latent_shapes) > 1: modalities = comfy.utils.unpack_latents(noise, latent_shapes) primary_total = latent_shapes[0][self.dim] - modalities[0] = apply_freenoise(modalities[0], self.dim, self.context_length, self.context_overlap, seed) + primary_video_count = modalities[0].size(self.dim) - guide_count + apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed) for i in range(1, len(modalities)): mod_total = latent_shapes[i][self.dim] ratio = mod_total / primary_total if primary_total > 0 else 1 @@ -355,7 +373,9 @@ class IndexListContextHandler(ContextHandlerABC): modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) noise, _ = comfy.utils.pack_latents(modalities) return noise - return apply_freenoise(noise, self.dim, self.context_length, self.context_overlap, seed) + video_count = noise.size(self.dim) - guide_count + apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed) + return noise def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState: """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" @@ -367,19 +387,7 @@ class IndexListContextHandler(ContextHandlerABC): guide_latents_list = [None] * len(unpacked_latents) guide_entries_list = [None] * len(unpacked_latents) - # Scan for 'guide_attention_entries' in conds - extracted_guide_entries = None - for cond_list in conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - entries = model_conds.get('guide_attention_entries') - if entries is not None and hasattr(entries, 'cond') and entries.cond: - extracted_guide_entries = entries.cond - break - if extracted_guide_entries is not None: - break + extracted_guide_entries = self._get_guide_entries(conds) # Strip guide frames (only from first modality for now) if extracted_guide_entries is not None: From 6442392810e13e2497520740699b987edfc73be7 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:19:25 -0600 Subject: [PATCH 23/23] Add defensive dtype cast before sigma step check --- comfy/context_windows.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 8fb7b9642..012f1bbd8 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -507,7 +507,9 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) + sample_sigmas = model_options["transformer_options"]["sample_sigmas"] + current_timestep = timestep[0].to(sample_sigmas.dtype) + mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: return # substep from multi-step sampler: keep self._step from the last full step