mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 03:27:24 +08:00
fix: Stop LTXVCropGuides leaving stray latent frames when guides share a start position
This commit is contained in:
parent
26515acd23
commit
09f15d82da
@ -183,10 +183,20 @@ def get_noise_mask(latent):
|
|||||||
noise_mask = noise_mask.clone()
|
noise_mask = noise_mask.clone()
|
||||||
return noise_mask
|
return noise_mask
|
||||||
|
|
||||||
def get_keyframe_idxs(cond):
|
def get_keyframe_idxs(cond, latent_shape=None):
|
||||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||||
if keyframe_idxs is None:
|
if keyframe_idxs is None:
|
||||||
return None, 0
|
return None, 0
|
||||||
|
# Get number of keyframes from latent_shape or guide_attention_entries if available
|
||||||
|
if latent_shape is not None and len(latent_shape) == 5:
|
||||||
|
tokens_per_frame = latent_shape[-2] * latent_shape[-1]
|
||||||
|
num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
|
||||||
|
return keyframe_idxs, num_keyframes
|
||||||
|
entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
|
||||||
|
if entries:
|
||||||
|
num_keyframes = sum(e["latent_shape"][0] for e in entries)
|
||||||
|
return keyframe_idxs, num_keyframes
|
||||||
|
# fallback, may under-count if keyframes share t-start
|
||||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||||
return keyframe_idxs, num_keyframes
|
return keyframe_idxs, num_keyframes
|
||||||
@ -238,9 +248,9 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
return encode_pixels, t
|
return encode_pixels, t
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
|
||||||
time_scale_factor, _, _ = scale_factors
|
time_scale_factor, _, _ = scale_factors
|
||||||
_, num_keyframes = get_keyframe_idxs(cond)
|
_, num_keyframes = get_keyframe_idxs(cond, latent_shape)
|
||||||
latent_count = latent_length - num_keyframes
|
latent_count = latent_length - num_keyframes
|
||||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||||
if guide_length > 1 and frame_idx != 0:
|
if guide_length > 1 and frame_idx != 0:
|
||||||
@ -344,7 +354,7 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
||||||
resolved_frame_idx = frame_idx
|
resolved_frame_idx = frame_idx
|
||||||
if frame_idx < 0:
|
if frame_idx < 0:
|
||||||
_, num_keyframes = get_keyframe_idxs(positive)
|
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||||
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
|
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||||
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
|
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
|
||||||
|
|
||||||
@ -357,7 +367,7 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
t = t[:, :, 1:, :, :]
|
t = t[:, :, 1:, :, :]
|
||||||
image = image[1:]
|
image = image[1:]
|
||||||
|
|
||||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
|
||||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||||
@ -407,7 +417,7 @@ class LTXVCropGuides(io.ComfyNode):
|
|||||||
latent_image = latent["samples"].clone()
|
latent_image = latent["samples"].clone()
|
||||||
noise_mask = get_noise_mask(latent)
|
noise_mask = get_noise_mask(latent)
|
||||||
|
|
||||||
_, num_keyframes = get_keyframe_idxs(positive)
|
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||||
if num_keyframes == 0:
|
if num_keyframes == 0:
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user