mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-02 21:32:31 +08:00
qwen35: support flash linear attention.
This commit is contained in:
parent
8cbbea8f6a
commit
7f71575235
@ -12,6 +12,12 @@ import comfy.text_encoders.qwen_vl
|
||||
|
||||
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
||||
|
||||
FLA_IS_AVAILABLE = False
|
||||
try:
|
||||
from fla.ops.gated_delta_rule import chunk_gated_delta_rule as fla_chunk_gated_delta_rule
|
||||
FLA_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def _qwen35_layer_types(n):
|
||||
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
||||
@ -253,11 +259,19 @@ class GatedDeltaNet(nn.Module):
|
||||
query = query.repeat_interleave(rep, dim=2)
|
||||
key = key.repeat_interleave(rep, dim=2)
|
||||
|
||||
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||
query, key, value, g=g, beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=past_key_value is not None,
|
||||
)
|
||||
if FLA_IS_AVAILABLE:
|
||||
core_attn_out, last_recurrent_state = fla_chunk_gated_delta_rule(
|
||||
query, key, value, g=g, beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=past_key_value is not None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
else:
|
||||
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||
query, key, value, g=g, beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=past_key_value is not None,
|
||||
)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user