LTX2 context windows - Skip guide frames in freenoise shuffle

This commit is contained in:
ozbayb 2026-04-12 20:18:10 -06:00
parent a8b084ed58
commit 6a53695006

View File

@ -340,13 +340,31 @@ class IndexListContextHandler(ContextHandlerABC):
return model_conds['latent_shapes'].cond return model_conds['latent_shapes'].cond
return None return None
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.""" """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
If guide frames are present on the primary modality, only the video portion is shuffled.
"""
guide_entries = self._get_guide_entries(conds)
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
latent_shapes = self._get_latent_shapes(conds) latent_shapes = self._get_latent_shapes(conds)
if latent_shapes is not None and len(latent_shapes) > 1: if latent_shapes is not None and len(latent_shapes) > 1:
modalities = comfy.utils.unpack_latents(noise, latent_shapes) modalities = comfy.utils.unpack_latents(noise, latent_shapes)
primary_total = latent_shapes[0][self.dim] primary_total = latent_shapes[0][self.dim]
modalities[0] = apply_freenoise(modalities[0], self.dim, self.context_length, self.context_overlap, seed) primary_video_count = modalities[0].size(self.dim) - guide_count
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
for i in range(1, len(modalities)): for i in range(1, len(modalities)):
mod_total = latent_shapes[i][self.dim] mod_total = latent_shapes[i][self.dim]
ratio = mod_total / primary_total if primary_total > 0 else 1 ratio = mod_total / primary_total if primary_total > 0 else 1
@ -355,7 +373,9 @@ class IndexListContextHandler(ContextHandlerABC):
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
noise, _ = comfy.utils.pack_latents(modalities) noise, _ = comfy.utils.pack_latents(modalities)
return noise return noise
return apply_freenoise(noise, self.dim, self.context_length, self.context_overlap, seed) video_count = noise.size(self.dim) - guide_count
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
return noise
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState: def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState:
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
@ -367,19 +387,7 @@ class IndexListContextHandler(ContextHandlerABC):
guide_latents_list = [None] * len(unpacked_latents) guide_latents_list = [None] * len(unpacked_latents)
guide_entries_list = [None] * len(unpacked_latents) guide_entries_list = [None] * len(unpacked_latents)
# Scan for 'guide_attention_entries' in conds extracted_guide_entries = self._get_guide_entries(conds)
extracted_guide_entries = None
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
extracted_guide_entries = entries.cond
break
if extracted_guide_entries is not None:
break
# Strip guide frames (only from first modality for now) # Strip guide frames (only from first modality for now)
if extracted_guide_entries is not None: if extracted_guide_entries is not None: