From 6a53695006e7bfe1f1823ab38d720e8356dd3726 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:18:10 -0600 Subject: [PATCH] LTX2 context windows - Skip guide frames in freenoise shuffle --- comfy/context_windows.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 409bcc271..8fb7b9642 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -340,13 +340,31 @@ class IndexListContextHandler(ContextHandlerABC): return model_conds['latent_shapes'].cond return None + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: - """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.""" + """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio. + If guide frames are present on the primary modality, only the video portion is shuffled. + """ + guide_entries = self._get_guide_entries(conds) + guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 + latent_shapes = self._get_latent_shapes(conds) if latent_shapes is not None and len(latent_shapes) > 1: modalities = comfy.utils.unpack_latents(noise, latent_shapes) primary_total = latent_shapes[0][self.dim] - modalities[0] = apply_freenoise(modalities[0], self.dim, self.context_length, self.context_overlap, seed) + primary_video_count = modalities[0].size(self.dim) - guide_count + apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed) for i in range(1, len(modalities)): mod_total = latent_shapes[i][self.dim] ratio = mod_total / primary_total if primary_total > 0 else 1 @@ -355,7 +373,9 @@ class IndexListContextHandler(ContextHandlerABC): modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) noise, _ = comfy.utils.pack_latents(modalities) return noise - return apply_freenoise(noise, self.dim, self.context_length, self.context_overlap, seed) + video_count = noise.size(self.dim) - guide_count + apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed) + return noise def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState: """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" @@ -367,19 +387,7 @@ class IndexListContextHandler(ContextHandlerABC): guide_latents_list = [None] * len(unpacked_latents) guide_entries_list = [None] * len(unpacked_latents) - # Scan for 'guide_attention_entries' in conds - extracted_guide_entries = None - for cond_list in conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - entries = model_conds.get('guide_attention_entries') - if entries is not None and hasattr(entries, 'cond') and entries.cond: - extracted_guide_entries = entries.cond - break - if extracted_guide_entries is not None: - break + extracted_guide_entries = self._get_guide_entries(conds) # Strip guide frames (only from first modality for now) if extracted_guide_entries is not None: