mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
Add some safety guards
This commit is contained in:
parent
1e76c3b9c9
commit
2d07a1004a
@ -383,14 +383,6 @@ class GuideAttentionMask:
|
||||
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
|
||||
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new = GuideAttentionMask.__new__(GuideAttentionMask)
|
||||
new.guide_start = self.guide_start
|
||||
new.tracked_count = self.tracked_count
|
||||
new.noisy_mask = self.noisy_mask.to(*args, **kwargs)
|
||||
new.tracked_mask = self.tracked_mask.to(*args, **kwargs)
|
||||
return new
|
||||
|
||||
|
||||
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
|
||||
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
|
||||
@ -402,16 +394,22 @@ def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, trans
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False, # sageattn mask support is unreliable
|
||||
)
|
||||
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
|
||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False, # sageattn mask support is unreliable
|
||||
)
|
||||
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
|
||||
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, tracked_end:, :], k, v, heads,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user