From b24bac604aa356471a6ee24560e5044d6718d024 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:10:37 +0200 Subject: [PATCH] Update chunking --- comfy/ldm/kandinsky5/model.py | 64 +++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 393b69e34..2bcaae195 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -110,23 +110,43 @@ class SelfAttention(nn.Module): self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 2 - def forward(self, x, freqs, transformer_options={}): - def compute_q(x): - q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) - return apply_rope1(self.query_norm(q), freqs) - - def compute_k(x): - k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1) - return apply_rope1(self.key_norm(k), freqs) - - q = compute_q(x) - k = compute_k(x) + def _compute_qk(self, x, freqs, proj_fn, norm_fn): + result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) + return apply_rope1(norm_fn(result), freqs) + def _forward(self, x, freqs, transformer_options={}): + q = self._compute_qk(x, freqs, self.to_query, self.query_norm) + k = self._compute_qk(x, freqs, self.to_key, self.key_norm) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) return self.out_layer(out) + def _forward_chunked(self, x, freqs, transformer_options={}): + def process_chunks(proj_fn, norm_fn): + B, L, _ = x.shape + chunk_size = (L + self.num_chunks - 1) // self.num_chunks + chunks = [] + for i in range(0, L, chunk_size): + end_idx = min(i + chunk_size, L) + x_chunk = x[:, i:end_idx] + freqs_chunk = freqs[:, i:end_idx] + chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn)) + return torch.cat(chunks, dim=1) + + q = process_chunks(self.to_query, self.query_norm) + k = process_chunks(self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def forward(self, x, freqs, transformer_options={}): + if x.shape[1] > 8192: + return self._forward_chunked(x, freqs, transformer_options=transformer_options) + else: + return self._forward(x, freqs, transformer_options=transformer_options) + class CrossAttention(SelfAttention): def get_qkv(self, x, context): @@ -150,22 +170,24 @@ class FeedForward(nn.Module): self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.num_chunks = 4 - def forward(self, x): - #return self.out_layer(self.activation(self.in_layer(x))) - # ffn is the peak memory consumer, chunking here helps - B, L, C = x.shape + def _forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + def _forward_chunked(self, x): + B, L, _ = x.shape chunk_size = (L + self.num_chunks - 1) // self.num_chunks output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device) - for i in range(0, L, chunk_size): end_idx = min(i + chunk_size, L) - def compute_chunk(x_chunk): - activated = self.activation(self.in_layer(x_chunk)) - return self.out_layer(activated) - output[:, i:end_idx] = compute_chunk(x[:, i:end_idx]) - + output[:, i:end_idx] = self._forward(x[:, i:end_idx]) return output + def forward(self, x): + if x.shape[1] > 8192: + return self._forward_chunked(x) + else: + return self._forward(x) + class OutLayer(nn.Module): def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):