LTX2 Context Windows - Collect multimodal methods into WindowingState; Condense execution path to treat all latents as potentially multimodal

This commit is contained in:
ozbayb 2026-04-11 11:31:04 -06:00
parent 88643f3978
commit ae3830a6d2
2 changed files with 245 additions and 258 deletions

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Callable
import torch import torch
import numpy as np import numpy as np
import collections import collections
@ -60,6 +60,10 @@ class IndexListContextWindow(ContextWindowABC):
self.total_frames = total_frames self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
self.guide_frames_indices: list[int] = []
self.guide_overlap_info: list[tuple[int, int]] = []
self.guide_kf_local_positions: list[int] = []
self.guide_downscale_factors: list[int] = []
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None: if dim is None:
@ -84,6 +88,11 @@ class IndexListContextWindow(ContextWindowABC):
region_idx = int(self.center_ratio * num_regions) region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1) return min(max(region_idx, 0), num_regions - 1)
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
if modality_idx == 0:
return self
return self.modality_windows[modality_idx]
class IndexListCallbacks: class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@ -181,46 +190,109 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC,
window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]:
"""Inject overlapping guide frames into a context window slice.
Determines which guide frames overlap with this window's indices, concatenates
them onto the video slice, and sets window attributes for downstream conditioning resize.
Returns (augmented_slice, num_guide_frames_added).
"""
guide_entries = window_data.aux_data["guide_entries"]
guide_frames = window_data.guide_frames
overlap = compute_guide_overlap(guide_entries, window.index_list)
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap
window.guide_frames_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
guide_downscale_factors = []
if guide_frame_count > 0:
full_H = guide_frames.shape[3]
for entry_idx, _ in overlap_info:
entry_H = guide_entries[entry_idx]["latent_shape"][1]
guide_downscale_factors.append(full_H // entry_H)
window.guide_downscale_factors = guide_downscale_factors
if guide_frame_count > 0:
idx = tuple([slice(None)] * dim + [suffix_idx])
sliced_guide = guide_frames[idx]
return torch.cat([video_slice, sliced_guide], dim=dim), guide_frame_count
return video_slice, 0
@dataclass @dataclass
class WindowingContext: class WindowingState:
tensor: torch.Tensor """Per-modality context windowing state for each step,
guide_frames: torch.Tensor | None built using IndexListContextHandler._build_window_state().
aux_data: Any For non-multimodal models the lists are length 1
latent_shapes: list | None """
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
dim: int = 0 # primary modality temporal dim for context windowing
is_multimodal: bool = False
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
"""Reformat window for multimodal contexts by deriving per-modality index lists.
Non-multimodal contexts return the input window unchanged.
"""
if not self.is_multimodal:
return window
x = self.latents[0]
map_shapes = self.latent_shapes
if x.size(self.dim) != self.latent_shapes[0][self.dim]:
map_shapes = list(self.latent_shapes)
video_shape = list(self.latent_shapes[0])
video_shape[self.dim] = x.size(self.dim)
map_shapes[0] = torch.Size(video_shape)
try:
per_modality_indices = model.map_context_window_to_modalities(
window.index_list, map_shapes, self.dim)
except AttributeError:
raise NotImplementedError(
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
modality_windows = {}
for mod_idx in range(1, len(self.latents)):
modality_windows[mod_idx] = IndexListContextWindow(
per_modality_indices[mod_idx], dim=self.dim,
total_frames=self.latents[mod_idx].shape[self.dim])
return IndexListContextWindow(
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
modality_windows=modality_windows)
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
"""Slice latents for a context window, injecting guide frames where applicable.
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
"""
sliced = []
guide_frame_counts = []
for idx in range(len(self.latents)):
modality_window = window.get_window_for_modality(idx)
retain = retain_index_list if idx == 0 else []
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
if self.guide_entries[idx] is not None:
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
else:
ng = 0
sliced.append(s)
guide_frame_counts.append(ng)
return sliced, guide_frame_counts
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
for idx in range(len(self.latents)):
if guide_frame_counts[idx] > 0:
window_len = len(window.get_window_for_modality(idx).index_list)
for ci in range(len(out_per_modality)):
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
guide_entries = self.guide_entries[modality_idx]
guide_frames = self.guide_latents[modality_idx]
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(guide_entries, window.index_list)
window.guide_frames_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
guide_downscale_factors = []
if guide_frame_count > 0:
full_H = guide_frames.shape[3]
for entry_idx, _ in overlap_info:
entry_H = guide_entries[entry_idx]["latent_shape"][1]
guide_downscale_factors.append(full_H // entry_H)
window.guide_downscale_factors = guide_downscale_factors
if guide_frame_count > 0:
idx = tuple([slice(None)] * self.dim + [suffix_idx])
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
return latent_slice, 0
def patch_latent_shapes(self, sub_conds, new_shapes):
if not self.is_multimodal:
return
for cond_list in sub_conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
@dataclass @dataclass
class ContextSchedule: class ContextSchedule:
@ -261,37 +333,35 @@ class IndexListContextHandler(ContextHandlerABC):
return model_conds['latent_shapes'].cond return model_conds['latent_shapes'].cond
return None return None
@staticmethod def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
def _unpack(combined_latent, latent_shapes): """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio."""
latent_shapes = self._get_latent_shapes(conds)
if latent_shapes is not None and len(latent_shapes) > 1: if latent_shapes is not None and len(latent_shapes) > 1:
return comfy.utils.unpack_latents(combined_latent, latent_shapes) modalities = comfy.utils.unpack_latents(noise, latent_shapes)
return [combined_latent] primary_total = latent_shapes[0][self.dim]
modalities[0] = apply_freenoise(modalities[0], self.dim, self.context_length, self.context_overlap, seed)
for i in range(1, len(modalities)):
mod_total = latent_shapes[i][self.dim]
ratio = mod_total / primary_total if primary_total > 0 else 1
mod_ctx_len = max(round(self.context_length * ratio), 1)
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
noise, _ = comfy.utils.pack_latents(modalities)
return noise
return apply_freenoise(noise, self.dim, self.context_length, self.context_overlap, seed)
@staticmethod def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState:
def _pack(latents): """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
if len(latents) > 1:
return comfy.utils.pack_latents(latents)
return latents[0], [latents[0].shape]
@staticmethod
def _patch_latent_shapes(sub_conds, new_shapes):
for cond_list in sub_conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingContext:
latent_shapes = self._get_latent_shapes(conds) latent_shapes = self._get_latent_shapes(conds)
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
if is_multimodal: unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
video_latent = comfy.utils.unpack_latents(x_in, latent_shapes)[0]
else:
video_latent = x_in
guide_entries = None unpacked_latents_list = list(unpacked_latents)
guide_latents_list = [None] * len(unpacked_latents)
guide_entries_list = [None] * len(unpacked_latents)
# Scan for 'guide_attention_entries' in conds
extracted_guide_entries = None
for cond_list in conds: for cond_list in conds:
if cond_list is None: if cond_list is None:
continue continue
@ -299,37 +369,39 @@ class IndexListContextHandler(ContextHandlerABC):
model_conds = cond_dict.get('model_conds', {}) model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries') entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond: if entries is not None and hasattr(entries, 'cond') and entries.cond:
guide_entries = entries.cond extracted_guide_entries = entries.cond
break break
if guide_entries is not None: if extracted_guide_entries is not None:
break break
guide_frame_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries is not None else 0 # Strip guide frames (only from first modality for now)
primary_frame_count = video_latent.size(self.dim) - guide_frame_count if extracted_guide_entries is not None:
primary_frames = video_latent.narrow(self.dim, 0, primary_frame_count) guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None if guide_count > 0:
x = unpacked_latents[0]
latent_count = x.size(self.dim) - guide_count
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
guide_entries_list[0] = extracted_guide_entries
if guide_frame_count > 0:
aux_data = {"guide_entries": guide_entries}
else:
aux_data = None
return WindowingContext( return WindowingState(
tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data, latents=unpacked_latents_list,
latent_shapes=latent_shapes) guide_latents=guide_latents_list,
guide_entries=guide_entries_list,
latent_shapes=latent_shapes,
dim=self.dim,
is_multimodal=is_multimodal)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
self._window_data = self._build_window_data(x_in, conds) window_state = self._build_window_state(x_in, conds) # build window_state to check frame counts, will be built again in execute
video_frames = self._window_data.tensor.size(self.dim) total_frame_count = window_state.latents[0].size(self.dim)
guide_frames = self._window_data.guide_frames.size(self.dim) if self._window_data.guide_frames is not None else 0 if total_frame_count > self.context_length:
if video_frames > self.context_length: logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
if guide_frames > 0:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_frames} guide frames).")
else:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.")
if self.cond_retain_index_list: if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
return True return True
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
return False return False
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
@ -436,188 +508,121 @@ class IndexListContextHandler(ContextHandlerABC):
self._model = model self._model = model
self.set_step(timestep, model_options) self.set_step(timestep, model_options)
window_data = self._window_data window_state = self._build_window_state(x_in, conds)
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 num_modalities = len(window_state.latents)
has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0
# if multimodal or has concatenated guide frames on noise latent, use the extended execute path context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
if is_multimodal or has_guide_frames:
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data)
# basic legacy execution path for single-modal video latent with no guide frames concatenated
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
conds_final = [torch.zeros_like(x_in) for _ in conds]
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
else:
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
conds_final, counts_final, biases_final)
try:
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
del counts_final
return conds_final
else:
for i in range(len(conds_final)):
conds_final[i] /= counts_final[i]
del counts_final
return conds_final
finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor,
timestep: torch.Tensor, model_options: dict[str],
window_data: WindowingContext):
"""Extended execute path for multimodal models and models with guide frames appended to the noise latent."""
latents = self._unpack(x_in, window_data.latent_shapes)
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
primary_frames = window_data.tensor
context_windows = self.get_context_windows(model, primary_frames, model_options)
enumerated_context_windows = list(enumerate(context_windows)) enumerated_context_windows = list(enumerate(context_windows))
total_windows = len(enumerated_context_windows) total_windows = len(enumerated_context_windows)
# Accumulators sized to video portion for primary, full for other modalities # Initialize per-modality accumulators (length 1 for single-modality)
accum_shape_refs = list(latents) accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
if window_data.guide_frames is not None:
accum_shape_refs[0] = primary_frames
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_shape_refs]
if self.fuse_method.name == ContextFuseMethods.RELATIVE: if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
else: else:
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_shape_refs] biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options) callback(self, model, x_in, conds, timestep, model_options)
for window_idx, window in enumerated_context_windows: # accumulate results from each context window
comfy.model_management.throw_exception_if_processing_interrupted() for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(
# Per-modality window indices calc_cond_batch, model, x_in, conds, timestep, [enum_window],
if is_multimodal: model_options, window_state=window_state, total_windows=total_windows)
map_shapes = window_data.latent_shapes for result in results:
if primary_frames.size(self.dim) != latents[0].size(self.dim): # result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
map_shapes = list(window_data.latent_shapes) for mod_idx in range(num_modalities):
video_shape = list(window_data.latent_shapes[0]) mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
video_shape[self.dim] = primary_frames.size(self.dim) modality_window = result.window.get_window_for_modality(mod_idx)
map_shapes[0] = torch.Size(video_shape) self.combine_context_window_results(
try: window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
per_modality_indices = model.map_context_window_to_modalities( result.window_idx, total_windows, timestep,
window.index_list, map_shapes, self.dim) accum[mod_idx], counts[mod_idx], biases[mod_idx])
except AttributeError:
raise NotImplementedError(
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
modality_windows = {}
for mod_idx in range(1, len(latents)):
modality_windows[mod_idx] = IndexListContextWindow(
per_modality_indices[mod_idx], dim=self.dim,
total_frames=latents[mod_idx].shape[self.dim])
window = IndexListContextWindow(
window.index_list, dim=self.dim, total_frames=primary_frames.shape[self.dim],
modality_windows=modality_windows)
# Build per-modality windows list
per_modality_windows_list = [window]
if is_multimodal:
for mod_idx in range(1, len(latents)):
per_modality_windows_list.append(modality_windows[mod_idx])
# Slice video, then inject overlapping guide frames if present
sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list)
if window_data.aux_data is not None:
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim)
else:
sliced_primary, num_guide_frames = sliced_video, 0
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}"
+ (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "")
+ (f" [{len(latents)} modalities]" if is_multimodal else ""))
sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))]
sub_x, sub_shapes = self._pack(sliced)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None)
model_options["transformer_options"]["context_window"] = window
sub_timestep = window.get_tensor(timestep, dim=0)
sub_conds = [self.get_resized_cond(cond, primary_frames, window) for cond in conds]
if is_multimodal:
self._patch_latent_shapes(sub_conds, sub_shapes)
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
# Unpack output per modality
out_per_modality = [self._unpack(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
# Strip guide frames from primary output before accumulation
if num_guide_frames > 0:
window_len = len(window.index_list)
for ci in range(len(sub_conds_out)):
out_per_modality[ci][0] = out_per_modality[ci][0].narrow(self.dim, 0, window_len)
# Accumulate per modality
for mod_idx in range(len(accum_shape_refs)):
mw = per_modality_windows_list[mod_idx]
sub_conds_out_per_modality = [out_per_modality[ci][mod_idx] for ci in range(len(sub_conds_out))]
self.combine_context_window_results(
accum_shape_refs[mod_idx], sub_conds_out_per_modality, sub_conds, mw,
window_idx, total_windows, timestep,
accum[mod_idx], counts[mod_idx], biases[mod_idx])
# fuse accumulated results into final conds
try: try:
result = [] result_out = []
for ci in range(len(conds)): for ci in range(len(conds)):
finalized = [] finalized = []
for mod_idx in range(len(accum_shape_refs)): for mod_idx in range(num_modalities):
if self.fuse_method.name != ContextFuseMethods.RELATIVE: if self.fuse_method.name != ContextFuseMethods.RELATIVE:
accum[mod_idx][ci] /= counts[mod_idx][ci] accum[mod_idx][ci] /= counts[mod_idx][ci]
f = accum[mod_idx][ci] f = accum[mod_idx][ci]
if mod_idx == 0 and window_data.guide_frames is not None:
f = torch.cat([f, window_data.guide_frames], dim=self.dim) # if guide frames were injected, append them to the end of the fused latents for the next step
if window_state.guide_latents[mod_idx] is not None:
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
finalized.append(f) finalized.append(f)
packed, _ = self._pack(finalized)
result.append(packed) # pack modalities together if needed
return result if window_state.is_multimodal and len(finalized) > 1:
packed, _ = comfy.utils.pack_latents(finalized)
else:
packed = finalized[0]
result_out.append(packed)
return result_out
finally: finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options) callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
model_options, device=None, first_device=None): timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, window_state: WindowingState, total_windows: int = None,
device=None, first_device=None):
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
For each window:
1. Builds windows (for each modality if multimodal)
2. Slices window for each modality
3. Injects concatenated latent guide frames where present
4. Packs together if needed and calls model
5. Unpacks and strips any guides from outputs
"""
x = window_state.latents[0]
results: list[ContextResults] = [] results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows: for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel # allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
# prepare the window accounting for multimodal windows
window = window_state.prepare_window(window, model)
# slice the window for each modality, injecting guide frames where applicable
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.cond_retain_index_list, device)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
model_options["transformer_options"]["context_window"] = window + (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
# get subsections of x, timestep, conds )
sub_x = window.get_tensor(x_in, device)
sub_timestep = window.get_tensor(timestep, device, dim=0)
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
# if multimodal, pack modalities together
if window_state.is_multimodal and len(sliced) > 1:
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
else:
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
# get resized conds for window
model_options["transformer_options"]["context_window"] = window
sub_timestep = window.get_tensor(timestep, dim=0)
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
# if multimodal, patch latent_shapes in conds for correct unpacking in model
window_state.patch_latent_shapes(sub_conds, sub_shapes)
# call model on window
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
if device is not None:
for i in range(len(sub_conds_out)): # unpack outputs and strip guide frames
sub_conds_out[i] = sub_conds_out[i].to(x_in.device) out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
return results return results
@ -684,28 +689,11 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
if not handler.freenoise: if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
# For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
# must only shuffle the video portion. Unpack, apply to video, repack. noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
latent_shapes = IndexListContextHandler._get_latent_shapes(
[guider.conds.get('positive', guider.conds.get('negative', []))])
if latent_shapes is not None and len(latent_shapes) > 1:
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
video_total = latent_shapes[0][handler.dim]
modalities[0] = apply_freenoise(modalities[0], handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
for i in range(1, len(modalities)):
mod_total = latent_shapes[i][handler.dim]
ratio = mod_total / video_total if video_total > 0 else 1
mod_ctx_len = max(round(handler.context_length * ratio), 1)
mod_ctx_overlap = max(round(handler.context_overlap * ratio), 0)
modalities[i] = apply_freenoise(modalities[i], handler.dim, mod_ctx_len, mod_ctx_overlap, extra_args["seed"])
noise, _ = comfy.utils.pack_latents(modalities)
else:
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher): def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key( model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
@ -713,7 +701,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
_sampler_sample_wrapper _sampler_sample_wrapper
) )
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape) total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device) weights_tensor = torch.Tensor(weights).to(device=device)

View File

@ -1162,7 +1162,7 @@ class LTXAV(BaseModel):
# Adjust spatial end positions for dilated (downscaled) guides. # Adjust spatial end positions for dilated (downscaled) guides.
# Each guide entry may have a different downscale factor; expand the # Each guide entry may have a different downscale factor; expand the
# per-entry factor to cover all tokens belonging to that entry. # per-entry factor to cover all tokens belonging to that entry.
downscale_factors = getattr(window, 'guide_downscale_factors', []) downscale_factors = window.guide_downscale_factors
overlap_info = window.guide_overlap_info overlap_info = window.guide_overlap_info
if downscale_factors: if downscale_factors:
per_token_factor = [] per_token_factor = []