diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index a8cf894e4..393b69e34 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -148,9 +148,23 @@ class FeedForward(nn.Module): self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.activation = nn.GELU() self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 4 def forward(self, x): - return self.out_layer(self.activation(self.in_layer(x))) + #return self.out_layer(self.activation(self.in_layer(x))) + # ffn is the peak memory consumer, chunking here helps + B, L, C = x.shape + chunk_size = (L + self.num_chunks - 1) // self.num_chunks + output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device) + + for i in range(0, L, chunk_size): + end_idx = min(i + chunk_size, L) + def compute_chunk(x_chunk): + activated = self.activation(self.in_layer(x_chunk)) + return self.out_layer(activated) + output[:, i:end_idx] = compute_chunk(x[:, i:end_idx]) + + return output class OutLayer(nn.Module):