From 115dbb69d18f29b1112cc7aec8a51ecf55f3e7ae Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:23:43 -0600 Subject: [PATCH] LTX2 context windows part 3 - Generalize guide splitting to windows --- comfy/context_windows.py | 84 ++++++++++++++++++++++++++++++-- comfy/ldm/lightricks/model.py | 2 +- comfy/model_base.py | 90 ++++++++++++++++++++++++++--------- comfy_extras/nodes_lt.py | 4 +- 4 files changed, 152 insertions(+), 28 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 357fbae17..4ace5ec13 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -140,6 +140,48 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) +def _compute_guide_overlap(guide_entries, window_index_list): + """Compute which guide frames overlap with a context window. + + Args: + guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape') + window_index_list: the window's frame indices into the video portion + + Returns None if any entry lacks 'latent_start' (backward compat → legacy path). + Otherwise returns (suffix_indices, overlap_info, kf_local_positions, total_overlap): + suffix_indices: indices into the guide_suffix tensor for frame selection + overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment + kf_local_positions: window-local frame positions for keyframe_idxs regeneration + total_overlap: total number of overlapping guide frames + """ + window_set = set(window_index_list) + window_list = list(window_index_list) + suffix_indices = [] + overlap_info = [] + kf_local_positions = [] + suffix_base = 0 + + for entry_idx, entry in enumerate(guide_entries): + latent_start = entry.get("latent_start", None) + if latent_start is None: + return None + guide_len = entry["latent_shape"][0] + entry_overlap = 0 + + for local_offset in range(guide_len): + video_pos = latent_start + local_offset + if video_pos in window_set: + suffix_indices.append(suffix_base + local_offset) + kf_local_positions.append(window_list.index(video_pos)) + entry_overlap += 1 + + if entry_overlap > 0: + overlap_info.append((entry_idx, entry_overlap)) + suffix_base += guide_len + + return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) + + @dataclass class ContextSchedule: name: str @@ -201,6 +243,18 @@ class IndexListContextHandler(ContextHandlerABC): if 'latent_shapes' in model_conds: model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + def _get_guide_entries(self, conds): + """Extract guide_attention_entries list from conditioning. Returns None if absent.""" + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + gae = model_conds.get('guide_attention_entries') + if gae is not None and hasattr(gae, 'cond') and gae.cond: + return gae.cond + return None + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: latent_shapes = self._get_latent_shapes(conds) primary = self._decompose(x_in, latent_shapes)[0] @@ -353,6 +407,8 @@ class IndexListContextHandler(ContextHandlerABC): counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities] + guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) @@ -391,10 +447,30 @@ class IndexListContextHandler(ContextHandlerABC): for mod_idx in range(1, len(modalities)): mod_windows.append(modality_windows[mod_idx]) - # Slice video and guide with same window indices, concatenate + # Slice video, then select overlapping guide frames sliced_video = mod_windows[0].get_tensor(video_primary) - if guide_suffix is not None: - sliced_guide = mod_windows[0].get_tensor(guide_suffix) + num_guide_in_window = 0 + if guide_suffix is not None and guide_entries is not None: + overlap = _compute_guide_overlap(guide_entries, window.index_list) + if overlap is None: + # Legacy: no latent_start → equal-size assumption + sliced_guide = mod_windows[0].get_tensor(guide_suffix) + num_guide_in_window = sliced_guide.shape[self.dim] + elif overlap[3] > 0: + suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap + idx = tuple([slice(None)] * self.dim + [suffix_idx]) + sliced_guide = guide_suffix[idx] + window.guide_suffix_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + else: + sliced_guide = None + window.guide_overlap_info = [] + window.guide_kf_local_positions = [] + else: + sliced_guide = None + + if sliced_guide is not None: sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim) else: sliced_primary = sliced_video @@ -421,7 +497,7 @@ class IndexListContextHandler(ContextHandlerABC): # out_per_mod[cond_idx][mod_idx] = tensor # Strip guide frames from primary output before accumulation - if guide_count > 0: + if num_guide_in_window > 0: window_len = len(window.index_list) for ci in range(len(sub_conds_out)): primary_out = out_per_mod[ci][0] diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index bfbc08357..c55e19ced 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1028,7 +1028,7 @@ class LTXVModel(LTXBaseModel): ) grid_mask = None - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: additional_args.update({ "orig_patchified_shape": list(x.shape)}) denoise_mask = self.patchifier.patchify(denoise_mask)[0] grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] diff --git a/comfy/model_base.py b/comfy/model_base.py index 9c31e2651..ae2ce2eb0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,8 +305,8 @@ class BaseModel(torch.nn.Module): def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): """Resize guide-related conditioning for context windows. - Derives guide_count from denoise_mask/x_in size difference. - Derives spatial dims from x_in. Requires self.diffusion_model.patchifier.""" + Uses overlap info from window if available (generalized path), + otherwise falls back to legacy equal-size assumption.""" if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): cond_tensor = cond_value.cond guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) @@ -315,30 +315,76 @@ class BaseModel(torch.nn.Module): video_mask = cond_tensor.narrow(window.dim, 0, T_video) guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - sliced_guide = window.get_tensor(guide_mask, device) - return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + # Use overlap-based guide selection if available, otherwise legacy + suffix_indices = getattr(window, 'guide_suffix_indices', None) + if suffix_indices is not None: + idx = tuple([slice(None)] * window.dim + [suffix_indices]) + sliced_guide = guide_mask[idx].to(device) if suffix_indices else None + else: + sliced_guide = window.get_tensor(guide_mask, device) + if sliced_guide is not None and sliced_guide.shape[window.dim] > 0: + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + else: + return cond_value._copy_with(sliced_video) if cond_key == "keyframe_idxs": - window_len = len(window.index_list) - H, W = x_in.shape[3], x_in.shape[4] - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) + kf_local_pos = getattr(window, 'guide_kf_local_positions', None) + if kf_local_pos is not None: + # Generalized: regenerate coords for full window, select guide positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + else: + # Legacy: regenerate for window_len (equal-size assumption) + window_len = len(window.index_list) + H, W = x_in.shape[3], x_in.shape[4] + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) if cond_key == "guide_attention_entries": - window_len = len(window.index_list) - H, W = x_in.shape[3], x_in.shape[4] - new_entries = [{**e, "pre_filter_count": window_len * H * W, - "latent_shape": [window_len, H, W]} for e in cond_value.cond] - return cond_value._copy_with(new_entries) + overlap_info = getattr(window, 'guide_overlap_info', None) + if overlap_info is not None: + # Generalized: per-guide adjustment based on overlap + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) + else: + # Legacy: all entries adjusted to window_len + window_len = len(window.index_list) + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [{**e, "pre_filter_count": window_len * H * W, + "latent_shape": [window_len, H, W]} for e in cond_value.cond] + return cond_value._copy_with(new_entries) return None diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index d7c2e8744..d8ba0bb27 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -135,7 +135,7 @@ class LTXVImgToVideoInplace(io.ComfyNode): generate = execute # TODO: remove -def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, latent_start=0): """Append a guide_attention_entry to both positive and negative conditioning. Each entry tracks one guide reference for per-reference attention control. @@ -146,6 +146,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s "strength": strength, "pixel_mask": None, "latent_shape": latent_shape, + "latent_start": latent_start, } results = [] for cond in (positive, negative): @@ -362,6 +363,7 @@ class LTXVAddGuide(io.ComfyNode): guide_latent_shape = list(t.shape[2:]) # [F, H, W] positive, negative = _append_guide_attention_entry( positive, negative, pre_filter_count, guide_latent_shape, strength=strength, + latent_start=latent_idx, ) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})