From 09f15d82daac51594bb711237cc534d88e29c8d9 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Wed, 13 May 2026 18:14:52 -0600 Subject: [PATCH] fix: Stop LTXVCropGuides leaving stray latent frames when guides share a start position --- comfy_extras/nodes_lt.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 3dc1199c2..695549c1c 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -183,10 +183,20 @@ def get_noise_mask(latent): noise_mask = noise_mask.clone() return noise_mask -def get_keyframe_idxs(cond): +def get_keyframe_idxs(cond, latent_shape=None): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 + # Get number of keyframes from latent_shape or guide_attention_entries if available + if latent_shape is not None and len(latent_shape) == 5: + tokens_per_frame = latent_shape[-2] * latent_shape[-1] + num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame + return keyframe_idxs, num_keyframes + entries = conditioning_get_any_value(cond, "guide_attention_entries", None) + if entries: + num_keyframes = sum(e["latent_shape"][0] for e in entries) + return keyframe_idxs, num_keyframes + # fallback, may under-count if keyframes share t-start # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes @@ -238,9 +248,9 @@ class LTXVAddGuide(io.ComfyNode): return encode_pixels, t @classmethod - def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): + def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None): time_scale_factor, _, _ = scale_factors - _, num_keyframes = get_keyframe_idxs(cond) + _, num_keyframes = get_keyframe_idxs(cond, latent_shape) latent_count = latent_length - num_keyframes frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0) if guide_length > 1 and frame_idx != 0: @@ -344,7 +354,7 @@ class LTXVAddGuide(io.ComfyNode): num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 resolved_frame_idx = frame_idx if frame_idx < 0: - _, num_keyframes = get_keyframe_idxs(positive) + _, num_keyframes = get_keyframe_idxs(positive, latent_image.shape) resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0) causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1 @@ -357,7 +367,7 @@ class LTXVAddGuide(io.ComfyNode): t = t[:, :, 1:, :, :] image = image[1:] - frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." positive, negative, latent_image, noise_mask = cls.append_keyframe( @@ -407,7 +417,7 @@ class LTXVCropGuides(io.ComfyNode): latent_image = latent["samples"].clone() noise_mask = get_noise_mask(latent) - _, num_keyframes = get_keyframe_idxs(positive) + _, num_keyframes = get_keyframe_idxs(positive, latent_image.shape) if num_keyframes == 0: return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)