Fixing bug where the attentions were mixed together even when the weight was 0.0

This commit is contained in:
InconsolableCellist 2023-04-30 17:31:55 -06:00
parent 2ae3b42b26
commit 62f6a1a1da

View File

@ -176,10 +176,10 @@ def _get_attention_scores_no_kv_chunking(
attn_scores /= summed
attn_probs = attn_scores
if attention_to_mux is not None:
if attention_to_mux is not None and attention_weight > 0.0:
attention_to_mux = attention_to_mux.to(attn_probs.device)
attn_probs = attn_probs * (1 - attention_weight) + attention_to_mux * attention_weight
print(f"muxed attention with weight {attention_weight}")
# print(f"muxed attention with weight {attention_weight}")
hidden_states_slice = torch.bmm(attn_probs, value)