mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
LTX2 context windows - Clean up unnecessary code
This commit is contained in:
parent
c9edd2d7c0
commit
d5badc5f38
@ -182,22 +182,22 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
|
||||
|
||||
|
||||
def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC,
|
||||
aux_data: dict, dim: int) -> tuple[torch.Tensor, int]:
|
||||
window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]:
|
||||
"""Inject overlapping guide frames into a context window slice.
|
||||
|
||||
Uses aux_data from WindowingContext to determine which guide frames overlap
|
||||
with this window's indices, concatenates them onto the video slice, and sets
|
||||
window attributes for downstream conditioning resize.
|
||||
Determines which guide frames overlap with this window's indices, concatenates
|
||||
them onto the video slice, and sets window attributes for downstream conditioning resize.
|
||||
|
||||
Returns (augmented_slice, num_guide_frames_added).
|
||||
"""
|
||||
guide_entries = aux_data["guide_entries"]
|
||||
guide_frames = aux_data["guide_frames"]
|
||||
guide_entries = window_data.aux_data["guide_entries"]
|
||||
guide_frames = window_data.guide_frames
|
||||
overlap = compute_guide_overlap(guide_entries, window.index_list)
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap
|
||||
window.guide_frames_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
|
||||
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
|
||||
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
|
||||
guide_downscale_factors = []
|
||||
@ -207,6 +207,7 @@ def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWi
|
||||
entry_H = guide_entries[entry_idx]["latent_shape"][1]
|
||||
guide_downscale_factors.append(full_H // entry_H)
|
||||
window.guide_downscale_factors = guide_downscale_factors
|
||||
|
||||
if guide_frame_count > 0:
|
||||
idx = tuple([slice(None)] * dim + [suffix_idx])
|
||||
sliced_guide = guide_frames[idx]
|
||||
@ -220,7 +221,6 @@ class WindowingContext:
|
||||
guide_frames: torch.Tensor | None
|
||||
aux_data: Any
|
||||
latent_shapes: list | None
|
||||
is_multimodal: bool
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
@ -310,13 +310,13 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None
|
||||
|
||||
if guide_frame_count > 0:
|
||||
aux_data = {"guide_entries": guide_entries, "guide_frames": guide_frames}
|
||||
aux_data = {"guide_entries": guide_entries}
|
||||
else:
|
||||
aux_data = None
|
||||
|
||||
return WindowingContext(
|
||||
tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data,
|
||||
latent_shapes=latent_shapes, is_multimodal=is_multimodal)
|
||||
latent_shapes=latent_shapes)
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
self._window_data = self._build_window_data(x_in, conds)
|
||||
@ -437,9 +437,14 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.set_step(timestep, model_options)
|
||||
|
||||
window_data = self._window_data
|
||||
if window_data.is_multimodal or (window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0):
|
||||
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
|
||||
has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0
|
||||
|
||||
# if multimodal or has concatenated guide frames on noise latent, use the extended execute path
|
||||
if is_multimodal or has_guide_frames:
|
||||
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data)
|
||||
|
||||
# basic legacy execution path for single-modal video latent with no guide frames concatenated
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
@ -475,8 +480,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
timestep: torch.Tensor, model_options: dict[str],
|
||||
window_data: WindowingContext):
|
||||
"""Extended execute path for multimodal models and models with guide frames appended to the noise latent."""
|
||||
|
||||
latents = self._unpack(x_in, window_data.latent_shapes)
|
||||
is_multimodal = window_data.is_multimodal
|
||||
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
|
||||
|
||||
primary_frames = window_data.tensor
|
||||
num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0
|
||||
@ -538,7 +544,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# Slice video, then inject overlapping guide frames if present
|
||||
sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list)
|
||||
if window_data.aux_data is not None:
|
||||
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data.aux_data, self.dim)
|
||||
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim)
|
||||
else:
|
||||
sliced_primary, num_guide_frames = sliced_video, 0
|
||||
sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))]
|
||||
|
||||
@ -1113,23 +1113,6 @@ class LTXAV(BaseModel):
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
def prepare_window_data(self, x_in, conds, dim, window_data):
|
||||
primary = comfy.utils.unpack_latents(x_in, window_data.latent_shapes)[0] if window_data.is_multimodal else x_in
|
||||
guide_entries = self._get_guide_entries(conds)
|
||||
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
|
||||
if guide_count <= 0:
|
||||
return comfy.context_windows.WindowingContext(
|
||||
tensor=primary, guide_frames=None, aux_data=None,
|
||||
latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal)
|
||||
video_len = primary.size(dim) - guide_count
|
||||
video_primary = primary.narrow(dim, 0, video_len)
|
||||
guide_frames = primary.narrow(dim, video_len, guide_count)
|
||||
return comfy.context_windows.WindowingContext(
|
||||
tensor=video_primary, guide_frames=guide_frames,
|
||||
aux_data={"guide_entries": guide_entries, "guide_frames": guide_frames},
|
||||
latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal)
|
||||
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# Audio denoise mask — slice using audio modality window
|
||||
|
||||
Loading…
Reference in New Issue
Block a user