Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-12-03 01:25:57 +03:00 committed by GitHub
commit 3d3a049c06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 0 deletions

View File

@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
if exception_fallback:
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),

View File

@ -279,6 +279,7 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@ -289,6 +290,8 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out