mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
LTX2 context windows part 3 - Generalize guide splitting to windows
This commit is contained in:
parent
941d50e777
commit
115dbb69d1
@ -140,6 +140,48 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
def _compute_guide_overlap(guide_entries, window_index_list):
|
||||
"""Compute which guide frames overlap with a context window.
|
||||
|
||||
Args:
|
||||
guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape')
|
||||
window_index_list: the window's frame indices into the video portion
|
||||
|
||||
Returns None if any entry lacks 'latent_start' (backward compat → legacy path).
|
||||
Otherwise returns (suffix_indices, overlap_info, kf_local_positions, total_overlap):
|
||||
suffix_indices: indices into the guide_suffix tensor for frame selection
|
||||
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
|
||||
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
|
||||
total_overlap: total number of overlapping guide frames
|
||||
"""
|
||||
window_set = set(window_index_list)
|
||||
window_list = list(window_index_list)
|
||||
suffix_indices = []
|
||||
overlap_info = []
|
||||
kf_local_positions = []
|
||||
suffix_base = 0
|
||||
|
||||
for entry_idx, entry in enumerate(guide_entries):
|
||||
latent_start = entry.get("latent_start", None)
|
||||
if latent_start is None:
|
||||
return None
|
||||
guide_len = entry["latent_shape"][0]
|
||||
entry_overlap = 0
|
||||
|
||||
for local_offset in range(guide_len):
|
||||
video_pos = latent_start + local_offset
|
||||
if video_pos in window_set:
|
||||
suffix_indices.append(suffix_base + local_offset)
|
||||
kf_local_positions.append(window_list.index(video_pos))
|
||||
entry_overlap += 1
|
||||
|
||||
if entry_overlap > 0:
|
||||
overlap_info.append((entry_idx, entry_overlap))
|
||||
suffix_base += guide_len
|
||||
|
||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@ -201,6 +243,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
def _get_guide_entries(self, conds):
|
||||
"""Extract guide_attention_entries list from conditioning. Returns None if absent."""
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
gae = model_conds.get('guide_attention_entries')
|
||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||
return gae.cond
|
||||
return None
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
primary = self._decompose(x_in, latent_shapes)[0]
|
||||
@ -353,6 +407,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities]
|
||||
|
||||
guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
@ -391,10 +447,30 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
for mod_idx in range(1, len(modalities)):
|
||||
mod_windows.append(modality_windows[mod_idx])
|
||||
|
||||
# Slice video and guide with same window indices, concatenate
|
||||
# Slice video, then select overlapping guide frames
|
||||
sliced_video = mod_windows[0].get_tensor(video_primary)
|
||||
if guide_suffix is not None:
|
||||
sliced_guide = mod_windows[0].get_tensor(guide_suffix)
|
||||
num_guide_in_window = 0
|
||||
if guide_suffix is not None and guide_entries is not None:
|
||||
overlap = _compute_guide_overlap(guide_entries, window.index_list)
|
||||
if overlap is None:
|
||||
# Legacy: no latent_start → equal-size assumption
|
||||
sliced_guide = mod_windows[0].get_tensor(guide_suffix)
|
||||
num_guide_in_window = sliced_guide.shape[self.dim]
|
||||
elif overlap[3] > 0:
|
||||
suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap
|
||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||
sliced_guide = guide_suffix[idx]
|
||||
window.guide_suffix_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
else:
|
||||
sliced_guide = None
|
||||
window.guide_overlap_info = []
|
||||
window.guide_kf_local_positions = []
|
||||
else:
|
||||
sliced_guide = None
|
||||
|
||||
if sliced_guide is not None:
|
||||
sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim)
|
||||
else:
|
||||
sliced_primary = sliced_video
|
||||
@ -421,7 +497,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# out_per_mod[cond_idx][mod_idx] = tensor
|
||||
|
||||
# Strip guide frames from primary output before accumulation
|
||||
if guide_count > 0:
|
||||
if num_guide_in_window > 0:
|
||||
window_len = len(window.index_list)
|
||||
for ci in range(len(sub_conds_out)):
|
||||
primary_out = out_per_mod[ci][0]
|
||||
|
||||
@ -1028,7 +1028,7 @@ class LTXVModel(LTXBaseModel):
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None:
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
|
||||
@ -305,8 +305,8 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Resize guide-related conditioning for context windows.
|
||||
Derives guide_count from denoise_mask/x_in size difference.
|
||||
Derives spatial dims from x_in. Requires self.diffusion_model.patchifier."""
|
||||
Uses overlap info from window if available (generalized path),
|
||||
otherwise falls back to legacy equal-size assumption."""
|
||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
cond_tensor = cond_value.cond
|
||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||
@ -315,30 +315,76 @@ class BaseModel(torch.nn.Module):
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
sliced_guide = window.get_tensor(guide_mask, device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
# Use overlap-based guide selection if available, otherwise legacy
|
||||
suffix_indices = getattr(window, 'guide_suffix_indices', None)
|
||||
if suffix_indices is not None:
|
||||
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||
sliced_guide = guide_mask[idx].to(device) if suffix_indices else None
|
||||
else:
|
||||
sliced_guide = window.get_tensor(guide_mask, device)
|
||||
if sliced_guide is not None and sliced_guide.shape[window.dim] > 0:
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
else:
|
||||
return cond_value._copy_with(sliced_video)
|
||||
|
||||
if cond_key == "keyframe_idxs":
|
||||
window_len = len(window.index_list)
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
kf_local_pos = getattr(window, 'guide_kf_local_positions', None)
|
||||
if kf_local_pos is not None:
|
||||
# Generalized: regenerate coords for full window, select guide positions
|
||||
if not kf_local_pos:
|
||||
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
window_len = len(window.index_list)
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
tokens = []
|
||||
for pos in kf_local_pos:
|
||||
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
else:
|
||||
# Legacy: regenerate for window_len (equal-size assumption)
|
||||
window_len = len(window.index_list)
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
if cond_key == "guide_attention_entries":
|
||||
window_len = len(window.index_list)
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
||||
return cond_value._copy_with(new_entries)
|
||||
overlap_info = getattr(window, 'guide_overlap_info', None)
|
||||
if overlap_info is not None:
|
||||
# Generalized: per-guide adjustment based on overlap
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = []
|
||||
for entry_idx, overlap_count in overlap_info:
|
||||
e = cond_value.cond[entry_idx]
|
||||
new_entries.append({**e,
|
||||
"pre_filter_count": overlap_count * H * W,
|
||||
"latent_shape": [overlap_count, H, W]})
|
||||
return cond_value._copy_with(new_entries)
|
||||
else:
|
||||
# Legacy: all entries adjusted to window_len
|
||||
window_len = len(window.index_list)
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = [{**e, "pre_filter_count": window_len * H * W,
|
||||
"latent_shape": [window_len, H, W]} for e in cond_value.cond]
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -135,7 +135,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, latent_start=0):
|
||||
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||
|
||||
Each entry tracks one guide reference for per-reference attention control.
|
||||
@ -146,6 +146,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
|
||||
"strength": strength,
|
||||
"pixel_mask": None,
|
||||
"latent_shape": latent_shape,
|
||||
"latent_start": latent_start,
|
||||
}
|
||||
results = []
|
||||
for cond in (positive, negative):
|
||||
@ -362,6 +363,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||
positive, negative = _append_guide_attention_entry(
|
||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||
latent_start=latent_idx,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user