diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index a78b9421e..a8cf894e4 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -111,16 +111,19 @@ class SelfAttention(nn.Module): self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - def get_qkv(self, x): - q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) - k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1) - v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) - return q, k, v - def forward(self, x, freqs, transformer_options={}): - q, k, v = self.get_qkv(x) - q = apply_rope1(self.query_norm(q), freqs) - k = apply_rope1(self.key_norm(k), freqs) + def compute_q(x): + q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) + return apply_rope1(self.query_norm(q), freqs) + + def compute_k(x): + k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1) + return apply_rope1(self.key_norm(k), freqs) + + q = compute_q(x) + k = compute_k(x) + + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) return self.out_layer(out)