Allow strength above 1.0

This commit is contained in:
kijai 2026-05-06 23:56:21 +03:00
parent 848880c3d3
commit 989dea8c40
2 changed files with 12 additions and 10 deletions

View File

@ -1122,7 +1122,9 @@ class LTXVModel(LTXBaseModel):
additional_args["resolved_guide_entries"] = resolved_entries additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides) # Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@ -1158,12 +1160,12 @@ class LTXVModel(LTXBaseModel):
if not resolved_entries: if not resolved_entries:
return None return None
# Check if any attenuation is actually needed # strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
needs_attenuation = any( needs_mask = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None e["strength"] != 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries for e in resolved_entries
) )
if not needs_attenuation: if not needs_mask:
return None return None
# Build per-guide-token weights for all tracked guide tokens. # Build per-guide-token weights for all tracked guide tokens.
@ -1218,8 +1220,8 @@ class LTXVModel(LTXBaseModel):
# Concatenate per-token weights for all tracked guides # Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed) # Skip when every weight is exactly 1.0 (additive bias would be 0).
if (tracked_weights >= 1.0).all(): if (tracked_weights == 1.0).all():
return None return None
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights) return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)

View File

@ -223,7 +223,7 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.", "down to the nearest multiple of 8. Negative values are counted from the end of the video.",
), ),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
], ],
outputs=[ outputs=[
io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="positive"),
@ -302,7 +302,7 @@ class LTXVAddGuide(io.ComfyNode):
else: else:
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength, max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype, dtype=noise_mask.dtype,
device=noise_mask.device, device=noise_mask.device,
) )
@ -322,7 +322,7 @@ class LTXVAddGuide(io.ComfyNode):
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1), (noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength, max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype, dtype=noise_mask.dtype,
device=noise_mask.device, device=noise_mask.device,
) )