From 56de390c2533112ba9925604573b12396a229885 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 23 Mar 2026 13:50:28 -0600 Subject: [PATCH] LTX2 context windows part 2 - Guide aware processing --- comfy/context_windows.py | 93 ++++++++++++++++++++++++------- comfy/model_base.py | 117 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 185 insertions(+), 25 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 29ee2b5b1..91d019a00 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -204,8 +204,13 @@ class IndexListContextHandler(ContextHandlerABC): def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: latent_shapes = self._get_latent_shapes(conds) primary = self._decompose(x_in, latent_shapes)[0] - if primary.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {primary.size(self.dim)} frames.") + guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0 + video_frames = primary.size(self.dim) - guide_count + if video_frames > self.context_length: + if guide_count > 0: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_count} guide frames excluded).") + else: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True @@ -321,18 +326,32 @@ class IndexListContextHandler(ContextHandlerABC): is_multimodal = len(modalities) > 1 primary = modalities[0] - # Windows from primary modality's temporal dim - context_windows = self.get_context_windows(model, primary, model_options) + # Separate guide frames from primary modality (guides are appended at the end) + guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0 + if guide_count > 0: + video_len = primary.size(self.dim) - guide_count + video_primary = primary.narrow(self.dim, 0, video_len) + guide_suffix = primary.narrow(self.dim, video_len, guide_count) + else: + video_primary = primary + guide_suffix = None + + # Windows from video portion only (excluding guide frames) + context_windows = self.get_context_windows(model, video_primary, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) - # Per-modality accumulators: accum[mod_idx][cond_idx] - accum = [[torch.zeros_like(m) for _ in conds] for m in modalities] + # Accumulators sized to video portion for primary, full for other modalities + accum_modalities = list(modalities) + if guide_suffix is not None: + accum_modalities[0] = video_primary + + accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] else: - counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities] - biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in modalities] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) @@ -340,10 +359,22 @@ class IndexListContextHandler(ContextHandlerABC): for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() + # Attach guide info to window for resize_cond_for_context_window + window.guide_count = guide_count + if guide_suffix is not None: + window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4]) + # Per-modality window indices if is_multimodal: + # Adjust latent_shapes so video shape reflects video-only frames (excludes guides) + map_shapes = latent_shapes + if guide_count > 0: + map_shapes = list(latent_shapes) + video_shape = list(latent_shapes[0]) + video_shape[self.dim] = video_shape[self.dim] - guide_count + map_shapes[0] = torch.Size(video_shape) per_mod_indices = model.map_context_window_to_modalities( - window.index_list, latent_shapes, self.dim) + window.index_list, map_shapes, self.dim) # Build per-modality windows and attach to primary window modality_windows = {} for mod_idx in range(1, len(modalities)): @@ -351,8 +382,11 @@ class IndexListContextHandler(ContextHandlerABC): per_mod_indices[mod_idx], dim=self.dim, total_frames=modalities[mod_idx].shape[self.dim]) window = IndexListContextWindow( - window.index_list, dim=self.dim, total_frames=primary.shape[self.dim], + window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim], modality_windows=modality_windows) + window.guide_count = guide_count + if guide_suffix is not None: + window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4]) else: per_mod_indices = [window.index_list] @@ -362,8 +396,14 @@ class IndexListContextHandler(ContextHandlerABC): for mod_idx in range(1, len(modalities)): mod_windows.append(modality_windows[mod_idx]) - # Slice each modality - sliced = [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(len(modalities))] + # Slice video and guide with same window indices, concatenate + sliced_video = mod_windows[0].get_tensor(video_primary) + if guide_suffix is not None: + sliced_guide = mod_windows[0].get_tensor(guide_suffix) + sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim) + else: + sliced_primary = sliced_video + sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))] # Compose for pipeline sub_x, sub_shapes = self._compose(sliced) @@ -374,8 +414,8 @@ class IndexListContextHandler(ContextHandlerABC): model_options["transformer_options"]["context_window"] = window sub_timestep = window.get_tensor(timestep, dim=0) - # Resize conds using primary tensor as reference (correct temporal dim) - sub_conds = [self.get_resized_cond(cond, primary, window) for cond in conds] + # Resize conds using video_primary as reference (excludes guide frames) + sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds] if is_multimodal: self._patch_latent_shapes(sub_conds, sub_shapes) @@ -385,13 +425,19 @@ class IndexListContextHandler(ContextHandlerABC): out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] # out_per_mod[cond_idx][mod_idx] = tensor - # Accumulate per modality - for mod_idx in range(len(modalities)): + # Strip guide frames from primary output before accumulation + if guide_count > 0: + window_len = len(window.index_list) + for ci in range(len(sub_conds_out)): + primary_out = out_per_mod[ci][0] + out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len) + + # Accumulate per modality (using video-only sizes) + for mod_idx in range(len(accum_modalities)): mw = mod_windows[mod_idx] - # Build per-modality sub_conds_out list for combine mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))] self.combine_context_window_results( - modalities[mod_idx], mod_sub_out, sub_conds, mw, + accum_modalities[mod_idx], mod_sub_out, sub_conds, mw, window_idx, total_windows, timestep, accum[mod_idx], counts[mod_idx], biases[mod_idx]) @@ -399,10 +445,15 @@ class IndexListContextHandler(ContextHandlerABC): result = [] for ci in range(len(conds)): finalized = [] - for mod_idx in range(len(modalities)): + for mod_idx in range(len(accum_modalities)): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] - finalized.append(accum[mod_idx][ci]) + f = accum[mod_idx][ci] + # Re-append original guide_suffix (not model output — sampling loop + # respects denoise_mask and never modifies guide frame positions) + if mod_idx == 0 and guide_suffix is not None: + f = torch.cat([f, guide_suffix], dim=self.dim) + finalized.append(f) composed, _ = self._compose(finalized) result.append(composed) return result diff --git a/comfy/model_base.py b/comfy/model_base.py index 3096ca4fb..12d49305f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -298,6 +298,11 @@ class BaseModel(torch.nn.Module): Returns list of index lists, one per modality.""" return [primary_indices] + def get_guide_frame_count(self, x, conds): + """Return the number of trailing guide frames appended to x along the temporal dim. + Override in subclasses that concatenate guide reference frames to the latent.""" + return 0 + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1021,12 +1026,64 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + def get_guide_frame_count(self, x, conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + gae = model_conds.get('guide_attention_entries') + if gae is not None and hasattr(gae, 'cond') and gae.cond: + return sum(e["latent_shape"][0] for e in gae.cond) + return 0 + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + guide_count = getattr(window, 'guide_count', 0) + + if cond_key == "denoise_mask" and guide_count > 0: + # Slice both video and guide halves with same window indices + cond_tensor = cond_value.cond + T_video = cond_tensor.size(window.dim) - guide_count + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + sliced_guide = window.get_tensor(guide_mask, device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + + if cond_key == "keyframe_idxs" and guide_count > 0: + # Recompute coords for window_len frames so guide tokens are co-located + # with noise tokens in RoPE space (identical to a standalone short video) + window_len = len(window.index_list) + H, W = window.guide_spatial + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + if cond_key == "guide_attention_entries" and guide_count > 0: + # Adjust token counts for window size + window_len = len(window.index_list) + H, W = window.guide_spatial + new_entries = [{**e, "pre_filter_count": window_len * H * W, + "latent_shape": [window_len, H, W]} for e in cond_value.cond] + return cond_value._copy_with(new_entries) + + return None + class LTXAV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) + logging.info(f"LTXAV.extra_conds: guide_attention_entries={'guide_attention_entries' in kwargs}, keyframe_idxs={'keyframe_idxs' in kwargs}") attention_mask = kwargs.get("attention_mask", None) device = kwargs["device"] @@ -1106,13 +1163,65 @@ class LTXAV(BaseModel): result.append(audio_indices) return result + def get_guide_frame_count(self, x, conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + gae = model_conds.get('guide_attention_entries') + logging.info(f"LTXAV.get_guide_frame_count: keys={list(model_conds.keys())}, gae={gae is not None}") + if gae is not None and hasattr(gae, 'cond') and gae.cond: + count = sum(e["latent_shape"][0] for e in gae.cond) + logging.info(f"LTXAV.get_guide_frame_count: found {count} guide frames") + return count + logging.info("LTXAV.get_guide_frame_count: no guide frames found") + return 0 + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # Audio-specific handling if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: audio_window = window.modality_windows.get(1) - if audio_window is not None: - import comfy.context_windows - return comfy.context_windows.slice_cond( - cond_value, audio_window, x_in, device, temporal_dim=2) + if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) + return cond_value._copy_with(sliced) + + # Guide handling (same as LTXV — shared guide mechanism) + guide_count = getattr(window, 'guide_count', 0) + if cond_key in ("keyframe_idxs", "guide_attention_entries", "denoise_mask"): + logging.info(f"LTXAV resize_cond: {cond_key}, guide_count={guide_count}, has_spatial={hasattr(window, 'guide_spatial')}") + + if cond_key == "denoise_mask" and guide_count > 0: + cond_tensor = cond_value.cond + T_video = cond_tensor.size(window.dim) - guide_count + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + sliced_guide = window.get_tensor(guide_mask, device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + + if cond_key == "keyframe_idxs" and guide_count > 0: + window_len = len(window.index_list) + H, W = window.guide_spatial + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + if cond_key == "guide_attention_entries" and guide_count > 0: + window_len = len(window.index_list) + H, W = window.guide_spatial + new_entries = [{**e, "pre_filter_count": window_len * H * W, + "latent_shape": [window_len, H, W]} for e in cond_value.cond] + return cond_value._copy_with(new_entries) + return None class HunyuanVideo(BaseModel):