mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
Merge 09f15d82da into 26515acd23
This commit is contained in:
commit
4ca03db082
@ -183,10 +183,20 @@ def get_noise_mask(latent):
|
||||
noise_mask = noise_mask.clone()
|
||||
return noise_mask
|
||||
|
||||
def get_keyframe_idxs(cond):
|
||||
def get_keyframe_idxs(cond, latent_shape=None):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
# Get number of keyframes from latent_shape or guide_attention_entries if available
|
||||
if latent_shape is not None and len(latent_shape) == 5:
|
||||
tokens_per_frame = latent_shape[-2] * latent_shape[-1]
|
||||
num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
|
||||
return keyframe_idxs, num_keyframes
|
||||
entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
|
||||
if entries:
|
||||
num_keyframes = sum(e["latent_shape"][0] for e in entries)
|
||||
return keyframe_idxs, num_keyframes
|
||||
# fallback, may under-count if keyframes share t-start
|
||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
@ -238,9 +248,9 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return encode_pixels, t
|
||||
|
||||
@classmethod
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
_, num_keyframes = get_keyframe_idxs(cond, latent_shape)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1 and frame_idx != 0:
|
||||
@ -344,7 +354,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
||||
resolved_frame_idx = frame_idx
|
||||
if frame_idx < 0:
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
|
||||
|
||||
@ -357,7 +367,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
t = t[:, :, 1:, :, :]
|
||||
image = image[1:]
|
||||
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||
@ -407,7 +417,7 @@ class LTXVCropGuides(io.ComfyNode):
|
||||
latent_image = latent["samples"].clone()
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||
if num_keyframes == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user