diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 7ccb2c42a..80a3f08d7 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1122,7 +1122,9 @@ class LTXVModel(LTXBaseModel): additional_args["resolved_guide_entries"] = resolved_entries keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] - pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs + + if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs # Total surviving guide tokens (all guides) additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] @@ -1158,12 +1160,12 @@ class LTXVModel(LTXBaseModel): if not resolved_entries: return None - # Check if any attenuation is actually needed - needs_attenuation = any( - e["strength"] < 1.0 or e.get("pixel_mask") is not None + # strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention. + needs_mask = any( + e["strength"] != 1.0 or e.get("pixel_mask") is not None for e in resolved_entries ) - if not needs_attenuation: + if not needs_mask: return None # Build per-guide-token weights for all tracked guide tokens. @@ -1218,8 +1220,8 @@ class LTXVModel(LTXBaseModel): # Concatenate per-token weights for all tracked guides tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) - # Check if any weight is actually < 1.0 (otherwise no attenuation needed) - if (tracked_weights >= 1.0).all(): + # Skip when every weight is exactly 1.0 (additive bias would be 0). + if (tracked_weights == 1.0).all(): return None return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 19d8a387f..28c231daf 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -223,7 +223,7 @@ class LTXVAddGuide(io.ComfyNode): "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " "down to the nearest multiple of 8. Negative values are counted from the end of the video.", ), - io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -302,7 +302,7 @@ class LTXVAddGuide(io.ComfyNode): else: mask = torch.full( (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, + max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask dtype=noise_mask.dtype, device=noise_mask.device, ) @@ -322,7 +322,7 @@ class LTXVAddGuide(io.ComfyNode): mask = torch.full( (noise_mask.shape[0], 1, cond_length, 1, 1), - 1.0 - strength, + max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask dtype=noise_mask.dtype, device=noise_mask.device, )