mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
395 lines
12 KiB
Python
395 lines
12 KiB
Python
import os
|
|
import torch
|
|
from comfy.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl
|
|
from comfy.flash_attn_triton_amd.bwd_prefill import attention_prefill_backward_triton_impl
|
|
from comfy.flash_attn_triton_amd.fwd_decode import attention_decode_forward_triton_impl
|
|
from comfy.flash_attn_triton_amd.fwd_ref import attention_forward_pytorch_ref_impl
|
|
from comfy.flash_attn_triton_amd.bwd_ref import attention_backward_pytorch_ref_impl
|
|
from comfy.flash_attn_triton_amd.utils import MetaData, get_shape_from_layout
|
|
|
|
|
|
USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
|
|
|
|
|
|
def fwd(q,
|
|
k,
|
|
v,
|
|
o,
|
|
alibi_slopes,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
|
|
return_softmax,
|
|
gen_ # pylint: disable=unused-argument
|
|
):
|
|
if dropout_p != 0.0:
|
|
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
|
|
|
|
if o is None:
|
|
o = torch.empty_like(q)
|
|
|
|
# Setup metadata
|
|
metadata = MetaData(sm_scale=softmax_scale)
|
|
metadata.max_seqlens_q = q.shape[1]
|
|
metadata.max_seqlens_k = k.shape[1]
|
|
metadata.layout = "bshd"
|
|
if return_softmax:
|
|
metadata.return_scores = True
|
|
|
|
batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) # pylint: disable=unused-variable
|
|
|
|
if causal:
|
|
metadata.need_causal()
|
|
|
|
if alibi_slopes is not None:
|
|
metadata.need_alibi(alibi_slopes, batch, nheads_q)
|
|
|
|
if dropout_p > 0.0:
|
|
metadata.need_dropout(dropout_p, return_softmax)
|
|
|
|
# Check arguments
|
|
metadata.check_args(q, k, v, o)
|
|
if USE_REF:
|
|
(output,
|
|
softmax_lse,
|
|
exp_scores,
|
|
_,
|
|
_,
|
|
_,
|
|
_) = attention_forward_pytorch_ref_impl(
|
|
q,
|
|
k,
|
|
v,
|
|
metadata.sm_scale,
|
|
metadata.causal,
|
|
metadata.layout,
|
|
metadata.cu_seqlens_q,
|
|
metadata.cu_seqlens_k,
|
|
metadata.max_seqlens_q,
|
|
metadata.max_seqlens_k,
|
|
metadata.use_exp2)
|
|
o.copy_(output)
|
|
else:
|
|
(_,
|
|
softmax_lse,
|
|
exp_scores,
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
_) = attention_prefill_forward_triton_impl(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
metadata.sm_scale,
|
|
metadata.alibi_slopes,
|
|
metadata.causal,
|
|
metadata.bias,
|
|
metadata.dropout_p,
|
|
metadata.layout,
|
|
metadata.cu_seqlens_q,
|
|
metadata.cu_seqlens_k,
|
|
metadata.max_seqlens_q,
|
|
metadata.max_seqlens_k,
|
|
metadata.return_scores,
|
|
metadata.use_exp2)
|
|
|
|
return o, softmax_lse, exp_scores, None
|
|
|
|
|
|
def bwd(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
alibi_slopes,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
|
|
):
|
|
if dropout_p != 0.0:
|
|
raise ValueError("dropout is not supported on AMD yet")
|
|
|
|
if USE_REF:
|
|
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
softmax_scale,
|
|
causal,
|
|
"bshd",
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
False,
|
|
)
|
|
dq.copy_(dq_ref)
|
|
dk.copy_(dk_ref)
|
|
dv.copy_(dv_ref)
|
|
delta = delta_ref
|
|
else:
|
|
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
softmax_scale,
|
|
alibi_slopes,
|
|
causal,
|
|
"bshd",
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
False,
|
|
)
|
|
delta = delta_triton
|
|
|
|
return dq, dk, dv, delta
|
|
|
|
|
|
def varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
seqused_k, leftpad_k, block_table_, # pylint: disable=unused-argument
|
|
alibi_slopes,\
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
zero_tensors, # pylint: disable=unused-argument
|
|
causal,
|
|
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
|
|
return_softmax,
|
|
gen_ # pylint: disable=unused-argument
|
|
):
|
|
if dropout_p != 0.0:
|
|
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
|
|
|
|
if o is None:
|
|
o = torch.empty_like(q)
|
|
|
|
# Setup metadata
|
|
metadata = MetaData(sm_scale=softmax_scale)
|
|
if return_softmax:
|
|
metadata.return_scores = True
|
|
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
|
|
|
|
# get shapes
|
|
batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
|
|
|
|
if causal:
|
|
metadata.need_causal()
|
|
|
|
if alibi_slopes is not None:
|
|
metadata.need_alibi(alibi_slopes, batch, nheads_q)
|
|
|
|
if dropout_p > 0.0:
|
|
metadata.need_dropout(dropout_p, return_softmax)
|
|
|
|
# Check arguments
|
|
metadata.check_args(q, k, v, o)
|
|
if o is None:
|
|
o = torch.empty_like(q, dtype=v.dtype)
|
|
|
|
if USE_REF:
|
|
(output,
|
|
softmax_lse,
|
|
exp_scores,
|
|
_,
|
|
_,
|
|
_,
|
|
_) = attention_forward_pytorch_ref_impl(
|
|
q,
|
|
k,
|
|
v,
|
|
metadata.sm_scale,
|
|
metadata.causal,
|
|
metadata.layout,
|
|
metadata.cu_seqlens_q,
|
|
metadata.cu_seqlens_k,
|
|
metadata.max_seqlens_q,
|
|
metadata.max_seqlens_k,
|
|
metadata.use_exp2)
|
|
o.copy_(output)
|
|
else:
|
|
(_,
|
|
softmax_lse,
|
|
exp_scores,
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
_) = attention_prefill_forward_triton_impl(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
metadata.sm_scale,
|
|
metadata.alibi_slopes,
|
|
metadata.causal,
|
|
metadata.bias,
|
|
metadata.dropout_p,
|
|
metadata.layout,
|
|
metadata.cu_seqlens_q,
|
|
metadata.cu_seqlens_k,
|
|
metadata.max_seqlens_q,
|
|
metadata.max_seqlens_k,
|
|
metadata.return_scores,
|
|
metadata.use_exp2)
|
|
|
|
return o, softmax_lse, exp_scores, None
|
|
|
|
|
|
def varlen_bwd(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
alibi_slopes,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
zero_tensors, # pylint: disable=unused-argument
|
|
causal,
|
|
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
|
|
):
|
|
if dropout_p != 0.0:
|
|
raise ValueError("dropout is not supported on AMD yet")
|
|
|
|
if USE_REF:
|
|
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
softmax_scale,
|
|
causal,
|
|
"thd",
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
False,
|
|
)
|
|
dq.copy_(dq_ref)
|
|
dk.copy_(dk_ref)
|
|
dv.copy_(dv_ref)
|
|
delta = delta_ref
|
|
else:
|
|
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
softmax_scale,
|
|
alibi_slopes,
|
|
causal,
|
|
"thd",
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
False,
|
|
)
|
|
delta = delta_triton
|
|
|
|
return dq, dk, dv, delta
|
|
|
|
|
|
def fwd_kvcache(
|
|
q,
|
|
k_cache,
|
|
v_cache,
|
|
k,
|
|
v,
|
|
cache_seqlens,
|
|
rotary_cos, rotary_sin, # pylint: disable=unused-argument
|
|
cache_batch_idx,
|
|
cache_leftpad, block_table, # pylint: disable=unused-argument
|
|
alibi_slopes,
|
|
out,
|
|
softmax_scale,
|
|
causal,
|
|
window_size_left, window_size_right, softcap, rotary_interleaved, num_splits, # pylint: disable=unused-argument
|
|
):
|
|
if out is None:
|
|
out = torch.empty_like(q)
|
|
|
|
# fill metadata
|
|
metadata = MetaData(sm_scale=softmax_scale)
|
|
metadata.layout = "bshd"
|
|
metadata.max_seqlens_q = q.shape[1]
|
|
metadata.max_seqlens_k = k_cache.shape[1]
|
|
metadata.cache_seqlens = cache_seqlens
|
|
metadata.cache_batch_idx = cache_batch_idx
|
|
|
|
if k is not None and v is not None:
|
|
metadata.new_kv = True
|
|
metadata.seqlen_new = k.shape[1]
|
|
metadata.k_new = k
|
|
metadata.v_new = v
|
|
|
|
if causal:
|
|
metadata.need_causal()
|
|
|
|
if alibi_slopes is not None:
|
|
batch, _ , nheads_q, _= q.shape
|
|
metadata.need_alibi(alibi_slopes, batch, nheads_q)
|
|
|
|
# launch kernel
|
|
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
|
|
output, softmax_lse = attention_decode_forward_triton_impl(
|
|
q,
|
|
k_cache,
|
|
v_cache,
|
|
metadata.sm_scale,
|
|
metadata.causal,
|
|
metadata.alibi_slopes,
|
|
metadata.layout,
|
|
metadata.cache_seqlens,
|
|
metadata.cache_batch_idx,
|
|
metadata.new_kv,
|
|
metadata.k_new,
|
|
metadata.v_new,
|
|
)
|
|
return output, softmax_lse
|