mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 08:40:50 +08:00
Reduce peak VRAM usage a bit
This commit is contained in:
parent
910c47abc8
commit
1baf5ec4af
@ -111,16 +111,19 @@ class SelfAttention(nn.Module):
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def get_qkv(self, x):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x)
|
||||
q = apply_rope1(self.query_norm(q), freqs)
|
||||
k = apply_rope1(self.key_norm(k), freqs)
|
||||
def compute_q(x):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(self.query_norm(q), freqs)
|
||||
|
||||
def compute_k(x):
|
||||
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(self.key_norm(k), freqs)
|
||||
|
||||
q = compute_q(x)
|
||||
k = compute_k(x)
|
||||
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user