qwen35: support flash linear attention.

This commit is contained in:
omarom 2026-04-05 14:15:44 +00:00
parent 8cbbea8f6a
commit 7f71575235

View File

@ -12,6 +12,12 @@ import comfy.text_encoders.qwen_vl
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
FLA_IS_AVAILABLE = False
try:
from fla.ops.gated_delta_rule import chunk_gated_delta_rule as fla_chunk_gated_delta_rule
FLA_IS_AVAILABLE = True
except ImportError:
pass
def _qwen35_layer_types(n): def _qwen35_layer_types(n):
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)] return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
@ -253,11 +259,19 @@ class GatedDeltaNet(nn.Module):
query = query.repeat_interleave(rep, dim=2) query = query.repeat_interleave(rep, dim=2)
key = key.repeat_interleave(rep, dim=2) key = key.repeat_interleave(rep, dim=2)
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule( if FLA_IS_AVAILABLE:
query, key, value, g=g, beta=beta, core_attn_out, last_recurrent_state = fla_chunk_gated_delta_rule(
initial_state=None, query, key, value, g=g, beta=beta,
output_final_state=past_key_value is not None, initial_state=None,
) output_final_state=past_key_value is not None,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
query, key, value, g=g, beta=beta,
initial_state=None,
output_final_state=past_key_value is not None,
)
present_key_value = None present_key_value = None
if past_key_value is not None: if past_key_value is not None: