diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 981b85da2..64b743fb6 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -176,10 +176,10 @@ def _get_attention_scores_no_kv_chunking( attn_scores /= summed attn_probs = attn_scores - if attention_to_mux is not None: + if attention_to_mux is not None and attention_weight > 0.0: attention_to_mux = attention_to_mux.to(attn_probs.device) attn_probs = attn_probs * (1 - attention_weight) + attention_to_mux * attention_weight - print(f"muxed attention with weight {attention_weight}") + # print(f"muxed attention with weight {attention_weight}") hidden_states_slice = torch.bmm(attn_probs, value)