mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Test implementation for LTX2 context windows
This commit is contained in:
parent
4b1444fc7a
commit
5bfe660b7c
@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
@ -51,12 +53,13 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
@ -165,10 +168,44 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def _get_latent_shapes(self, conds):
|
||||
"""Extract latent_shapes from conditioning. Returns None if absent."""
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
return model_conds['latent_shapes'].cond
|
||||
return None
|
||||
|
||||
def _decompose(self, x, latent_shapes):
|
||||
"""Packed tensor -> list of per-modality tensors."""
|
||||
if latent_shapes is not None and len(latent_shapes) > 1:
|
||||
return comfy.utils.unpack_latents(x, latent_shapes)
|
||||
return [x]
|
||||
|
||||
def _compose(self, modalities):
|
||||
"""List of per-modality tensors -> single tensor for pipeline."""
|
||||
if len(modalities) > 1:
|
||||
return comfy.utils.pack_latents(modalities)
|
||||
return modalities[0], [modalities[0].shape]
|
||||
|
||||
def _patch_latent_shapes(self, sub_conds, new_shapes):
|
||||
"""Patch latent_shapes CONDConstant in (already-copied) sub_conds."""
|
||||
for cond_list in sub_conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
primary = self._decompose(x_in, latent_shapes)[0]
|
||||
if primary.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {primary.size(self.dim)} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
return True
|
||||
@ -277,36 +314,98 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
# Decompose — single-modality: [x_in], multimodal: [video, audio, ...]
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
modalities = self._decompose(x_in, latent_shapes)
|
||||
is_multimodal = len(modalities) > 1
|
||||
primary = modalities[0]
|
||||
|
||||
# Windows from primary modality's temporal dim
|
||||
context_windows = self.get_context_windows(model, primary, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
total_windows = len(enumerated_context_windows)
|
||||
|
||||
# Per-modality accumulators: accum[mod_idx][cond_idx]
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in modalities]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in modalities]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# Per-modality window indices
|
||||
if is_multimodal:
|
||||
per_mod_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, latent_shapes, self.dim)
|
||||
# Build per-modality windows and attach to primary window
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(modalities)):
|
||||
modality_windows[mod_idx] = IndexListContextWindow(
|
||||
per_mod_indices[mod_idx], dim=self.dim,
|
||||
total_frames=modalities[mod_idx].shape[self.dim])
|
||||
window = IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=primary.shape[self.dim],
|
||||
modality_windows=modality_windows)
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
per_mod_indices = [window.index_list]
|
||||
|
||||
# Build per-modality windows list (including primary)
|
||||
mod_windows = [window] # primary window at index 0
|
||||
if is_multimodal:
|
||||
for mod_idx in range(1, len(modalities)):
|
||||
mod_windows.append(modality_windows[mod_idx])
|
||||
|
||||
# Slice each modality
|
||||
sliced = [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(len(modalities))]
|
||||
|
||||
# Compose for pipeline
|
||||
sub_x, sub_shapes = self._compose(sliced)
|
||||
|
||||
# Callbacks
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None)
|
||||
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
# Resize conds using primary tensor as reference (correct temporal dim)
|
||||
sub_conds = [self.get_resized_cond(cond, primary, window) for cond in conds]
|
||||
if is_multimodal:
|
||||
self._patch_latent_shapes(sub_conds, sub_shapes)
|
||||
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
|
||||
# Decompose output per modality
|
||||
out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
# out_per_mod[cond_idx][mod_idx] = tensor
|
||||
|
||||
# Accumulate per modality
|
||||
for mod_idx in range(len(modalities)):
|
||||
mw = mod_windows[mod_idx]
|
||||
# Build per-modality sub_conds_out list for combine
|
||||
mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))]
|
||||
self.combine_context_window_results(
|
||||
modalities[mod_idx], mod_sub_out, sub_conds, mw,
|
||||
window_idx, total_windows, timestep,
|
||||
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||
|
||||
try:
|
||||
result = []
|
||||
for ci in range(len(conds)):
|
||||
finalized = []
|
||||
for mod_idx in range(len(modalities)):
|
||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||
finalized.append(accum[mod_idx][ci])
|
||||
composed, _ = self._compose(finalized)
|
||||
result.append(composed)
|
||||
return result
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
@ -374,7 +473,10 @@ def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args,
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
# Guard: only clamp when dim is within bounds and the value is meaningful
|
||||
# (packed multimodal tensors have noise_shape=[B,1,flat] where flat is not frame count)
|
||||
if handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@ -293,6 +293,11 @@ class BaseModel(torch.nn.Module):
|
||||
Use comfy.context_windows.slice_cond() for common cases."""
|
||||
return None
|
||||
|
||||
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
|
||||
"""Map primary modality's window indices to all modalities.
|
||||
Returns list of index lists, one per modality."""
|
||||
return [primary_indices]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@ -1082,6 +1087,34 @@ class LTXAV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
|
||||
result = [primary_indices]
|
||||
if len(latent_shapes) < 2:
|
||||
return result
|
||||
|
||||
video_total = latent_shapes[0][dim]
|
||||
audio_total = latent_shapes[1][dim]
|
||||
|
||||
# Proportional mapping — video and audio cover same real-time duration
|
||||
v_start, v_end = min(primary_indices), max(primary_indices) + 1
|
||||
a_start = round(v_start * audio_total / video_total)
|
||||
a_end = round(v_end * audio_total / video_total)
|
||||
audio_indices = list(range(a_start, min(a_end, audio_total)))
|
||||
if not audio_indices:
|
||||
audio_indices = [min(a_start, audio_total - 1)]
|
||||
|
||||
result.append(audio_indices)
|
||||
return result
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||
audio_window = window.modality_windows.get(1)
|
||||
if audio_window is not None:
|
||||
import comfy.context_windows
|
||||
return comfy.context_windows.slice_cond(
|
||||
cond_value, audio_window, x_in, device, temporal_dim=2)
|
||||
return None
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user