mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 08:22:36 +08:00
Reduce peak VRAM usage a bit
This commit is contained in:
parent
910c47abc8
commit
1baf5ec4af
@ -111,16 +111,19 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
def get_qkv(self, x):
|
|
||||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
|
||||||
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
|
|
||||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
def forward(self, x, freqs, transformer_options={}):
|
def forward(self, x, freqs, transformer_options={}):
|
||||||
q, k, v = self.get_qkv(x)
|
def compute_q(x):
|
||||||
q = apply_rope1(self.query_norm(q), freqs)
|
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
k = apply_rope1(self.key_norm(k), freqs)
|
return apply_rope1(self.query_norm(q), freqs)
|
||||||
|
|
||||||
|
def compute_k(x):
|
||||||
|
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
return apply_rope1(self.key_norm(k), freqs)
|
||||||
|
|
||||||
|
q = compute_q(x)
|
||||||
|
k = compute_k(x)
|
||||||
|
|
||||||
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||||
return self.out_layer(out)
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user