mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
Update attention.py to keep older torch versions running
This commit is contained in:
parent
eaf40b802d
commit
bcea9b9a0c
@ -503,16 +503,23 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
try:
|
||||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||||
|
|
||||||
|
|
||||||
@flash_attn_wrapper.register_fake
|
@flash_attn_wrapper.register_fake
|
||||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||||
# Output shape is the same as q
|
# Output shape is the same as q
|
||||||
return q.new_empty(q.shape)
|
return q.new_empty(q.shape)
|
||||||
|
except AttributeError as error:
|
||||||
|
FLASH_ATTN_ERROR = error
|
||||||
|
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
|
|
||||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user