diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index afbab2ac7..23627f55d 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -11,7 +11,7 @@ import comfy.ldm.common_dit from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.math import apply_rope, apply_rope1 import comfy.patcher_extension @@ -111,7 +111,12 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) - xq, xk = apply_rope(xq, xk, freqs_cis) + # Use apply_rope1 separately for GQA compatibility (when n_heads != n_kv_heads) + if xq.shape == xk.shape: + xq, xk = apply_rope(xq, xk, freqs_cis) + else: + xq = apply_rope1(xq, freqs_cis) + xk = apply_rope1(xk, freqs_cis) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: