mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
LTX2 context windows part 2b - Calculate guide parameters in model code, refactor
This commit is contained in:
parent
56de390c25
commit
941d50e777
@ -358,11 +358,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# Attach guide info to window for resize_cond_for_context_window
|
||||
window.guide_count = guide_count
|
||||
if guide_suffix is not None:
|
||||
window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4])
|
||||
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
|
||||
+ (f" (+{guide_count} guide)" if guide_count > 0 else "")
|
||||
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
|
||||
|
||||
# Per-modality window indices
|
||||
if is_multimodal:
|
||||
@ -384,9 +382,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
window = IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim],
|
||||
modality_windows=modality_windows)
|
||||
window.guide_count = guide_count
|
||||
if guide_suffix is not None:
|
||||
window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4])
|
||||
else:
|
||||
per_mod_indices = [window.index_list]
|
||||
|
||||
|
||||
@ -303,6 +303,45 @@ class BaseModel(torch.nn.Module):
|
||||
Override in subclasses that concatenate guide reference frames to the latent."""
|
||||
return 0
|
||||
|
||||
def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Resize guide-related conditioning for context windows.
|
||||
Derives guide_count from denoise_mask/x_in size difference.
|
||||
Derives spatial dims from x_in. Requires self.diffusion_model.patchifier."""
|
||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
cond_tensor = cond_value.cond
|
||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||
if guide_count > 0:
|
||||
T_video = x_in.size(window.dim)
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
sliced_guide = window.get_tensor(guide_mask, device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
|
||||
if cond_key == "keyframe_idxs":
|
||||
window_len = len(window.index_list)
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
if cond_key == "guide_attention_entries":
|
||||
window_len = len(window.index_list)
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@ -1038,44 +1077,7 @@ class LTXV(BaseModel):
|
||||
return 0
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
guide_count = getattr(window, 'guide_count', 0)
|
||||
|
||||
if cond_key == "denoise_mask" and guide_count > 0:
|
||||
# Slice both video and guide halves with same window indices
|
||||
cond_tensor = cond_value.cond
|
||||
T_video = cond_tensor.size(window.dim) - guide_count
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
sliced_guide = window.get_tensor(guide_mask, device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
|
||||
if cond_key == "keyframe_idxs" and guide_count > 0:
|
||||
# Recompute coords for window_len frames so guide tokens are co-located
|
||||
# with noise tokens in RoPE space (identical to a standalone short video)
|
||||
window_len = len(window.index_list)
|
||||
H, W = window.guide_spatial
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
if cond_key == "guide_attention_entries" and guide_count > 0:
|
||||
# Adjust token counts for window size
|
||||
window_len = len(window.index_list)
|
||||
H, W = window.guide_spatial
|
||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
return self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list)
|
||||
|
||||
class LTXAV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
@ -1083,7 +1085,6 @@ class LTXAV(BaseModel):
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
logging.info(f"LTXAV.extra_conds: guide_attention_entries={'guide_attention_entries' in kwargs}, keyframe_idxs={'keyframe_idxs' in kwargs}")
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
@ -1170,12 +1171,8 @@ class LTXAV(BaseModel):
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
gae = model_conds.get('guide_attention_entries')
|
||||
logging.info(f"LTXAV.get_guide_frame_count: keys={list(model_conds.keys())}, gae={gae is not None}")
|
||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||
count = sum(e["latent_shape"][0] for e in gae.cond)
|
||||
logging.info(f"LTXAV.get_guide_frame_count: found {count} guide frames")
|
||||
return count
|
||||
logging.info("LTXAV.get_guide_frame_count: no guide frames found")
|
||||
return sum(e["latent_shape"][0] for e in gae.cond)
|
||||
return 0
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
@ -1186,41 +1183,10 @@ class LTXAV(BaseModel):
|
||||
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# Guide handling (same as LTXV — shared guide mechanism)
|
||||
guide_count = getattr(window, 'guide_count', 0)
|
||||
if cond_key in ("keyframe_idxs", "guide_attention_entries", "denoise_mask"):
|
||||
logging.info(f"LTXAV resize_cond: {cond_key}, guide_count={guide_count}, has_spatial={hasattr(window, 'guide_spatial')}")
|
||||
|
||||
if cond_key == "denoise_mask" and guide_count > 0:
|
||||
cond_tensor = cond_value.cond
|
||||
T_video = cond_tensor.size(window.dim) - guide_count
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
sliced_guide = window.get_tensor(guide_mask, device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
|
||||
if cond_key == "keyframe_idxs" and guide_count > 0:
|
||||
window_len = len(window.index_list)
|
||||
H, W = window.guide_spatial
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
if cond_key == "guide_attention_entries" and guide_count > 0:
|
||||
window_len = len(window.index_list)
|
||||
H, W = window.guide_spatial
|
||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
||||
return cond_value._copy_with(new_entries)
|
||||
# Guide handling (shared with LTXV)
|
||||
result = self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user