mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-26 10:22:30 +08:00
LTX2 context windows - Cleanup: latent_start value is required for context windows with guides
This commit is contained in:
parent
71712472f5
commit
3660533f83
@ -147,8 +147,7 @@ def _compute_guide_overlap(guide_entries, window_index_list):
|
|||||||
guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape')
|
guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape')
|
||||||
window_index_list: the window's frame indices into the video portion
|
window_index_list: the window's frame indices into the video portion
|
||||||
|
|
||||||
Returns None if any entry lacks 'latent_start' (backward compat → legacy path).
|
Returns (suffix_indices, overlap_info, kf_local_positions, total_overlap):
|
||||||
Otherwise returns (suffix_indices, overlap_info, kf_local_positions, total_overlap):
|
|
||||||
suffix_indices: indices into the guide_suffix tensor for frame selection
|
suffix_indices: indices into the guide_suffix tensor for frame selection
|
||||||
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
|
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
|
||||||
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
|
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
|
||||||
@ -164,7 +163,7 @@ def _compute_guide_overlap(guide_entries, window_index_list):
|
|||||||
for entry_idx, entry in enumerate(guide_entries):
|
for entry_idx, entry in enumerate(guide_entries):
|
||||||
latent_start = entry.get("latent_start", None)
|
latent_start = entry.get("latent_start", None)
|
||||||
if latent_start is None:
|
if latent_start is None:
|
||||||
return None
|
raise ValueError("guide_attention_entry missing required 'latent_start'.")
|
||||||
guide_len = entry["latent_shape"][0]
|
guide_len = entry["latent_shape"][0]
|
||||||
entry_overlap = 0
|
entry_overlap = 0
|
||||||
|
|
||||||
@ -452,11 +451,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
num_guide_in_window = 0
|
num_guide_in_window = 0
|
||||||
if guide_suffix is not None and guide_entries is not None:
|
if guide_suffix is not None and guide_entries is not None:
|
||||||
overlap = _compute_guide_overlap(guide_entries, window.index_list)
|
overlap = _compute_guide_overlap(guide_entries, window.index_list)
|
||||||
if overlap is None:
|
if overlap[3] > 0:
|
||||||
# Legacy: no latent_start → equal-size assumption
|
|
||||||
sliced_guide = mod_windows[0].get_tensor(guide_suffix)
|
|
||||||
num_guide_in_window = sliced_guide.shape[self.dim]
|
|
||||||
elif overlap[3] > 0:
|
|
||||||
suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap
|
suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap
|
||||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||||
sliced_guide = guide_suffix[idx]
|
sliced_guide = guide_suffix[idx]
|
||||||
|
|||||||
@ -305,8 +305,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
"""Resize guide-related conditioning for context windows.
|
"""Resize guide-related conditioning for context windows.
|
||||||
Uses overlap info from window if available (generalized path),
|
Requires guide_suffix_indices, guide_overlap_info, and guide_kf_local_positions
|
||||||
otherwise falls back to legacy equal-size assumption."""
|
to be set on the window by _compute_guide_overlap."""
|
||||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
cond_tensor = cond_value.cond
|
cond_tensor = cond_value.cond
|
||||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||||
@ -315,76 +315,46 @@ class BaseModel(torch.nn.Module):
|
|||||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||||
# Use overlap-based guide selection if available, otherwise legacy
|
suffix_indices = window.guide_suffix_indices
|
||||||
suffix_indices = getattr(window, 'guide_suffix_indices', None)
|
if suffix_indices:
|
||||||
if suffix_indices is not None:
|
|
||||||
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||||
sliced_guide = guide_mask[idx].to(device) if suffix_indices else None
|
sliced_guide = guide_mask[idx].to(device)
|
||||||
else:
|
|
||||||
sliced_guide = window.get_tensor(guide_mask, device)
|
|
||||||
if sliced_guide is not None and sliced_guide.shape[window.dim] > 0:
|
|
||||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||||
else:
|
else:
|
||||||
return cond_value._copy_with(sliced_video)
|
return cond_value._copy_with(sliced_video)
|
||||||
|
|
||||||
if cond_key == "keyframe_idxs":
|
if cond_key == "keyframe_idxs":
|
||||||
kf_local_pos = getattr(window, 'guide_kf_local_positions', None)
|
kf_local_pos = window.guide_kf_local_positions
|
||||||
if kf_local_pos is not None:
|
if not kf_local_pos:
|
||||||
# Generalized: regenerate coords for full window, select guide positions
|
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||||
if not kf_local_pos:
|
H, W = x_in.shape[3], x_in.shape[4]
|
||||||
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
window_len = len(window.index_list)
|
||||||
H, W = x_in.shape[3], x_in.shape[4]
|
patchifier = self.diffusion_model.patchifier
|
||||||
window_len = len(window.index_list)
|
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||||
patchifier = self.diffusion_model.patchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
pixel_coords = latent_to_pixel_coords(
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
latent_coords,
|
||||||
pixel_coords = latent_to_pixel_coords(
|
self.diffusion_model.vae_scale_factors,
|
||||||
latent_coords,
|
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||||
self.diffusion_model.vae_scale_factors,
|
tokens = []
|
||||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
for pos in kf_local_pos:
|
||||||
tokens = []
|
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||||
for pos in kf_local_pos:
|
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||||
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
B = cond_value.cond.shape[0]
|
||||||
pixel_coords = pixel_coords[:, :, tokens, :]
|
if B > 1:
|
||||||
B = cond_value.cond.shape[0]
|
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||||
if B > 1:
|
return cond_value._copy_with(pixel_coords)
|
||||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
|
||||||
return cond_value._copy_with(pixel_coords)
|
|
||||||
else:
|
|
||||||
# Legacy: regenerate for window_len (equal-size assumption)
|
|
||||||
window_len = len(window.index_list)
|
|
||||||
H, W = x_in.shape[3], x_in.shape[4]
|
|
||||||
patchifier = self.diffusion_model.patchifier
|
|
||||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
|
||||||
pixel_coords = latent_to_pixel_coords(
|
|
||||||
latent_coords,
|
|
||||||
self.diffusion_model.vae_scale_factors,
|
|
||||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
|
||||||
B = cond_value.cond.shape[0]
|
|
||||||
if B > 1:
|
|
||||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
|
||||||
return cond_value._copy_with(pixel_coords)
|
|
||||||
|
|
||||||
if cond_key == "guide_attention_entries":
|
if cond_key == "guide_attention_entries":
|
||||||
overlap_info = getattr(window, 'guide_overlap_info', None)
|
overlap_info = window.guide_overlap_info
|
||||||
if overlap_info is not None:
|
H, W = x_in.shape[3], x_in.shape[4]
|
||||||
# Generalized: per-guide adjustment based on overlap
|
new_entries = []
|
||||||
H, W = x_in.shape[3], x_in.shape[4]
|
for entry_idx, overlap_count in overlap_info:
|
||||||
new_entries = []
|
e = cond_value.cond[entry_idx]
|
||||||
for entry_idx, overlap_count in overlap_info:
|
new_entries.append({**e,
|
||||||
e = cond_value.cond[entry_idx]
|
"pre_filter_count": overlap_count * H * W,
|
||||||
new_entries.append({**e,
|
"latent_shape": [overlap_count, H, W]})
|
||||||
"pre_filter_count": overlap_count * H * W,
|
return cond_value._copy_with(new_entries)
|
||||||
"latent_shape": [overlap_count, H, W]})
|
|
||||||
return cond_value._copy_with(new_entries)
|
|
||||||
else:
|
|
||||||
# Legacy: all entries adjusted to window_len
|
|
||||||
window_len = len(window.index_list)
|
|
||||||
H, W = x_in.shape[3], x_in.shape[4]
|
|
||||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
|
||||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
|
||||||
return cond_value._copy_with(new_entries)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user