Reduce peak VRAM usage a bit

This commit is contained in:
kijai 2025-11-28 01:11:10 +02:00
parent 910c47abc8
commit 1baf5ec4af

View File

@ -111,16 +111,19 @@ class SelfAttention(nn.Module):
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def get_qkv(self, x):
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
return q, k, v
def forward(self, x, freqs, transformer_options={}):
q, k, v = self.get_qkv(x)
q = apply_rope1(self.query_norm(q), freqs)
k = apply_rope1(self.key_norm(k), freqs)
def compute_q(x):
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(self.query_norm(q), freqs)
def compute_k(x):
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(self.key_norm(k), freqs)
q = compute_q(x)
k = compute_k(x)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)