mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
LTX2 context windows - Refactor guide logic from context_windows into LTXAV model hooks
This commit is contained in:
parent
350237618d
commit
874690c01c
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
@ -181,6 +181,12 @@ def _compute_guide_overlap(guide_entries, window_index_list):
|
||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WindowingContext:
|
||||
tensor: torch.Tensor
|
||||
suffix: torch.Tensor | None
|
||||
aux_data: Any
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@ -242,18 +248,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
def _get_guide_entries(self, conds):
|
||||
"""Extract guide_attention_entries list from conditioning. Returns None if absent."""
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
gae = model_conds.get('guide_attention_entries')
|
||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||
return gae.cond
|
||||
return None
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
primary = self._decompose(x_in, latent_shapes)[0]
|
||||
@ -379,24 +373,19 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
is_multimodal = len(modalities) > 1
|
||||
primary = modalities[0]
|
||||
|
||||
# Separate guide frames from primary modality (guides are appended at the end)
|
||||
guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0
|
||||
if guide_count > 0:
|
||||
video_len = primary.size(self.dim) - guide_count
|
||||
video_primary = primary.narrow(self.dim, 0, video_len)
|
||||
guide_suffix = primary.narrow(self.dim, video_len, guide_count)
|
||||
else:
|
||||
video_primary = primary
|
||||
guide_suffix = None
|
||||
# Let model strip auxiliary frames (e.g. guide frames)
|
||||
window_data = model.prepare_for_windowing(primary, conds, self.dim)
|
||||
video_primary = window_data.tensor
|
||||
aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0
|
||||
|
||||
# Windows from video portion only (excluding guide frames)
|
||||
# Windows from video portion only
|
||||
context_windows = self.get_context_windows(model, video_primary, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
total_windows = len(enumerated_context_windows)
|
||||
|
||||
# Accumulators sized to video portion for primary, full for other modalities
|
||||
accum_modalities = list(modalities)
|
||||
if guide_suffix is not None:
|
||||
if window_data.suffix is not None:
|
||||
accum_modalities[0] = video_primary
|
||||
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities]
|
||||
@ -406,25 +395,22 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities]
|
||||
|
||||
guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
|
||||
+ (f" (+{guide_count} guide)" if guide_count > 0 else "")
|
||||
+ (f" (+{aux_count} aux)" if aux_count > 0 else "")
|
||||
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
|
||||
|
||||
# Per-modality window indices
|
||||
if is_multimodal:
|
||||
# Adjust latent_shapes so video shape reflects video-only frames (excludes guides)
|
||||
map_shapes = latent_shapes
|
||||
if guide_count > 0:
|
||||
if video_primary.size(self.dim) != primary.size(self.dim):
|
||||
map_shapes = list(latent_shapes)
|
||||
video_shape = list(latent_shapes[0])
|
||||
video_shape[self.dim] = video_shape[self.dim] - guide_count
|
||||
video_shape[self.dim] = video_primary.size(self.dim)
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
per_mod_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
|
||||
@ -446,30 +432,10 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
for mod_idx in range(1, len(modalities)):
|
||||
mod_windows.append(modality_windows[mod_idx])
|
||||
|
||||
# Slice video, then select overlapping guide frames
|
||||
# Slice video, then let model inject auxiliary frames
|
||||
sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list)
|
||||
num_guide_in_window = 0
|
||||
if guide_suffix is not None and guide_entries is not None:
|
||||
overlap = _compute_guide_overlap(guide_entries, window.index_list)
|
||||
if overlap[3] > 0:
|
||||
suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap
|
||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||
sliced_guide = guide_suffix[idx]
|
||||
window.guide_suffix_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
else:
|
||||
sliced_guide = None
|
||||
window.guide_suffix_indices = []
|
||||
window.guide_overlap_info = []
|
||||
window.guide_kf_local_positions = []
|
||||
else:
|
||||
sliced_guide = None
|
||||
|
||||
if sliced_guide is not None:
|
||||
sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim)
|
||||
else:
|
||||
sliced_primary = sliced_video
|
||||
sliced_primary, num_aux = model.prepare_window_input(
|
||||
sliced_video, window, window_data.aux_data, self.dim)
|
||||
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
|
||||
|
||||
# Compose for pipeline
|
||||
@ -481,7 +447,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
# Resize conds using video_primary as reference (excludes guide frames)
|
||||
sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds]
|
||||
if is_multimodal:
|
||||
self._patch_latent_shapes(sub_conds, sub_shapes)
|
||||
@ -490,14 +455,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
|
||||
# Decompose output per modality
|
||||
out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
# out_per_mod[cond_idx][mod_idx] = tensor
|
||||
|
||||
# Strip guide frames from primary output before accumulation
|
||||
if num_guide_in_window > 0:
|
||||
# Strip auxiliary frames from primary output before accumulation
|
||||
if num_aux > 0:
|
||||
window_len = len(window.index_list)
|
||||
for ci in range(len(sub_conds_out)):
|
||||
primary_out = out_per_mod[ci][0]
|
||||
out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len)
|
||||
out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len)
|
||||
|
||||
# Accumulate per modality (using video-only sizes)
|
||||
for mod_idx in range(len(accum_modalities)):
|
||||
@ -516,10 +479,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||
f = accum[mod_idx][ci]
|
||||
# Re-append original guide_suffix (not model output — sampling loop
|
||||
# respects denoise_mask and never modifies guide frame positions)
|
||||
if mod_idx == 0 and guide_suffix is not None:
|
||||
f = torch.cat([f, guide_suffix], dim=self.dim)
|
||||
# Re-append model's suffix (auxiliary frames stripped before windowing)
|
||||
if mod_idx == 0 and window_data.suffix is not None:
|
||||
f = torch.cat([f, window_data.suffix], dim=self.dim)
|
||||
finalized.append(f)
|
||||
composed, _ = self._compose(finalized)
|
||||
result.append(composed)
|
||||
|
||||
@ -287,6 +287,12 @@ class BaseModel(torch.nn.Module):
|
||||
return data
|
||||
return None
|
||||
|
||||
def prepare_for_windowing(self, primary, conds, dim):
|
||||
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
|
||||
|
||||
def prepare_window_input(self, video_slice, window, aux_data, dim):
|
||||
return video_slice, 0
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
||||
Return a sliced cond object, or None to fall through to default handling.
|
||||
@ -1113,6 +1119,51 @@ class LTXAV(BaseModel):
|
||||
return sum(e["latent_shape"][0] for e in gae.cond)
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
gae = model_conds.get('guide_attention_entries')
|
||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||
return gae.cond
|
||||
return None
|
||||
|
||||
def prepare_for_windowing(self, primary, conds, dim):
|
||||
guide_count = self.get_guide_frame_count(primary, conds)
|
||||
if guide_count <= 0:
|
||||
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
|
||||
video_len = primary.size(dim) - guide_count
|
||||
video_primary = primary.narrow(dim, 0, video_len)
|
||||
guide_suffix = primary.narrow(dim, video_len, guide_count)
|
||||
guide_entries = self._get_guide_entries(conds)
|
||||
return comfy.context_windows.WindowingContext(
|
||||
tensor=video_primary, suffix=guide_suffix,
|
||||
aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix})
|
||||
|
||||
def prepare_window_input(self, video_slice, window, aux_data, dim):
|
||||
if aux_data is None:
|
||||
return video_slice, 0
|
||||
guide_entries = aux_data["guide_entries"]
|
||||
guide_suffix = aux_data["guide_suffix"]
|
||||
if guide_entries is None:
|
||||
window.guide_suffix_indices = []
|
||||
window.guide_overlap_info = []
|
||||
window.guide_kf_local_positions = []
|
||||
return video_slice, 0
|
||||
overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list)
|
||||
suffix_idx, overlap_info, kf_local_pos, num_guide = overlap
|
||||
window.guide_suffix_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
if num_guide > 0:
|
||||
idx = tuple([slice(None)] * dim + [suffix_idx])
|
||||
sliced_guide = guide_suffix[idx]
|
||||
return torch.cat([video_slice, sliced_guide], dim=dim), num_guide
|
||||
return video_slice, 0
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# Audio denoise mask — slice using audio modality window
|
||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user