From 350237618d1035fd016f845390ecd1784d90c48b Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Mon, 6 Apr 2026 09:11:35 -0600 Subject: [PATCH] LTX2 context windows - Cleanup: Remove model specific code from BaseModel. Older LTXV model's guides + context_windows will need to be re-implemented but outside the scope of LTX2 changes --- comfy/context_windows.py | 6 +- comfy/model_base.py | 136 +++++++++++++++------------------------ 2 files changed, 55 insertions(+), 87 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index fe1afdffe..295f348e6 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -257,7 +257,7 @@ class IndexListContextHandler(ContextHandlerABC): def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: latent_shapes = self._get_latent_shapes(conds) primary = self._decompose(x_in, latent_shapes)[0] - guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0 + guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0 video_frames = primary.size(self.dim) - guide_count if video_frames > self.context_length: if guide_count > 0: @@ -380,7 +380,7 @@ class IndexListContextHandler(ContextHandlerABC): primary = modalities[0] # Separate guide frames from primary modality (guides are appended at the end) - guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0 + guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0 if guide_count > 0: video_len = primary.size(self.dim) - guide_count video_primary = primary.narrow(self.dim, 0, video_len) @@ -427,7 +427,7 @@ class IndexListContextHandler(ContextHandlerABC): video_shape[self.dim] = video_shape[self.dim] - guide_count map_shapes[0] = torch.Size(video_shape) per_mod_indices = model.map_context_window_to_modalities( - window.index_list, map_shapes, self.dim) + window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list] # Build per-modality windows and attach to primary window modality_windows = {} for mod_idx in range(1, len(modalities)): diff --git a/comfy/model_base.py b/comfy/model_base.py index e4659a236..066e3f8d8 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -293,71 +293,6 @@ class BaseModel(torch.nn.Module): Use comfy.context_windows.slice_cond() for common cases.""" return None - def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): - """Map primary modality's window indices to all modalities. - Returns list of index lists, one per modality.""" - return [primary_indices] - - def get_guide_frame_count(self, x, conds): - """Return the number of trailing guide frames appended to x along the temporal dim. - Override in subclasses that concatenate guide reference frames to the latent.""" - return 0 - - def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - """Resize guide-related conditioning for context windows. - Requires guide_suffix_indices, guide_overlap_info, and guide_kf_local_positions - to be set on the window by _compute_guide_overlap.""" - if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): - cond_tensor = cond_value.cond - guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) - if guide_count > 0: - T_video = x_in.size(window.dim) - video_mask = cond_tensor.narrow(window.dim, 0, T_video) - guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) - sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - suffix_indices = window.guide_suffix_indices - if suffix_indices: - idx = tuple([slice(None)] * window.dim + [suffix_indices]) - sliced_guide = guide_mask[idx].to(device) - return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) - else: - return cond_value._copy_with(sliced_video) - - if cond_key == "keyframe_idxs": - kf_local_pos = window.guide_kf_local_positions - if not kf_local_pos: - return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty - H, W = x_in.shape[3], x_in.shape[4] - window_len = len(window.index_list) - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - tokens = [] - for pos in kf_local_pos: - tokens.extend(range(pos * H * W, (pos + 1) * H * W)) - pixel_coords = pixel_coords[:, :, tokens, :] - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) - - if cond_key == "guide_attention_entries": - overlap_info = window.guide_overlap_info - H, W = x_in.shape[3], x_in.shape[4] - new_entries = [] - for entry_idx, overlap_count in overlap_info: - e = cond_value.cond[entry_idx] - new_entries.append({**e, - "pre_filter_count": overlap_count * H * W, - "latent_shape": [overlap_count, H, W]}) - return cond_value._copy_with(new_entries) - - return None - def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1081,20 +1016,6 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image - def get_guide_frame_count(self, x, conds): - for cond_list in conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - gae = model_conds.get('guide_attention_entries') - if gae is not None and hasattr(gae, 'cond') and gae.cond: - return sum(e["latent_shape"][0] for e in gae.cond) - return 0 - - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - return self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list) - class LTXAV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO @@ -1193,17 +1114,64 @@ class LTXAV(BaseModel): return 0 def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - # Audio-specific handling + # Audio denoise mask — slice using audio modality window if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: audio_window = window.modality_windows.get(1) if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) return cond_value._copy_with(sliced) - # Guide handling (shared with LTXV) - result = self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list) - if result is not None: - return result + # Video denoise mask — split into video + guide portions, slice each + if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + cond_tensor = cond_value.cond + guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) + if guide_count > 0: + T_video = x_in.size(window.dim) + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + suffix_indices = window.guide_suffix_indices + if suffix_indices: + idx = tuple([slice(None)] * window.dim + [suffix_indices]) + sliced_guide = guide_mask[idx].to(device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + else: + return cond_value._copy_with(sliced_video) + + # Keyframe indices — regenerate pixel coords for window, select guide positions + if cond_key == "keyframe_idxs": + kf_local_pos = window.guide_kf_local_positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + # Guide attention entries — adjust per-guide counts based on window overlap + if cond_key == "guide_attention_entries": + overlap_info = window.guide_overlap_info + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) return None