LTX2 context windows - Skip guide frames in freenoise shuffle

This commit is contained in:
ozbayb 2026-04-12 20:18:10 -06:00
parent a8b084ed58
commit 6a53695006

View File

@ -340,13 +340,31 @@ class IndexListContextHandler(ContextHandlerABC):
return model_conds['latent_shapes'].cond
return None
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio."""
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
If guide frames are present on the primary modality, only the video portion is shuffled.
"""
guide_entries = self._get_guide_entries(conds)
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
latent_shapes = self._get_latent_shapes(conds)
if latent_shapes is not None and len(latent_shapes) > 1:
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
primary_total = latent_shapes[0][self.dim]
modalities[0] = apply_freenoise(modalities[0], self.dim, self.context_length, self.context_overlap, seed)
primary_video_count = modalities[0].size(self.dim) - guide_count
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
for i in range(1, len(modalities)):
mod_total = latent_shapes[i][self.dim]
ratio = mod_total / primary_total if primary_total > 0 else 1
@ -355,7 +373,9 @@ class IndexListContextHandler(ContextHandlerABC):
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
noise, _ = comfy.utils.pack_latents(modalities)
return noise
return apply_freenoise(noise, self.dim, self.context_length, self.context_overlap, seed)
video_count = noise.size(self.dim) - guide_count
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
return noise
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState:
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
@ -367,19 +387,7 @@ class IndexListContextHandler(ContextHandlerABC):
guide_latents_list = [None] * len(unpacked_latents)
guide_entries_list = [None] * len(unpacked_latents)
# Scan for 'guide_attention_entries' in conds
extracted_guide_entries = None
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
extracted_guide_entries = entries.cond
break
if extracted_guide_entries is not None:
break
extracted_guide_entries = self._get_guide_entries(conds)
# Strip guide frames (only from first modality for now)
if extracted_guide_entries is not None: