Update chunking

This commit is contained in:
kijai 2025-11-29 16:10:37 +02:00
parent a3ce1e02d7
commit b24bac604a

View File

@ -110,23 +110,43 @@ class SelfAttention(nn.Module):
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 2
def forward(self, x, freqs, transformer_options={}):
def compute_q(x):
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(self.query_norm(q), freqs)
def compute_k(x):
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(self.key_norm(k), freqs)
q = compute_q(x)
k = compute_k(x)
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs)
def _forward(self, x, freqs, transformer_options={}):
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def _forward_chunked(self, x, freqs, transformer_options={}):
def process_chunks(proj_fn, norm_fn):
B, L, _ = x.shape
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
chunks = []
for i in range(0, L, chunk_size):
end_idx = min(i + chunk_size, L)
x_chunk = x[:, i:end_idx]
freqs_chunk = freqs[:, i:end_idx]
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
return torch.cat(chunks, dim=1)
q = process_chunks(self.to_query, self.query_norm)
k = process_chunks(self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def forward(self, x, freqs, transformer_options={}):
if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
else:
return self._forward(x, freqs, transformer_options=transformer_options)
class CrossAttention(SelfAttention):
def get_qkv(self, x, context):
@ -150,22 +170,24 @@ class FeedForward(nn.Module):
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 4
def forward(self, x):
#return self.out_layer(self.activation(self.in_layer(x)))
# ffn is the peak memory consumer, chunking here helps
B, L, C = x.shape
def _forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
def _forward_chunked(self, x):
B, L, _ = x.shape
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device)
for i in range(0, L, chunk_size):
end_idx = min(i + chunk_size, L)
def compute_chunk(x_chunk):
activated = self.activation(self.in_layer(x_chunk))
return self.out_layer(activated)
output[:, i:end_idx] = compute_chunk(x[:, i:end_idx])
output[:, i:end_idx] = self._forward(x[:, i:end_idx])
return output
def forward(self, x):
if x.shape[1] > 8192:
return self._forward_chunked(x)
else:
return self._forward(x)
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):