Allow strength above 1.0

This commit is contained in:
kijai 2026-05-06 23:56:21 +03:00
parent 848880c3d3
commit 989dea8c40
2 changed files with 12 additions and 10 deletions

View File

@ -1122,7 +1122,9 @@ class LTXVModel(LTXBaseModel):
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@ -1158,12 +1160,12 @@ class LTXVModel(LTXBaseModel):
if not resolved_entries:
return None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
needs_mask = any(
e["strength"] != 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_attenuation:
if not needs_mask:
return None
# Build per-guide-token weights for all tracked guide tokens.
@ -1218,8 +1220,8 @@ class LTXVModel(LTXBaseModel):
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
# Skip when every weight is exactly 1.0 (additive bias would be 0).
if (tracked_weights == 1.0).all():
return None
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)

View File

@ -223,7 +223,7 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -302,7 +302,7 @@ class LTXVAddGuide(io.ComfyNode):
else:
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype,
device=noise_mask.device,
)
@ -322,7 +322,7 @@ class LTXVAddGuide(io.ComfyNode):
mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength,
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype,
device=noise_mask.device,
)