mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
Further reduce peak VRAM consumption by chunking ffn
This commit is contained in:
parent
1baf5ec4af
commit
a3ce1e02d7
@ -148,9 +148,23 @@ class FeedForward(nn.Module):
|
||||
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
#return self.out_layer(self.activation(self.in_layer(x)))
|
||||
# ffn is the peak memory consumer, chunking here helps
|
||||
B, L, C = x.shape
|
||||
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
|
||||
output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device)
|
||||
|
||||
for i in range(0, L, chunk_size):
|
||||
end_idx = min(i + chunk_size, L)
|
||||
def compute_chunk(x_chunk):
|
||||
activated = self.activation(self.in_layer(x_chunk))
|
||||
return self.out_layer(activated)
|
||||
output[:, i:end_idx] = compute_chunk(x[:, i:end_idx])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user