mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
LTX2 context windows - Refactor guide logic from context_windows into LTXAV model hooks
This commit is contained in:
parent
350237618d
commit
874690c01c
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
@ -181,6 +181,12 @@ def _compute_guide_overlap(guide_entries, window_index_list):
|
|||||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WindowingContext:
|
||||||
|
tensor: torch.Tensor
|
||||||
|
suffix: torch.Tensor | None
|
||||||
|
aux_data: Any
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ContextSchedule:
|
class ContextSchedule:
|
||||||
name: str
|
name: str
|
||||||
@ -242,18 +248,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
if 'latent_shapes' in model_conds:
|
if 'latent_shapes' in model_conds:
|
||||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||||
|
|
||||||
def _get_guide_entries(self, conds):
|
|
||||||
"""Extract guide_attention_entries list from conditioning. Returns None if absent."""
|
|
||||||
for cond_list in conds:
|
|
||||||
if cond_list is None:
|
|
||||||
continue
|
|
||||||
for cond_dict in cond_list:
|
|
||||||
model_conds = cond_dict.get('model_conds', {})
|
|
||||||
gae = model_conds.get('guide_attention_entries')
|
|
||||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
|
||||||
return gae.cond
|
|
||||||
return None
|
|
||||||
|
|
||||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
latent_shapes = self._get_latent_shapes(conds)
|
latent_shapes = self._get_latent_shapes(conds)
|
||||||
primary = self._decompose(x_in, latent_shapes)[0]
|
primary = self._decompose(x_in, latent_shapes)[0]
|
||||||
@ -379,24 +373,19 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
is_multimodal = len(modalities) > 1
|
is_multimodal = len(modalities) > 1
|
||||||
primary = modalities[0]
|
primary = modalities[0]
|
||||||
|
|
||||||
# Separate guide frames from primary modality (guides are appended at the end)
|
# Let model strip auxiliary frames (e.g. guide frames)
|
||||||
guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0
|
window_data = model.prepare_for_windowing(primary, conds, self.dim)
|
||||||
if guide_count > 0:
|
video_primary = window_data.tensor
|
||||||
video_len = primary.size(self.dim) - guide_count
|
aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0
|
||||||
video_primary = primary.narrow(self.dim, 0, video_len)
|
|
||||||
guide_suffix = primary.narrow(self.dim, video_len, guide_count)
|
|
||||||
else:
|
|
||||||
video_primary = primary
|
|
||||||
guide_suffix = None
|
|
||||||
|
|
||||||
# Windows from video portion only (excluding guide frames)
|
# Windows from video portion only
|
||||||
context_windows = self.get_context_windows(model, video_primary, model_options)
|
context_windows = self.get_context_windows(model, video_primary, model_options)
|
||||||
enumerated_context_windows = list(enumerate(context_windows))
|
enumerated_context_windows = list(enumerate(context_windows))
|
||||||
total_windows = len(enumerated_context_windows)
|
total_windows = len(enumerated_context_windows)
|
||||||
|
|
||||||
# Accumulators sized to video portion for primary, full for other modalities
|
# Accumulators sized to video portion for primary, full for other modalities
|
||||||
accum_modalities = list(modalities)
|
accum_modalities = list(modalities)
|
||||||
if guide_suffix is not None:
|
if window_data.suffix is not None:
|
||||||
accum_modalities[0] = video_primary
|
accum_modalities[0] = video_primary
|
||||||
|
|
||||||
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities]
|
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities]
|
||||||
@ -406,25 +395,22 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
|
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
|
||||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities]
|
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities]
|
||||||
|
|
||||||
guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None
|
|
||||||
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||||
callback(self, model, x_in, conds, timestep, model_options)
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
for window_idx, window in enumerated_context_windows:
|
for window_idx, window in enumerated_context_windows:
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
|
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
|
||||||
+ (f" (+{guide_count} guide)" if guide_count > 0 else "")
|
+ (f" (+{aux_count} aux)" if aux_count > 0 else "")
|
||||||
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
|
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
|
||||||
|
|
||||||
# Per-modality window indices
|
# Per-modality window indices
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
# Adjust latent_shapes so video shape reflects video-only frames (excludes guides)
|
|
||||||
map_shapes = latent_shapes
|
map_shapes = latent_shapes
|
||||||
if guide_count > 0:
|
if video_primary.size(self.dim) != primary.size(self.dim):
|
||||||
map_shapes = list(latent_shapes)
|
map_shapes = list(latent_shapes)
|
||||||
video_shape = list(latent_shapes[0])
|
video_shape = list(latent_shapes[0])
|
||||||
video_shape[self.dim] = video_shape[self.dim] - guide_count
|
video_shape[self.dim] = video_primary.size(self.dim)
|
||||||
map_shapes[0] = torch.Size(video_shape)
|
map_shapes[0] = torch.Size(video_shape)
|
||||||
per_mod_indices = model.map_context_window_to_modalities(
|
per_mod_indices = model.map_context_window_to_modalities(
|
||||||
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
|
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
|
||||||
@ -446,30 +432,10 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
for mod_idx in range(1, len(modalities)):
|
for mod_idx in range(1, len(modalities)):
|
||||||
mod_windows.append(modality_windows[mod_idx])
|
mod_windows.append(modality_windows[mod_idx])
|
||||||
|
|
||||||
# Slice video, then select overlapping guide frames
|
# Slice video, then let model inject auxiliary frames
|
||||||
sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list)
|
sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list)
|
||||||
num_guide_in_window = 0
|
sliced_primary, num_aux = model.prepare_window_input(
|
||||||
if guide_suffix is not None and guide_entries is not None:
|
sliced_video, window, window_data.aux_data, self.dim)
|
||||||
overlap = _compute_guide_overlap(guide_entries, window.index_list)
|
|
||||||
if overlap[3] > 0:
|
|
||||||
suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap
|
|
||||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
|
||||||
sliced_guide = guide_suffix[idx]
|
|
||||||
window.guide_suffix_indices = suffix_idx
|
|
||||||
window.guide_overlap_info = overlap_info
|
|
||||||
window.guide_kf_local_positions = kf_local_pos
|
|
||||||
else:
|
|
||||||
sliced_guide = None
|
|
||||||
window.guide_suffix_indices = []
|
|
||||||
window.guide_overlap_info = []
|
|
||||||
window.guide_kf_local_positions = []
|
|
||||||
else:
|
|
||||||
sliced_guide = None
|
|
||||||
|
|
||||||
if sliced_guide is not None:
|
|
||||||
sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim)
|
|
||||||
else:
|
|
||||||
sliced_primary = sliced_video
|
|
||||||
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
|
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
|
||||||
|
|
||||||
# Compose for pipeline
|
# Compose for pipeline
|
||||||
@ -481,7 +447,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
|
|
||||||
model_options["transformer_options"]["context_window"] = window
|
model_options["transformer_options"]["context_window"] = window
|
||||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||||
# Resize conds using video_primary as reference (excludes guide frames)
|
|
||||||
sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds]
|
sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds]
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
self._patch_latent_shapes(sub_conds, sub_shapes)
|
self._patch_latent_shapes(sub_conds, sub_shapes)
|
||||||
@ -490,14 +455,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
|
|
||||||
# Decompose output per modality
|
# Decompose output per modality
|
||||||
out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||||
# out_per_mod[cond_idx][mod_idx] = tensor
|
|
||||||
|
|
||||||
# Strip guide frames from primary output before accumulation
|
# Strip auxiliary frames from primary output before accumulation
|
||||||
if num_guide_in_window > 0:
|
if num_aux > 0:
|
||||||
window_len = len(window.index_list)
|
window_len = len(window.index_list)
|
||||||
for ci in range(len(sub_conds_out)):
|
for ci in range(len(sub_conds_out)):
|
||||||
primary_out = out_per_mod[ci][0]
|
out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len)
|
||||||
out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len)
|
|
||||||
|
|
||||||
# Accumulate per modality (using video-only sizes)
|
# Accumulate per modality (using video-only sizes)
|
||||||
for mod_idx in range(len(accum_modalities)):
|
for mod_idx in range(len(accum_modalities)):
|
||||||
@ -516,10 +479,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||||
f = accum[mod_idx][ci]
|
f = accum[mod_idx][ci]
|
||||||
# Re-append original guide_suffix (not model output — sampling loop
|
# Re-append model's suffix (auxiliary frames stripped before windowing)
|
||||||
# respects denoise_mask and never modifies guide frame positions)
|
if mod_idx == 0 and window_data.suffix is not None:
|
||||||
if mod_idx == 0 and guide_suffix is not None:
|
f = torch.cat([f, window_data.suffix], dim=self.dim)
|
||||||
f = torch.cat([f, guide_suffix], dim=self.dim)
|
|
||||||
finalized.append(f)
|
finalized.append(f)
|
||||||
composed, _ = self._compose(finalized)
|
composed, _ = self._compose(finalized)
|
||||||
result.append(composed)
|
result.append(composed)
|
||||||
|
|||||||
@ -287,6 +287,12 @@ class BaseModel(torch.nn.Module):
|
|||||||
return data
|
return data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def prepare_for_windowing(self, primary, conds, dim):
|
||||||
|
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
|
||||||
|
|
||||||
|
def prepare_window_input(self, video_slice, window, aux_data, dim):
|
||||||
|
return video_slice, 0
|
||||||
|
|
||||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
||||||
Return a sliced cond object, or None to fall through to default handling.
|
Return a sliced cond object, or None to fall through to default handling.
|
||||||
@ -1113,6 +1119,51 @@ class LTXAV(BaseModel):
|
|||||||
return sum(e["latent_shape"][0] for e in gae.cond)
|
return sum(e["latent_shape"][0] for e in gae.cond)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_guide_entries(conds):
|
||||||
|
for cond_list in conds:
|
||||||
|
if cond_list is None:
|
||||||
|
continue
|
||||||
|
for cond_dict in cond_list:
|
||||||
|
model_conds = cond_dict.get('model_conds', {})
|
||||||
|
gae = model_conds.get('guide_attention_entries')
|
||||||
|
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||||
|
return gae.cond
|
||||||
|
return None
|
||||||
|
|
||||||
|
def prepare_for_windowing(self, primary, conds, dim):
|
||||||
|
guide_count = self.get_guide_frame_count(primary, conds)
|
||||||
|
if guide_count <= 0:
|
||||||
|
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
|
||||||
|
video_len = primary.size(dim) - guide_count
|
||||||
|
video_primary = primary.narrow(dim, 0, video_len)
|
||||||
|
guide_suffix = primary.narrow(dim, video_len, guide_count)
|
||||||
|
guide_entries = self._get_guide_entries(conds)
|
||||||
|
return comfy.context_windows.WindowingContext(
|
||||||
|
tensor=video_primary, suffix=guide_suffix,
|
||||||
|
aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix})
|
||||||
|
|
||||||
|
def prepare_window_input(self, video_slice, window, aux_data, dim):
|
||||||
|
if aux_data is None:
|
||||||
|
return video_slice, 0
|
||||||
|
guide_entries = aux_data["guide_entries"]
|
||||||
|
guide_suffix = aux_data["guide_suffix"]
|
||||||
|
if guide_entries is None:
|
||||||
|
window.guide_suffix_indices = []
|
||||||
|
window.guide_overlap_info = []
|
||||||
|
window.guide_kf_local_positions = []
|
||||||
|
return video_slice, 0
|
||||||
|
overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list)
|
||||||
|
suffix_idx, overlap_info, kf_local_pos, num_guide = overlap
|
||||||
|
window.guide_suffix_indices = suffix_idx
|
||||||
|
window.guide_overlap_info = overlap_info
|
||||||
|
window.guide_kf_local_positions = kf_local_pos
|
||||||
|
if num_guide > 0:
|
||||||
|
idx = tuple([slice(None)] * dim + [suffix_idx])
|
||||||
|
sliced_guide = guide_suffix[idx]
|
||||||
|
return torch.cat([video_slice, sliced_guide], dim=dim), num_guide
|
||||||
|
return video_slice, 0
|
||||||
|
|
||||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
# Audio denoise mask — slice using audio modality window
|
# Audio denoise mask — slice using audio modality window
|
||||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user