ComfyUI/comfy/flash_attn_triton_amd/interface_fa.py
2025-05-01 23:04:46 +03:00

395 lines
12 KiB
Python

import os
import torch
from comfy.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl
from comfy.flash_attn_triton_amd.bwd_prefill import attention_prefill_backward_triton_impl
from comfy.flash_attn_triton_amd.fwd_decode import attention_decode_forward_triton_impl
from comfy.flash_attn_triton_amd.fwd_ref import attention_forward_pytorch_ref_impl
from comfy.flash_attn_triton_amd.bwd_ref import attention_backward_pytorch_ref_impl
from comfy.flash_attn_triton_amd.utils import MetaData, get_shape_from_layout
USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
def fwd(q,
k,
v,
o,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
return_softmax,
gen_ # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
if o is None:
o = torch.empty_like(q)
# Setup metadata
metadata = MetaData(sm_scale=softmax_scale)
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k.shape[1]
metadata.layout = "bshd"
if return_softmax:
metadata.return_scores = True
batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) # pylint: disable=unused-variable
if causal:
metadata.need_causal()
if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, batch, nheads_q)
if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
# Check arguments
metadata.check_args(q, k, v, o)
if USE_REF:
(output,
softmax_lse,
exp_scores,
_,
_,
_,
_) = attention_forward_pytorch_ref_impl(
q,
k,
v,
metadata.sm_scale,
metadata.causal,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.use_exp2)
o.copy_(output)
else:
(_,
softmax_lse,
exp_scores,
_,
_,
_,
_,
_,
_) = attention_prefill_forward_triton_impl(
q,
k,
v,
o,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.use_exp2)
return o, softmax_lse, exp_scores, None
def bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
if USE_REF:
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
dout,
q,
k,
v,
out,
softmax_lse,
softmax_scale,
causal,
"bshd",
None,
None,
None,
None,
False,
)
dq.copy_(dq_ref)
dk.copy_(dk_ref)
dv.copy_(dv_ref)
delta = delta_ref
else:
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
alibi_slopes,
causal,
"bshd",
None,
None,
None,
None,
False,
)
delta = delta_triton
return dq, dk, dv, delta
def varlen_fwd(
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
seqused_k, leftpad_k, block_table_, # pylint: disable=unused-argument
alibi_slopes,\
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors, # pylint: disable=unused-argument
causal,
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
return_softmax,
gen_ # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
if o is None:
o = torch.empty_like(q)
# Setup metadata
metadata = MetaData(sm_scale=softmax_scale)
if return_softmax:
metadata.return_scores = True
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
# get shapes
batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
if causal:
metadata.need_causal()
if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, batch, nheads_q)
if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
# Check arguments
metadata.check_args(q, k, v, o)
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
if USE_REF:
(output,
softmax_lse,
exp_scores,
_,
_,
_,
_) = attention_forward_pytorch_ref_impl(
q,
k,
v,
metadata.sm_scale,
metadata.causal,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.use_exp2)
o.copy_(output)
else:
(_,
softmax_lse,
exp_scores,
_,
_,
_,
_,
_,
_) = attention_prefill_forward_triton_impl(
q,
k,
v,
o,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.use_exp2)
return o, softmax_lse, exp_scores, None
def varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors, # pylint: disable=unused-argument
causal,
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
):
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
if USE_REF:
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
dout,
q,
k,
v,
out,
softmax_lse,
softmax_scale,
causal,
"thd",
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
False,
)
dq.copy_(dq_ref)
dk.copy_(dk_ref)
dv.copy_(dv_ref)
delta = delta_ref
else:
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
alibi_slopes,
causal,
"thd",
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
False,
)
delta = delta_triton
return dq, dk, dv, delta
def fwd_kvcache(
q,
k_cache,
v_cache,
k,
v,
cache_seqlens,
rotary_cos, rotary_sin, # pylint: disable=unused-argument
cache_batch_idx,
cache_leftpad, block_table, # pylint: disable=unused-argument
alibi_slopes,
out,
softmax_scale,
causal,
window_size_left, window_size_right, softcap, rotary_interleaved, num_splits, # pylint: disable=unused-argument
):
if out is None:
out = torch.empty_like(q)
# fill metadata
metadata = MetaData(sm_scale=softmax_scale)
metadata.layout = "bshd"
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k_cache.shape[1]
metadata.cache_seqlens = cache_seqlens
metadata.cache_batch_idx = cache_batch_idx
if k is not None and v is not None:
metadata.new_kv = True
metadata.seqlen_new = k.shape[1]
metadata.k_new = k
metadata.v_new = v
if causal:
metadata.need_causal()
if alibi_slopes is not None:
batch, _ , nheads_q, _= q.shape
metadata.need_alibi(alibi_slopes, batch, nheads_q)
# launch kernel
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
output, softmax_lse = attention_decode_forward_triton_impl(
q,
k_cache,
v_cache,
metadata.sm_scale,
metadata.causal,
metadata.alibi_slopes,
metadata.layout,
metadata.cache_seqlens,
metadata.cache_batch_idx,
metadata.new_kv,
metadata.k_new,
metadata.v_new,
)
return output, softmax_lse