mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
LTX2 Context Windows - Collect multimodal methods into WindowingState; Condense execution path to treat all latents as potentially multimodal
This commit is contained in:
parent
88643f3978
commit
ae3830a6d2
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
@ -60,6 +60,10 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
|
||||
self.guide_frames_indices: list[int] = []
|
||||
self.guide_overlap_info: list[tuple[int, int]] = []
|
||||
self.guide_kf_local_positions: list[int] = []
|
||||
self.guide_downscale_factors: list[int] = []
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
@ -84,6 +88,11 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
|
||||
if modality_idx == 0:
|
||||
return self
|
||||
return self.modality_windows[modality_idx]
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
@ -181,46 +190,109 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
|
||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||
|
||||
|
||||
def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC,
|
||||
window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]:
|
||||
"""Inject overlapping guide frames into a context window slice.
|
||||
|
||||
Determines which guide frames overlap with this window's indices, concatenates
|
||||
them onto the video slice, and sets window attributes for downstream conditioning resize.
|
||||
|
||||
Returns (augmented_slice, num_guide_frames_added).
|
||||
"""
|
||||
guide_entries = window_data.aux_data["guide_entries"]
|
||||
guide_frames = window_data.guide_frames
|
||||
overlap = compute_guide_overlap(guide_entries, window.index_list)
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap
|
||||
window.guide_frames_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
|
||||
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
|
||||
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
|
||||
guide_downscale_factors = []
|
||||
if guide_frame_count > 0:
|
||||
full_H = guide_frames.shape[3]
|
||||
for entry_idx, _ in overlap_info:
|
||||
entry_H = guide_entries[entry_idx]["latent_shape"][1]
|
||||
guide_downscale_factors.append(full_H // entry_H)
|
||||
window.guide_downscale_factors = guide_downscale_factors
|
||||
|
||||
if guide_frame_count > 0:
|
||||
idx = tuple([slice(None)] * dim + [suffix_idx])
|
||||
sliced_guide = guide_frames[idx]
|
||||
return torch.cat([video_slice, sliced_guide], dim=dim), guide_frame_count
|
||||
return video_slice, 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class WindowingContext:
|
||||
tensor: torch.Tensor
|
||||
guide_frames: torch.Tensor | None
|
||||
aux_data: Any
|
||||
latent_shapes: list | None
|
||||
class WindowingState:
|
||||
"""Per-modality context windowing state for each step,
|
||||
built using IndexListContextHandler._build_window_state().
|
||||
For non-multimodal models the lists are length 1
|
||||
"""
|
||||
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
|
||||
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
|
||||
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
|
||||
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
|
||||
dim: int = 0 # primary modality temporal dim for context windowing
|
||||
is_multimodal: bool = False
|
||||
|
||||
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
|
||||
"""Reformat window for multimodal contexts by deriving per-modality index lists.
|
||||
Non-multimodal contexts return the input window unchanged.
|
||||
"""
|
||||
if not self.is_multimodal:
|
||||
return window
|
||||
|
||||
x = self.latents[0]
|
||||
map_shapes = self.latent_shapes
|
||||
if x.size(self.dim) != self.latent_shapes[0][self.dim]:
|
||||
map_shapes = list(self.latent_shapes)
|
||||
video_shape = list(self.latent_shapes[0])
|
||||
video_shape[self.dim] = x.size(self.dim)
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
try:
|
||||
per_modality_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, map_shapes, self.dim)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(self.latents)):
|
||||
modality_windows[mod_idx] = IndexListContextWindow(
|
||||
per_modality_indices[mod_idx], dim=self.dim,
|
||||
total_frames=self.latents[mod_idx].shape[self.dim])
|
||||
return IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
|
||||
modality_windows=modality_windows)
|
||||
|
||||
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
|
||||
"""Slice latents for a context window, injecting guide frames where applicable.
|
||||
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
|
||||
"""
|
||||
sliced = []
|
||||
guide_frame_counts = []
|
||||
for idx in range(len(self.latents)):
|
||||
modality_window = window.get_window_for_modality(idx)
|
||||
retain = retain_index_list if idx == 0 else []
|
||||
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
|
||||
if self.guide_entries[idx] is not None:
|
||||
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
|
||||
else:
|
||||
ng = 0
|
||||
sliced.append(s)
|
||||
guide_frame_counts.append(ng)
|
||||
return sliced, guide_frame_counts
|
||||
|
||||
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
|
||||
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
|
||||
for idx in range(len(self.latents)):
|
||||
if guide_frame_counts[idx] > 0:
|
||||
window_len = len(window.get_window_for_modality(idx).index_list)
|
||||
for ci in range(len(out_per_modality)):
|
||||
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
|
||||
|
||||
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
|
||||
guide_entries = self.guide_entries[modality_idx]
|
||||
guide_frames = self.guide_latents[modality_idx]
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(guide_entries, window.index_list)
|
||||
window.guide_frames_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
|
||||
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
|
||||
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
|
||||
guide_downscale_factors = []
|
||||
if guide_frame_count > 0:
|
||||
full_H = guide_frames.shape[3]
|
||||
for entry_idx, _ in overlap_info:
|
||||
entry_H = guide_entries[entry_idx]["latent_shape"][1]
|
||||
guide_downscale_factors.append(full_H // entry_H)
|
||||
window.guide_downscale_factors = guide_downscale_factors
|
||||
|
||||
if guide_frame_count > 0:
|
||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
|
||||
return latent_slice, 0
|
||||
|
||||
def patch_latent_shapes(self, sub_conds, new_shapes):
|
||||
if not self.is_multimodal:
|
||||
return
|
||||
|
||||
for cond_list in sub_conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
@ -261,37 +333,35 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return model_conds['latent_shapes'].cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _unpack(combined_latent, latent_shapes):
|
||||
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
|
||||
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio."""
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
if latent_shapes is not None and len(latent_shapes) > 1:
|
||||
return comfy.utils.unpack_latents(combined_latent, latent_shapes)
|
||||
return [combined_latent]
|
||||
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
|
||||
primary_total = latent_shapes[0][self.dim]
|
||||
modalities[0] = apply_freenoise(modalities[0], self.dim, self.context_length, self.context_overlap, seed)
|
||||
for i in range(1, len(modalities)):
|
||||
mod_total = latent_shapes[i][self.dim]
|
||||
ratio = mod_total / primary_total if primary_total > 0 else 1
|
||||
mod_ctx_len = max(round(self.context_length * ratio), 1)
|
||||
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
|
||||
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
|
||||
noise, _ = comfy.utils.pack_latents(modalities)
|
||||
return noise
|
||||
return apply_freenoise(noise, self.dim, self.context_length, self.context_overlap, seed)
|
||||
|
||||
@staticmethod
|
||||
def _pack(latents):
|
||||
if len(latents) > 1:
|
||||
return comfy.utils.pack_latents(latents)
|
||||
return latents[0], [latents[0].shape]
|
||||
|
||||
@staticmethod
|
||||
def _patch_latent_shapes(sub_conds, new_shapes):
|
||||
for cond_list in sub_conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingContext:
|
||||
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState:
|
||||
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
|
||||
if is_multimodal:
|
||||
video_latent = comfy.utils.unpack_latents(x_in, latent_shapes)[0]
|
||||
else:
|
||||
video_latent = x_in
|
||||
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
|
||||
|
||||
guide_entries = None
|
||||
unpacked_latents_list = list(unpacked_latents)
|
||||
guide_latents_list = [None] * len(unpacked_latents)
|
||||
guide_entries_list = [None] * len(unpacked_latents)
|
||||
|
||||
# Scan for 'guide_attention_entries' in conds
|
||||
extracted_guide_entries = None
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
@ -299,37 +369,39 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
guide_entries = entries.cond
|
||||
extracted_guide_entries = entries.cond
|
||||
break
|
||||
if guide_entries is not None:
|
||||
if extracted_guide_entries is not None:
|
||||
break
|
||||
|
||||
guide_frame_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries is not None else 0
|
||||
primary_frame_count = video_latent.size(self.dim) - guide_frame_count
|
||||
primary_frames = video_latent.narrow(self.dim, 0, primary_frame_count)
|
||||
guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None
|
||||
# Strip guide frames (only from first modality for now)
|
||||
if extracted_guide_entries is not None:
|
||||
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
|
||||
if guide_count > 0:
|
||||
x = unpacked_latents[0]
|
||||
latent_count = x.size(self.dim) - guide_count
|
||||
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
|
||||
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
|
||||
guide_entries_list[0] = extracted_guide_entries
|
||||
|
||||
if guide_frame_count > 0:
|
||||
aux_data = {"guide_entries": guide_entries}
|
||||
else:
|
||||
aux_data = None
|
||||
|
||||
return WindowingContext(
|
||||
tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data,
|
||||
latent_shapes=latent_shapes)
|
||||
return WindowingState(
|
||||
latents=unpacked_latents_list,
|
||||
guide_latents=guide_latents_list,
|
||||
guide_entries=guide_entries_list,
|
||||
latent_shapes=latent_shapes,
|
||||
dim=self.dim,
|
||||
is_multimodal=is_multimodal)
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
self._window_data = self._build_window_data(x_in, conds)
|
||||
video_frames = self._window_data.tensor.size(self.dim)
|
||||
guide_frames = self._window_data.guide_frames.size(self.dim) if self._window_data.guide_frames is not None else 0
|
||||
if video_frames > self.context_length:
|
||||
if guide_frames > 0:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_frames} guide frames).")
|
||||
else:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.")
|
||||
window_state = self._build_window_state(x_in, conds) # build window_state to check frame counts, will be built again in execute
|
||||
total_frame_count = window_state.latents[0].size(self.dim)
|
||||
if total_frame_count > self.context_length:
|
||||
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
return True
|
||||
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
@ -436,188 +508,121 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
|
||||
window_data = self._window_data
|
||||
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
|
||||
has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0
|
||||
window_state = self._build_window_state(x_in, conds)
|
||||
num_modalities = len(window_state.latents)
|
||||
|
||||
# if multimodal or has concatenated guide frames on noise latent, use the extended execute path
|
||||
if is_multimodal or has_guide_frames:
|
||||
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data)
|
||||
|
||||
# basic legacy execution path for single-modal video latent with no guide frames concatenated
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor,
|
||||
timestep: torch.Tensor, model_options: dict[str],
|
||||
window_data: WindowingContext):
|
||||
"""Extended execute path for multimodal models and models with guide frames appended to the noise latent."""
|
||||
|
||||
latents = self._unpack(x_in, window_data.latent_shapes)
|
||||
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
|
||||
|
||||
primary_frames = window_data.tensor
|
||||
context_windows = self.get_context_windows(model, primary_frames, model_options)
|
||||
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
total_windows = len(enumerated_context_windows)
|
||||
|
||||
# Accumulators sized to video portion for primary, full for other modalities
|
||||
accum_shape_refs = list(latents)
|
||||
if window_data.guide_frames is not None:
|
||||
accum_shape_refs[0] = primary_frames
|
||||
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_shape_refs]
|
||||
# Initialize per-modality accumulators (length 1 for single-modality)
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs]
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
else:
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_shape_refs]
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# Per-modality window indices
|
||||
if is_multimodal:
|
||||
map_shapes = window_data.latent_shapes
|
||||
if primary_frames.size(self.dim) != latents[0].size(self.dim):
|
||||
map_shapes = list(window_data.latent_shapes)
|
||||
video_shape = list(window_data.latent_shapes[0])
|
||||
video_shape[self.dim] = primary_frames.size(self.dim)
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
try:
|
||||
per_modality_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, map_shapes, self.dim)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(latents)):
|
||||
modality_windows[mod_idx] = IndexListContextWindow(
|
||||
per_modality_indices[mod_idx], dim=self.dim,
|
||||
total_frames=latents[mod_idx].shape[self.dim])
|
||||
window = IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=primary_frames.shape[self.dim],
|
||||
modality_windows=modality_windows)
|
||||
|
||||
# Build per-modality windows list
|
||||
per_modality_windows_list = [window]
|
||||
if is_multimodal:
|
||||
for mod_idx in range(1, len(latents)):
|
||||
per_modality_windows_list.append(modality_windows[mod_idx])
|
||||
|
||||
# Slice video, then inject overlapping guide frames if present
|
||||
sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list)
|
||||
if window_data.aux_data is not None:
|
||||
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim)
|
||||
else:
|
||||
sliced_primary, num_guide_frames = sliced_video, 0
|
||||
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}"
|
||||
+ (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "")
|
||||
+ (f" [{len(latents)} modalities]" if is_multimodal else ""))
|
||||
sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))]
|
||||
|
||||
sub_x, sub_shapes = self._pack(sliced)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None)
|
||||
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, primary_frames, window) for cond in conds]
|
||||
if is_multimodal:
|
||||
self._patch_latent_shapes(sub_conds, sub_shapes)
|
||||
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
|
||||
# Unpack output per modality
|
||||
out_per_modality = [self._unpack(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
|
||||
# Strip guide frames from primary output before accumulation
|
||||
if num_guide_frames > 0:
|
||||
window_len = len(window.index_list)
|
||||
for ci in range(len(sub_conds_out)):
|
||||
out_per_modality[ci][0] = out_per_modality[ci][0].narrow(self.dim, 0, window_len)
|
||||
|
||||
# Accumulate per modality
|
||||
for mod_idx in range(len(accum_shape_refs)):
|
||||
mw = per_modality_windows_list[mod_idx]
|
||||
sub_conds_out_per_modality = [out_per_modality[ci][mod_idx] for ci in range(len(sub_conds_out))]
|
||||
self.combine_context_window_results(
|
||||
accum_shape_refs[mod_idx], sub_conds_out_per_modality, sub_conds, mw,
|
||||
window_idx, total_windows, timestep,
|
||||
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||
# accumulate results from each context window
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(
|
||||
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
|
||||
model_options, window_state=window_state, total_windows=total_windows)
|
||||
for result in results:
|
||||
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
|
||||
for mod_idx in range(num_modalities):
|
||||
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
|
||||
modality_window = result.window.get_window_for_modality(mod_idx)
|
||||
self.combine_context_window_results(
|
||||
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
|
||||
result.window_idx, total_windows, timestep,
|
||||
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||
|
||||
# fuse accumulated results into final conds
|
||||
try:
|
||||
result = []
|
||||
result_out = []
|
||||
for ci in range(len(conds)):
|
||||
finalized = []
|
||||
for mod_idx in range(len(accum_shape_refs)):
|
||||
for mod_idx in range(num_modalities):
|
||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||
f = accum[mod_idx][ci]
|
||||
if mod_idx == 0 and window_data.guide_frames is not None:
|
||||
f = torch.cat([f, window_data.guide_frames], dim=self.dim)
|
||||
|
||||
# if guide frames were injected, append them to the end of the fused latents for the next step
|
||||
if window_state.guide_latents[mod_idx] is not None:
|
||||
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
|
||||
finalized.append(f)
|
||||
packed, _ = self._pack(finalized)
|
||||
result.append(packed)
|
||||
return result
|
||||
|
||||
# pack modalities together if needed
|
||||
if window_state.is_multimodal and len(finalized) > 1:
|
||||
packed, _ = comfy.utils.pack_latents(finalized)
|
||||
else:
|
||||
packed = finalized[0]
|
||||
|
||||
result_out.append(packed)
|
||||
return result_out
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
|
||||
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, window_state: WindowingState, total_windows: int = None,
|
||||
device=None, first_device=None):
|
||||
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
|
||||
|
||||
For each window:
|
||||
1. Builds windows (for each modality if multimodal)
|
||||
2. Slices window for each modality
|
||||
3. Injects concatenated latent guide frames where present
|
||||
4. Packs together if needed and calls model
|
||||
5. Unpacks and strips any guides from outputs
|
||||
"""
|
||||
x = window_state.latents[0]
|
||||
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# prepare the window accounting for multimodal windows
|
||||
window = window_state.prepare_window(window, model)
|
||||
|
||||
# slice the window for each modality, injecting guide frames where applicable
|
||||
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.cond_retain_index_list, device)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
|
||||
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
|
||||
)
|
||||
|
||||
# if multimodal, pack modalities together
|
||||
if window_state.is_multimodal and len(sliced) > 1:
|
||||
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
|
||||
else:
|
||||
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
|
||||
|
||||
# get resized conds for window
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
|
||||
|
||||
# if multimodal, patch latent_shapes in conds for correct unpacking in model
|
||||
window_state.patch_latent_shapes(sub_conds, sub_shapes)
|
||||
|
||||
# call model on window
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
|
||||
# unpack outputs and strip guide frames
|
||||
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
|
||||
|
||||
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
@ -684,28 +689,11 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
# For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise
|
||||
# must only shuffle the video portion. Unpack, apply to video, repack.
|
||||
latent_shapes = IndexListContextHandler._get_latent_shapes(
|
||||
[guider.conds.get('positive', guider.conds.get('negative', []))])
|
||||
|
||||
if latent_shapes is not None and len(latent_shapes) > 1:
|
||||
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
|
||||
video_total = latent_shapes[0][handler.dim]
|
||||
modalities[0] = apply_freenoise(modalities[0], handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
for i in range(1, len(modalities)):
|
||||
mod_total = latent_shapes[i][handler.dim]
|
||||
ratio = mod_total / video_total if video_total > 0 else 1
|
||||
mod_ctx_len = max(round(handler.context_length * ratio), 1)
|
||||
mod_ctx_overlap = max(round(handler.context_overlap * ratio), 0)
|
||||
modalities[i] = apply_freenoise(modalities[i], handler.dim, mod_ctx_len, mod_ctx_overlap, extra_args["seed"])
|
||||
noise, _ = comfy.utils.pack_latents(modalities)
|
||||
else:
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
|
||||
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
@ -713,7 +701,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
|
||||
@ -1162,7 +1162,7 @@ class LTXAV(BaseModel):
|
||||
# Adjust spatial end positions for dilated (downscaled) guides.
|
||||
# Each guide entry may have a different downscale factor; expand the
|
||||
# per-entry factor to cover all tokens belonging to that entry.
|
||||
downscale_factors = getattr(window, 'guide_downscale_factors', [])
|
||||
downscale_factors = window.guide_downscale_factors
|
||||
overlap_info = window.guide_overlap_info
|
||||
if downscale_factors:
|
||||
per_token_factor = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user