mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
attention: use flag based OOM fallback (#11038)
Exception ref all local variables for the lifetime of exception context. Just set a flag and then if to dump the exception before falling back.
This commit is contained in:
parent
daaceac769
commit
277237ccc1
@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
exception_fallback = False
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout = "HND"
|
tensor_layout = "HND"
|
||||||
@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||||
|
exception_fallback = True
|
||||||
|
if exception_fallback:
|
||||||
if tensor_layout == "NHD":
|
if tensor_layout == "NHD":
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.transpose(1, 2),
|
lambda t: t.transpose(1, 2),
|
||||||
|
|||||||
@ -279,6 +279,7 @@ def pytorch_attention(q, k, v):
|
|||||||
orig_shape = q.shape
|
orig_shape = q.shape
|
||||||
B = orig_shape[0]
|
B = orig_shape[0]
|
||||||
C = orig_shape[1]
|
C = orig_shape[1]
|
||||||
|
oom_fallback = False
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
@ -289,6 +290,8 @@ def pytorch_attention(q, k, v):
|
|||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
|
oom_fallback = True
|
||||||
|
if oom_fallback:
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user