Further reduce peak VRAM consumption by chunking ffn

This commit is contained in:
kijai 2025-11-28 02:13:35 +02:00
parent 1baf5ec4af
commit a3ce1e02d7

View File

@ -148,9 +148,23 @@ class FeedForward(nn.Module):
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.GELU()
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 4
def forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
#return self.out_layer(self.activation(self.in_layer(x)))
# ffn is the peak memory consumer, chunking here helps
B, L, C = x.shape
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device)
for i in range(0, L, chunk_size):
end_idx = min(i + chunk_size, L)
def compute_chunk(x_chunk):
activated = self.activation(self.in_layer(x_chunk))
return self.out_layer(activated)
output[:, i:end_idx] = compute_chunk(x[:, i:end_idx])
return output
class OutLayer(nn.Module):