mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-02 21:32:31 +08:00
qwen35: support flash linear attention.
This commit is contained in:
parent
8cbbea8f6a
commit
7f71575235
@ -12,6 +12,12 @@ import comfy.text_encoders.qwen_vl
|
|||||||
|
|
||||||
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
||||||
|
|
||||||
|
FLA_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from fla.ops.gated_delta_rule import chunk_gated_delta_rule as fla_chunk_gated_delta_rule
|
||||||
|
FLA_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
def _qwen35_layer_types(n):
|
def _qwen35_layer_types(n):
|
||||||
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
||||||
@ -253,11 +259,19 @@ class GatedDeltaNet(nn.Module):
|
|||||||
query = query.repeat_interleave(rep, dim=2)
|
query = query.repeat_interleave(rep, dim=2)
|
||||||
key = key.repeat_interleave(rep, dim=2)
|
key = key.repeat_interleave(rep, dim=2)
|
||||||
|
|
||||||
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
if FLA_IS_AVAILABLE:
|
||||||
query, key, value, g=g, beta=beta,
|
core_attn_out, last_recurrent_state = fla_chunk_gated_delta_rule(
|
||||||
initial_state=None,
|
query, key, value, g=g, beta=beta,
|
||||||
output_final_state=past_key_value is not None,
|
initial_state=None,
|
||||||
)
|
output_final_state=past_key_value is not None,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||||
|
query, key, value, g=g, beta=beta,
|
||||||
|
initial_state=None,
|
||||||
|
output_final_state=past_key_value is not None,
|
||||||
|
)
|
||||||
|
|
||||||
present_key_value = None
|
present_key_value = None
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user