mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Add some safety guards
This commit is contained in:
parent
1e76c3b9c9
commit
2d07a1004a
@ -383,14 +383,6 @@ class GuideAttentionMask:
|
|||||||
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
|
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
|
||||||
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
|
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
new = GuideAttentionMask.__new__(GuideAttentionMask)
|
|
||||||
new.guide_start = self.guide_start
|
|
||||||
new.tracked_count = self.tracked_count
|
|
||||||
new.noisy_mask = self.noisy_mask.to(*args, **kwargs)
|
|
||||||
new.tracked_mask = self.tracked_mask.to(*args, **kwargs)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
|
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
|
||||||
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
|
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
|
||||||
@ -402,16 +394,22 @@ def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, trans
|
|||||||
|
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(q)
|
||||||
|
|
||||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
|
||||||
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||||
low_precision_attention=False, # sageattn mask support is unreliable
|
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||||
)
|
low_precision_attention=False, # sageattn mask support is unreliable
|
||||||
|
)
|
||||||
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||||
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
||||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||||
low_precision_attention=False,
|
low_precision_attention=False,
|
||||||
)
|
)
|
||||||
|
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
|
||||||
|
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||||
|
q[:, tracked_end:, :], k, v, heads,
|
||||||
|
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user