From 2d07a1004a6aeca2b5b0ec9e642e080f2f789e1b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 9 May 2026 14:57:23 +0300 Subject: [PATCH] Add some safety guards --- comfy/ldm/lightricks/model.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 80a3f08d7..e0a4a0f9b 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -383,14 +383,6 @@ class GuideAttentionMask: self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype) self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1) - def to(self, *args, **kwargs): - new = GuideAttentionMask.__new__(GuideAttentionMask) - new.guide_start = self.guide_start - new.tracked_count = self.tracked_count - new.noisy_mask = self.noisy_mask.to(*args, **kwargs) - new.tracked_mask = self.tracked_mask.to(*args, **kwargs) - return new - def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options): """Apply the guide mask by partitioning Q into noisy and tracked-guide @@ -402,16 +394,22 @@ def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, trans out = torch.empty_like(q) - out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention( - q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask, - attn_precision=attn_precision, transformer_options=transformer_options, - low_precision_attention=False, # sageattn mask support is unreliable - ) + if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes. + out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask, + attn_precision=attn_precision, transformer_options=transformer_options, + low_precision_attention=False, # sageattn mask support is unreliable + ) out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention( q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask, attn_precision=attn_precision, transformer_options=transformer_options, low_precision_attention=False, ) + if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes. + out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, tracked_end:, :], k, v, heads, + attn_precision=attn_precision, transformer_options=transformer_options, + ) return out