LTX2 context windows part 2b - Calculate guide parameters in model code, refactor

This commit is contained in:
ozbayb 2026-03-24 11:43:42 -06:00
parent 56de390c25
commit 941d50e777
2 changed files with 48 additions and 87 deletions

View File

@ -358,11 +358,9 @@ class IndexListContextHandler(ContextHandlerABC):
for window_idx, window in enumerated_context_windows:
comfy.model_management.throw_exception_if_processing_interrupted()
# Attach guide info to window for resize_cond_for_context_window
window.guide_count = guide_count
if guide_suffix is not None:
window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4])
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
+ (f" (+{guide_count} guide)" if guide_count > 0 else "")
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
# Per-modality window indices
if is_multimodal:
@ -384,9 +382,6 @@ class IndexListContextHandler(ContextHandlerABC):
window = IndexListContextWindow(
window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim],
modality_windows=modality_windows)
window.guide_count = guide_count
if guide_suffix is not None:
window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4])
else:
per_mod_indices = [window.index_list]

View File

@ -303,6 +303,45 @@ class BaseModel(torch.nn.Module):
Override in subclasses that concatenate guide reference frames to the latent."""
return 0
def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Resize guide-related conditioning for context windows.
Derives guide_count from denoise_mask/x_in size difference.
Derives spatial dims from x_in. Requires self.diffusion_model.patchifier."""
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
cond_tensor = cond_value.cond
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
if guide_count > 0:
T_video = x_in.size(window.dim)
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
sliced_guide = window.get_tensor(guide_mask, device)
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
if cond_key == "keyframe_idxs":
window_len = len(window.index_list)
H, W = x_in.shape[3], x_in.shape[4]
patchifier = self.diffusion_model.patchifier
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
pixel_coords = latent_to_pixel_coords(
latent_coords,
self.diffusion_model.vae_scale_factors,
causal_fix=self.diffusion_model.causal_temporal_positioning)
B = cond_value.cond.shape[0]
if B > 1:
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
return cond_value._copy_with(pixel_coords)
if cond_key == "guide_attention_entries":
window_len = len(window.index_list)
H, W = x_in.shape[3], x_in.shape[4]
new_entries = [{**e, "pre_filter_count": window_len * H * W,
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
return cond_value._copy_with(new_entries)
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
@ -1038,44 +1077,7 @@ class LTXV(BaseModel):
return 0
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
guide_count = getattr(window, 'guide_count', 0)
if cond_key == "denoise_mask" and guide_count > 0:
# Slice both video and guide halves with same window indices
cond_tensor = cond_value.cond
T_video = cond_tensor.size(window.dim) - guide_count
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
sliced_guide = window.get_tensor(guide_mask, device)
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
if cond_key == "keyframe_idxs" and guide_count > 0:
# Recompute coords for window_len frames so guide tokens are co-located
# with noise tokens in RoPE space (identical to a standalone short video)
window_len = len(window.index_list)
H, W = window.guide_spatial
patchifier = self.diffusion_model.patchifier
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
pixel_coords = latent_to_pixel_coords(
latent_coords,
self.diffusion_model.vae_scale_factors,
causal_fix=self.diffusion_model.causal_temporal_positioning)
B = cond_value.cond.shape[0]
if B > 1:
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
return cond_value._copy_with(pixel_coords)
if cond_key == "guide_attention_entries" and guide_count > 0:
# Adjust token counts for window size
window_len = len(window.index_list)
H, W = window.guide_spatial
new_entries = [{**e, "pre_filter_count": window_len * H * W,
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
return cond_value._copy_with(new_entries)
return None
return self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list)
class LTXAV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
@ -1083,7 +1085,6 @@ class LTXAV(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
logging.info(f"LTXAV.extra_conds: guide_attention_entries={'guide_attention_entries' in kwargs}, keyframe_idxs={'keyframe_idxs' in kwargs}")
attention_mask = kwargs.get("attention_mask", None)
device = kwargs["device"]
@ -1170,12 +1171,8 @@ class LTXAV(BaseModel):
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
gae = model_conds.get('guide_attention_entries')
logging.info(f"LTXAV.get_guide_frame_count: keys={list(model_conds.keys())}, gae={gae is not None}")
if gae is not None and hasattr(gae, 'cond') and gae.cond:
count = sum(e["latent_shape"][0] for e in gae.cond)
logging.info(f"LTXAV.get_guide_frame_count: found {count} guide frames")
return count
logging.info("LTXAV.get_guide_frame_count: no guide frames found")
return sum(e["latent_shape"][0] for e in gae.cond)
return 0
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
@ -1186,41 +1183,10 @@ class LTXAV(BaseModel):
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
return cond_value._copy_with(sliced)
# Guide handling (same as LTXV — shared guide mechanism)
guide_count = getattr(window, 'guide_count', 0)
if cond_key in ("keyframe_idxs", "guide_attention_entries", "denoise_mask"):
logging.info(f"LTXAV resize_cond: {cond_key}, guide_count={guide_count}, has_spatial={hasattr(window, 'guide_spatial')}")
if cond_key == "denoise_mask" and guide_count > 0:
cond_tensor = cond_value.cond
T_video = cond_tensor.size(window.dim) - guide_count
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
sliced_guide = window.get_tensor(guide_mask, device)
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
if cond_key == "keyframe_idxs" and guide_count > 0:
window_len = len(window.index_list)
H, W = window.guide_spatial
patchifier = self.diffusion_model.patchifier
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
pixel_coords = latent_to_pixel_coords(
latent_coords,
self.diffusion_model.vae_scale_factors,
causal_fix=self.diffusion_model.causal_temporal_positioning)
B = cond_value.cond.shape[0]
if B > 1:
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
return cond_value._copy_with(pixel_coords)
if cond_key == "guide_attention_entries" and guide_count > 0:
window_len = len(window.index_list)
H, W = window.guide_spatial
new_entries = [{**e, "pre_filter_count": window_len * H * W,
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
return cond_value._copy_with(new_entries)
# Guide handling (shared with LTXV)
result = self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list)
if result is not None:
return result
return None