diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index a4c85db77..3dc1199c2 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -338,8 +338,25 @@ class LTXVAddGuide(io.ComfyNode): noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape + + # For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot + time_scale_factor = scale_factors[0] + num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 + resolved_frame_idx = frame_idx + if frame_idx < 0: + _, num_keyframes = get_keyframe_idxs(positive) + resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0) + causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1 + + if not causal_fix: + image = torch.cat([image[:1], image], dim=0) + image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) + if not causal_fix: + t = t[:, :, 1:, :, :] + image = image[1:] + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." @@ -352,6 +369,7 @@ class LTXVAddGuide(io.ComfyNode): t, strength, scale_factors, + causal_fix=causal_fix, ) # Track this guide for per-reference attention control.