mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
LTX2 context windows part 2b - Calculate guide parameters in model code, refactor
This commit is contained in:
parent
56de390c25
commit
941d50e777
@ -358,11 +358,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
|
|
||||||
for window_idx, window in enumerated_context_windows:
|
for window_idx, window in enumerated_context_windows:
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
|
||||||
# Attach guide info to window for resize_cond_for_context_window
|
+ (f" (+{guide_count} guide)" if guide_count > 0 else "")
|
||||||
window.guide_count = guide_count
|
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
|
||||||
if guide_suffix is not None:
|
|
||||||
window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4])
|
|
||||||
|
|
||||||
# Per-modality window indices
|
# Per-modality window indices
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
@ -384,9 +382,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
window = IndexListContextWindow(
|
window = IndexListContextWindow(
|
||||||
window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim],
|
window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim],
|
||||||
modality_windows=modality_windows)
|
modality_windows=modality_windows)
|
||||||
window.guide_count = guide_count
|
|
||||||
if guide_suffix is not None:
|
|
||||||
window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4])
|
|
||||||
else:
|
else:
|
||||||
per_mod_indices = [window.index_list]
|
per_mod_indices = [window.index_list]
|
||||||
|
|
||||||
|
|||||||
@ -303,6 +303,45 @@ class BaseModel(torch.nn.Module):
|
|||||||
Override in subclasses that concatenate guide reference frames to the latent."""
|
Override in subclasses that concatenate guide reference frames to the latent."""
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
|
"""Resize guide-related conditioning for context windows.
|
||||||
|
Derives guide_count from denoise_mask/x_in size difference.
|
||||||
|
Derives spatial dims from x_in. Requires self.diffusion_model.patchifier."""
|
||||||
|
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
cond_tensor = cond_value.cond
|
||||||
|
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||||
|
if guide_count > 0:
|
||||||
|
T_video = x_in.size(window.dim)
|
||||||
|
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||||
|
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||||
|
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||||
|
sliced_guide = window.get_tensor(guide_mask, device)
|
||||||
|
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||||
|
|
||||||
|
if cond_key == "keyframe_idxs":
|
||||||
|
window_len = len(window.index_list)
|
||||||
|
H, W = x_in.shape[3], x_in.shape[4]
|
||||||
|
patchifier = self.diffusion_model.patchifier
|
||||||
|
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||||
|
pixel_coords = latent_to_pixel_coords(
|
||||||
|
latent_coords,
|
||||||
|
self.diffusion_model.vae_scale_factors,
|
||||||
|
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||||
|
B = cond_value.cond.shape[0]
|
||||||
|
if B > 1:
|
||||||
|
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||||
|
return cond_value._copy_with(pixel_coords)
|
||||||
|
|
||||||
|
if cond_key == "guide_attention_entries":
|
||||||
|
window_len = len(window.index_list)
|
||||||
|
H, W = x_in.shape[3], x_in.shape[4]
|
||||||
|
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
||||||
|
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
||||||
|
return cond_value._copy_with(new_entries)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
concat_cond = self.concat_cond(**kwargs)
|
concat_cond = self.concat_cond(**kwargs)
|
||||||
@ -1038,44 +1077,7 @@ class LTXV(BaseModel):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
guide_count = getattr(window, 'guide_count', 0)
|
return self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list)
|
||||||
|
|
||||||
if cond_key == "denoise_mask" and guide_count > 0:
|
|
||||||
# Slice both video and guide halves with same window indices
|
|
||||||
cond_tensor = cond_value.cond
|
|
||||||
T_video = cond_tensor.size(window.dim) - guide_count
|
|
||||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
|
||||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
|
||||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
|
||||||
sliced_guide = window.get_tensor(guide_mask, device)
|
|
||||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
|
||||||
|
|
||||||
if cond_key == "keyframe_idxs" and guide_count > 0:
|
|
||||||
# Recompute coords for window_len frames so guide tokens are co-located
|
|
||||||
# with noise tokens in RoPE space (identical to a standalone short video)
|
|
||||||
window_len = len(window.index_list)
|
|
||||||
H, W = window.guide_spatial
|
|
||||||
patchifier = self.diffusion_model.patchifier
|
|
||||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
|
||||||
pixel_coords = latent_to_pixel_coords(
|
|
||||||
latent_coords,
|
|
||||||
self.diffusion_model.vae_scale_factors,
|
|
||||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
|
||||||
B = cond_value.cond.shape[0]
|
|
||||||
if B > 1:
|
|
||||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
|
||||||
return cond_value._copy_with(pixel_coords)
|
|
||||||
|
|
||||||
if cond_key == "guide_attention_entries" and guide_count > 0:
|
|
||||||
# Adjust token counts for window size
|
|
||||||
window_len = len(window.index_list)
|
|
||||||
H, W = window.guide_spatial
|
|
||||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
|
||||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
|
||||||
return cond_value._copy_with(new_entries)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
class LTXAV(BaseModel):
|
class LTXAV(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
@ -1083,7 +1085,6 @@ class LTXAV(BaseModel):
|
|||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
logging.info(f"LTXAV.extra_conds: guide_attention_entries={'guide_attention_entries' in kwargs}, keyframe_idxs={'keyframe_idxs' in kwargs}")
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
|
|
||||||
@ -1170,12 +1171,8 @@ class LTXAV(BaseModel):
|
|||||||
for cond_dict in cond_list:
|
for cond_dict in cond_list:
|
||||||
model_conds = cond_dict.get('model_conds', {})
|
model_conds = cond_dict.get('model_conds', {})
|
||||||
gae = model_conds.get('guide_attention_entries')
|
gae = model_conds.get('guide_attention_entries')
|
||||||
logging.info(f"LTXAV.get_guide_frame_count: keys={list(model_conds.keys())}, gae={gae is not None}")
|
|
||||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||||
count = sum(e["latent_shape"][0] for e in gae.cond)
|
return sum(e["latent_shape"][0] for e in gae.cond)
|
||||||
logging.info(f"LTXAV.get_guide_frame_count: found {count} guide frames")
|
|
||||||
return count
|
|
||||||
logging.info("LTXAV.get_guide_frame_count: no guide frames found")
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
@ -1186,41 +1183,10 @@ class LTXAV(BaseModel):
|
|||||||
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||||
return cond_value._copy_with(sliced)
|
return cond_value._copy_with(sliced)
|
||||||
|
|
||||||
# Guide handling (same as LTXV — shared guide mechanism)
|
# Guide handling (shared with LTXV)
|
||||||
guide_count = getattr(window, 'guide_count', 0)
|
result = self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list)
|
||||||
if cond_key in ("keyframe_idxs", "guide_attention_entries", "denoise_mask"):
|
if result is not None:
|
||||||
logging.info(f"LTXAV resize_cond: {cond_key}, guide_count={guide_count}, has_spatial={hasattr(window, 'guide_spatial')}")
|
return result
|
||||||
|
|
||||||
if cond_key == "denoise_mask" and guide_count > 0:
|
|
||||||
cond_tensor = cond_value.cond
|
|
||||||
T_video = cond_tensor.size(window.dim) - guide_count
|
|
||||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
|
||||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
|
||||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
|
||||||
sliced_guide = window.get_tensor(guide_mask, device)
|
|
||||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
|
||||||
|
|
||||||
if cond_key == "keyframe_idxs" and guide_count > 0:
|
|
||||||
window_len = len(window.index_list)
|
|
||||||
H, W = window.guide_spatial
|
|
||||||
patchifier = self.diffusion_model.patchifier
|
|
||||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
|
||||||
pixel_coords = latent_to_pixel_coords(
|
|
||||||
latent_coords,
|
|
||||||
self.diffusion_model.vae_scale_factors,
|
|
||||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
|
||||||
B = cond_value.cond.shape[0]
|
|
||||||
if B > 1:
|
|
||||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
|
||||||
return cond_value._copy_with(pixel_coords)
|
|
||||||
|
|
||||||
if cond_key == "guide_attention_entries" and guide_count > 0:
|
|
||||||
window_len = len(window.index_list)
|
|
||||||
H, W = window.guide_spatial
|
|
||||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
|
||||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
|
||||||
return cond_value._copy_with(new_entries)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user