mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
LTX2 context windows part 2 - Guide aware processing
This commit is contained in:
parent
5bfe660b7c
commit
56de390c25
@ -204,8 +204,13 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
primary = self._decompose(x_in, latent_shapes)[0]
|
||||
if primary.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {primary.size(self.dim)} frames.")
|
||||
guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0
|
||||
video_frames = primary.size(self.dim) - guide_count
|
||||
if video_frames > self.context_length:
|
||||
if guide_count > 0:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_count} guide frames excluded).")
|
||||
else:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
return True
|
||||
@ -321,18 +326,32 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
is_multimodal = len(modalities) > 1
|
||||
primary = modalities[0]
|
||||
|
||||
# Windows from primary modality's temporal dim
|
||||
context_windows = self.get_context_windows(model, primary, model_options)
|
||||
# Separate guide frames from primary modality (guides are appended at the end)
|
||||
guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0
|
||||
if guide_count > 0:
|
||||
video_len = primary.size(self.dim) - guide_count
|
||||
video_primary = primary.narrow(self.dim, 0, video_len)
|
||||
guide_suffix = primary.narrow(self.dim, video_len, guide_count)
|
||||
else:
|
||||
video_primary = primary
|
||||
guide_suffix = None
|
||||
|
||||
# Windows from video portion only (excluding guide frames)
|
||||
context_windows = self.get_context_windows(model, video_primary, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
total_windows = len(enumerated_context_windows)
|
||||
|
||||
# Per-modality accumulators: accum[mod_idx][cond_idx]
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in modalities]
|
||||
# Accumulators sized to video portion for primary, full for other modalities
|
||||
accum_modalities = list(modalities)
|
||||
if guide_suffix is not None:
|
||||
accum_modalities[0] = video_primary
|
||||
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities]
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
|
||||
else:
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in modalities]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in modalities]
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
@ -340,10 +359,22 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# Attach guide info to window for resize_cond_for_context_window
|
||||
window.guide_count = guide_count
|
||||
if guide_suffix is not None:
|
||||
window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4])
|
||||
|
||||
# Per-modality window indices
|
||||
if is_multimodal:
|
||||
# Adjust latent_shapes so video shape reflects video-only frames (excludes guides)
|
||||
map_shapes = latent_shapes
|
||||
if guide_count > 0:
|
||||
map_shapes = list(latent_shapes)
|
||||
video_shape = list(latent_shapes[0])
|
||||
video_shape[self.dim] = video_shape[self.dim] - guide_count
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
per_mod_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, latent_shapes, self.dim)
|
||||
window.index_list, map_shapes, self.dim)
|
||||
# Build per-modality windows and attach to primary window
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(modalities)):
|
||||
@ -351,8 +382,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
per_mod_indices[mod_idx], dim=self.dim,
|
||||
total_frames=modalities[mod_idx].shape[self.dim])
|
||||
window = IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=primary.shape[self.dim],
|
||||
window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim],
|
||||
modality_windows=modality_windows)
|
||||
window.guide_count = guide_count
|
||||
if guide_suffix is not None:
|
||||
window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4])
|
||||
else:
|
||||
per_mod_indices = [window.index_list]
|
||||
|
||||
@ -362,8 +396,14 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
for mod_idx in range(1, len(modalities)):
|
||||
mod_windows.append(modality_windows[mod_idx])
|
||||
|
||||
# Slice each modality
|
||||
sliced = [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(len(modalities))]
|
||||
# Slice video and guide with same window indices, concatenate
|
||||
sliced_video = mod_windows[0].get_tensor(video_primary)
|
||||
if guide_suffix is not None:
|
||||
sliced_guide = mod_windows[0].get_tensor(guide_suffix)
|
||||
sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim)
|
||||
else:
|
||||
sliced_primary = sliced_video
|
||||
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
|
||||
|
||||
# Compose for pipeline
|
||||
sub_x, sub_shapes = self._compose(sliced)
|
||||
@ -374,8 +414,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
# Resize conds using primary tensor as reference (correct temporal dim)
|
||||
sub_conds = [self.get_resized_cond(cond, primary, window) for cond in conds]
|
||||
# Resize conds using video_primary as reference (excludes guide frames)
|
||||
sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds]
|
||||
if is_multimodal:
|
||||
self._patch_latent_shapes(sub_conds, sub_shapes)
|
||||
|
||||
@ -385,13 +425,19 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
# out_per_mod[cond_idx][mod_idx] = tensor
|
||||
|
||||
# Accumulate per modality
|
||||
for mod_idx in range(len(modalities)):
|
||||
# Strip guide frames from primary output before accumulation
|
||||
if guide_count > 0:
|
||||
window_len = len(window.index_list)
|
||||
for ci in range(len(sub_conds_out)):
|
||||
primary_out = out_per_mod[ci][0]
|
||||
out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len)
|
||||
|
||||
# Accumulate per modality (using video-only sizes)
|
||||
for mod_idx in range(len(accum_modalities)):
|
||||
mw = mod_windows[mod_idx]
|
||||
# Build per-modality sub_conds_out list for combine
|
||||
mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))]
|
||||
self.combine_context_window_results(
|
||||
modalities[mod_idx], mod_sub_out, sub_conds, mw,
|
||||
accum_modalities[mod_idx], mod_sub_out, sub_conds, mw,
|
||||
window_idx, total_windows, timestep,
|
||||
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||
|
||||
@ -399,10 +445,15 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
result = []
|
||||
for ci in range(len(conds)):
|
||||
finalized = []
|
||||
for mod_idx in range(len(modalities)):
|
||||
for mod_idx in range(len(accum_modalities)):
|
||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||
finalized.append(accum[mod_idx][ci])
|
||||
f = accum[mod_idx][ci]
|
||||
# Re-append original guide_suffix (not model output — sampling loop
|
||||
# respects denoise_mask and never modifies guide frame positions)
|
||||
if mod_idx == 0 and guide_suffix is not None:
|
||||
f = torch.cat([f, guide_suffix], dim=self.dim)
|
||||
finalized.append(f)
|
||||
composed, _ = self._compose(finalized)
|
||||
result.append(composed)
|
||||
return result
|
||||
|
||||
@ -298,6 +298,11 @@ class BaseModel(torch.nn.Module):
|
||||
Returns list of index lists, one per modality."""
|
||||
return [primary_indices]
|
||||
|
||||
def get_guide_frame_count(self, x, conds):
|
||||
"""Return the number of trailing guide frames appended to x along the temporal dim.
|
||||
Override in subclasses that concatenate guide reference frames to the latent."""
|
||||
return 0
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@ -1021,12 +1026,64 @@ class LTXV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
def get_guide_frame_count(self, x, conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
gae = model_conds.get('guide_attention_entries')
|
||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||
return sum(e["latent_shape"][0] for e in gae.cond)
|
||||
return 0
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
guide_count = getattr(window, 'guide_count', 0)
|
||||
|
||||
if cond_key == "denoise_mask" and guide_count > 0:
|
||||
# Slice both video and guide halves with same window indices
|
||||
cond_tensor = cond_value.cond
|
||||
T_video = cond_tensor.size(window.dim) - guide_count
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
sliced_guide = window.get_tensor(guide_mask, device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
|
||||
if cond_key == "keyframe_idxs" and guide_count > 0:
|
||||
# Recompute coords for window_len frames so guide tokens are co-located
|
||||
# with noise tokens in RoPE space (identical to a standalone short video)
|
||||
window_len = len(window.index_list)
|
||||
H, W = window.guide_spatial
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
if cond_key == "guide_attention_entries" and guide_count > 0:
|
||||
# Adjust token counts for window size
|
||||
window_len = len(window.index_list)
|
||||
H, W = window.guide_spatial
|
||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
class LTXAV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
logging.info(f"LTXAV.extra_conds: guide_attention_entries={'guide_attention_entries' in kwargs}, keyframe_idxs={'keyframe_idxs' in kwargs}")
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
@ -1106,13 +1163,65 @@ class LTXAV(BaseModel):
|
||||
result.append(audio_indices)
|
||||
return result
|
||||
|
||||
def get_guide_frame_count(self, x, conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
gae = model_conds.get('guide_attention_entries')
|
||||
logging.info(f"LTXAV.get_guide_frame_count: keys={list(model_conds.keys())}, gae={gae is not None}")
|
||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||
count = sum(e["latent_shape"][0] for e in gae.cond)
|
||||
logging.info(f"LTXAV.get_guide_frame_count: found {count} guide frames")
|
||||
return count
|
||||
logging.info("LTXAV.get_guide_frame_count: no guide frames found")
|
||||
return 0
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# Audio-specific handling
|
||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||
audio_window = window.modality_windows.get(1)
|
||||
if audio_window is not None:
|
||||
import comfy.context_windows
|
||||
return comfy.context_windows.slice_cond(
|
||||
cond_value, audio_window, x_in, device, temporal_dim=2)
|
||||
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# Guide handling (same as LTXV — shared guide mechanism)
|
||||
guide_count = getattr(window, 'guide_count', 0)
|
||||
if cond_key in ("keyframe_idxs", "guide_attention_entries", "denoise_mask"):
|
||||
logging.info(f"LTXAV resize_cond: {cond_key}, guide_count={guide_count}, has_spatial={hasattr(window, 'guide_spatial')}")
|
||||
|
||||
if cond_key == "denoise_mask" and guide_count > 0:
|
||||
cond_tensor = cond_value.cond
|
||||
T_video = cond_tensor.size(window.dim) - guide_count
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
sliced_guide = window.get_tensor(guide_mask, device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
|
||||
if cond_key == "keyframe_idxs" and guide_count > 0:
|
||||
window_len = len(window.index_list)
|
||||
H, W = window.guide_spatial
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
if cond_key == "guide_attention_entries" and guide_count > 0:
|
||||
window_len = len(window.index_list)
|
||||
H, W = window.guide_spatial
|
||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user