Add some safety guards

This commit is contained in:
kijai 2026-05-09 14:57:23 +03:00
parent 1e76c3b9c9
commit 2d07a1004a

View File

@ -383,14 +383,6 @@ class GuideAttentionMask:
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
def to(self, *args, **kwargs):
new = GuideAttentionMask.__new__(GuideAttentionMask)
new.guide_start = self.guide_start
new.tracked_count = self.tracked_count
new.noisy_mask = self.noisy_mask.to(*args, **kwargs)
new.tracked_mask = self.tracked_mask.to(*args, **kwargs)
return new
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
@ -402,16 +394,22 @@ def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, trans
out = torch.empty_like(q)
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False,
)
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, tracked_end:, :], k, v, heads,
attn_precision=attn_precision, transformer_options=transformer_options,
)
return out