mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
Update chunking
This commit is contained in:
parent
a3ce1e02d7
commit
b24bac604a
@ -110,23 +110,43 @@ class SelfAttention(nn.Module):
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 2
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
def compute_q(x):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(self.query_norm(q), freqs)
|
||||
|
||||
def compute_k(x):
|
||||
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(self.key_norm(k), freqs)
|
||||
|
||||
q = compute_q(x)
|
||||
k = compute_k(x)
|
||||
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
B, L, _ = x.shape
|
||||
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
|
||||
chunks = []
|
||||
for i in range(0, L, chunk_size):
|
||||
end_idx = min(i + chunk_size, L)
|
||||
x_chunk = x[:, i:end_idx]
|
||||
freqs_chunk = freqs[:, i:end_idx]
|
||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
@ -150,22 +170,24 @@ class FeedForward(nn.Module):
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def forward(self, x):
|
||||
#return self.out_layer(self.activation(self.in_layer(x)))
|
||||
# ffn is the peak memory consumer, chunking here helps
|
||||
B, L, C = x.shape
|
||||
def _forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
def _forward_chunked(self, x):
|
||||
B, L, _ = x.shape
|
||||
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
|
||||
output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device)
|
||||
|
||||
for i in range(0, L, chunk_size):
|
||||
end_idx = min(i + chunk_size, L)
|
||||
def compute_chunk(x_chunk):
|
||||
activated = self.activation(self.in_layer(x_chunk))
|
||||
return self.out_layer(activated)
|
||||
output[:, i:end_idx] = compute_chunk(x[:, i:end_idx])
|
||||
|
||||
output[:, i:end_idx] = self._forward(x[:, i:end_idx])
|
||||
return output
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user