mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
Update compute_guide_overlap to include keyframe indices and temporal downscale ratio instead of using latent_start key added to attention_guide_entries.
- This makes it possible for existing nodes like the ComfyUI-LTXVideo nodes to work with context windows without needing modifications
This commit is contained in:
parent
cbe2a1ba42
commit
75e8e4b6dc
@ -161,11 +161,17 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int]):
|
||||
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
|
||||
"""Compute which concatenated guide frames overlap with a context window.
|
||||
|
||||
Each guide's latent-space start is derived from its first token's pixel-t-start
|
||||
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
|
||||
model's temporal_downscale_ratio.
|
||||
|
||||
Args:
|
||||
guide_entries: list of guide_attention_entry dicts
|
||||
keyframe_idxs: per-token pixel coords cond tensor for the modality
|
||||
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
|
||||
window_index_list: the window's frame indices into the video portion
|
||||
|
||||
Returns:
|
||||
@ -180,11 +186,11 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
|
||||
overlap_info = []
|
||||
kf_local_positions = []
|
||||
suffix_base = 0
|
||||
token_offset = 0
|
||||
|
||||
for entry_idx, entry in enumerate(guide_entries):
|
||||
latent_start = entry.get("latent_start", None)
|
||||
if latent_start is None:
|
||||
raise ValueError("guide_attention_entry missing required 'latent_start'.")
|
||||
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
|
||||
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
|
||||
guide_len = entry["latent_shape"][0]
|
||||
entry_overlap = 0
|
||||
|
||||
@ -198,6 +204,7 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
|
||||
if entry_overlap > 0:
|
||||
overlap_info.append((entry_idx, entry_overlap))
|
||||
suffix_base += guide_len
|
||||
token_offset += entry["pre_filter_count"]
|
||||
|
||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||
|
||||
@ -211,9 +218,11 @@ class WindowingState:
|
||||
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
|
||||
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
|
||||
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
|
||||
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
|
||||
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
|
||||
dim: int = 0 # primary modality temporal dim for context windowing
|
||||
is_multimodal: bool = False
|
||||
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
|
||||
|
||||
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
|
||||
"""Reformat window for multimodal contexts by deriving per-modality index lists.
|
||||
@ -279,7 +288,9 @@ class WindowingState:
|
||||
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
|
||||
guide_entries = self.guide_entries[modality_idx]
|
||||
guide_frames = self.guide_latents[modality_idx]
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(guide_entries, window.index_list)
|
||||
keyframe_idxs = self.keyframe_idxs[modality_idx]
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
|
||||
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
|
||||
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
@ -370,6 +381,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_keyframe_idxs(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
kf = model_conds.get('keyframe_idxs')
|
||||
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
|
||||
return kf.cond
|
||||
return None
|
||||
|
||||
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
|
||||
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
|
||||
If guide frames are present on the primary modality, only the video portion is shuffled.
|
||||
@ -395,7 +418,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
return noise
|
||||
|
||||
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState:
|
||||
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
|
||||
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
|
||||
@ -404,8 +427,10 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
unpacked_latents_list = list(unpacked_latents)
|
||||
guide_latents_list = [None] * len(unpacked_latents)
|
||||
guide_entries_list = [None] * len(unpacked_latents)
|
||||
keyframe_idxs_list = [None] * len(unpacked_latents)
|
||||
|
||||
extracted_guide_entries = self._get_guide_entries(conds)
|
||||
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
|
||||
|
||||
# Strip guide frames (only from first modality for now)
|
||||
if extracted_guide_entries is not None:
|
||||
@ -416,18 +441,21 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
|
||||
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
|
||||
guide_entries_list[0] = extracted_guide_entries
|
||||
keyframe_idxs_list[0] = extracted_keyframe_idxs
|
||||
|
||||
|
||||
return WindowingState(
|
||||
latents=unpacked_latents_list,
|
||||
guide_latents=guide_latents_list,
|
||||
guide_entries=guide_entries_list,
|
||||
keyframe_idxs=keyframe_idxs_list,
|
||||
latent_shapes=latent_shapes,
|
||||
dim=self.dim,
|
||||
is_multimodal=is_multimodal)
|
||||
is_multimodal=is_multimodal,
|
||||
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
window_state = self._build_window_state(x_in, conds) # build window_state to check frame counts, will be built again in execute
|
||||
window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
|
||||
total_frame_count = window_state.latents[0].size(self.dim)
|
||||
if total_frame_count > self.context_length:
|
||||
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
||||
@ -545,7 +573,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
|
||||
window_state = self._build_window_state(x_in, conds)
|
||||
window_state = self._build_window_state(x_in, conds, model)
|
||||
num_modalities = len(window_state.latents)
|
||||
|
||||
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
|
||||
|
||||
@ -136,7 +136,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, latent_start=0):
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||
|
||||
Each entry tracks one guide reference for per-reference attention control.
|
||||
@ -147,7 +147,6 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
|
||||
"strength": strength,
|
||||
"pixel_mask": None,
|
||||
"latent_shape": latent_shape,
|
||||
"latent_start": latent_start,
|
||||
}
|
||||
results = []
|
||||
for cond in (positive, negative):
|
||||
@ -364,7 +363,6 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||
positive, negative = _append_guide_attention_entry(
|
||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||
latent_start=latent_idx,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user