diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index ce9b07464..2ab4990fc 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -12,6 +12,12 @@ import comfy.text_encoders.qwen_vl from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope +FLA_IS_AVAILABLE = False +try: + from fla.ops.gated_delta_rule import chunk_gated_delta_rule as fla_chunk_gated_delta_rule + FLA_IS_AVAILABLE = True +except ImportError: + pass def _qwen35_layer_types(n): return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)] @@ -253,11 +259,19 @@ class GatedDeltaNet(nn.Module): query = query.repeat_interleave(rep, dim=2) key = key.repeat_interleave(rep, dim=2) - core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule( - query, key, value, g=g, beta=beta, - initial_state=None, - output_final_state=past_key_value is not None, - ) + if FLA_IS_AVAILABLE: + core_attn_out, last_recurrent_state = fla_chunk_gated_delta_rule( + query, key, value, g=g, beta=beta, + initial_state=None, + output_final_state=past_key_value is not None, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule( + query, key, value, g=g, beta=beta, + initial_state=None, + output_final_state=past_key_value is not None, + ) present_key_value = None if past_key_value is not None: