Add nabla support

This commit is contained in:
Mihail Karaev 2025-12-10 14:06:16 +00:00
parent 3a5f239cb6
commit 2bff3c520f
2 changed files with 234 additions and 14 deletions

View File

@ -6,6 +6,12 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.kandinsky5.utils_nabla import (
fractal_flatten,
fractal_unflatten,
fast_sta_nabla,
nabla,
)
def attention(q, k, v, heads, transformer_options={}):
return optimized_attention(
@ -116,14 +122,17 @@ class SelfAttention(nn.Module):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs)
def _forward(self, x, freqs, transformer_options={}):
def _forward(self, x, freqs, sparse_params=None, transformer_options={}):
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
if sparse_params is None:
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
else:
out = nabla(q, k, v, sparse_params)
return self.out_layer(out)
def _forward_chunked(self, x, freqs, transformer_options={}):
def _forward_chunked(self, x, freqs, sparse_params=None, transformer_options={}):
def process_chunks(proj_fn, norm_fn):
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
@ -135,14 +144,17 @@ class SelfAttention(nn.Module):
q = process_chunks(self.to_query, self.query_norm)
k = process_chunks(self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
if sparse_params is None:
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
else:
out = nabla(q, k, v, sparse_params)
return self.out_layer(out)
def forward(self, x, freqs, transformer_options={}):
def forward(self, x, freqs, sparse_params=None, transformer_options={}):
if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
return self._forward_chunked(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
else:
return self._forward(x, freqs, transformer_options=transformer_options)
return self._forward(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
class CrossAttention(SelfAttention):
@ -251,12 +263,12 @@ class TransformerDecoderBlock(nn.Module):
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
def forward(self, visual_embed, text_embed, time_embed, freqs, sparse_params=None, transformer_options={}):
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
visual_out = self.self_attention(visual_out, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# cross attention
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
@ -369,21 +381,80 @@ class Kandinsky5(nn.Module):
visual_embed = self.visual_embeddings(x)
visual_shape = visual_embed.shape[:-1]
visual_embed = visual_embed.flatten(1, -2)
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
transformer_options["block_type"] = "double"
B, _, T, H, W = x.shape
if T > 30: # 10 sec generation
assert self.patch_size[0] == 1
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0]
visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:])
visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0)
pt, ph, pw = self.patch_size
T, H, W = T // pt, H // ph, W // pw
wT, wW, wH = 11, 11, 3
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device)
sparse_params = dict(
sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0),
attention_type="nabla",
to_fractal=True,
P=0.8,
wT=wT, wW=wW, wH=wH,
add_sta=True,
visual_shape=(T, H, W),
method="topcdf",
)
else:
sparse_params = None
visual_embed = visual_embed.flatten(1, -2)
for i, block in enumerate(self.visual_transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
return block(
x=args["x"],
context=args["context"],
time_embed=args["time_embed"],
freqs=args["freqs"],
sparse_params=args.get("sparse_params"),
transformer_options=args.get("transformer_options"),
)
visual_embed = blocks_replace[("double_block", i)](
{
"x": visual_embed,
"context": context,
"time_embed": time_embed,
"freqs": freqs,
"sparse_params": sparse_params,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)["x"]
else:
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
visual_embed = block(
visual_embed,
context,
time_embed,
freqs=freqs,
sparse_params=sparse_params,
transformer_options=transformer_options,
)
if T > 30:
visual_embed = fractal_unflatten(
visual_embed[0],
visual_shape[1:],
).unsqueeze(0)
else:
visual_embed = visual_embed.reshape(*visual_shape, -1)
visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
@ -411,3 +482,5 @@ class Kandinsky5(nn.Module):
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)

View File

@ -0,0 +1,147 @@
import math
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import BlockMask, flex_attention
def fractal_flatten(x, rope, shape):
pixel_size = 8
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
x = x.flatten(0, 1)
rope = rope.flatten(0, 1)
return x, rope
def fractal_unflatten(x, shape):
pixel_size = 8
x = x.reshape(-1, pixel_size**2, x.shape[-1])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
return x
def local_patching(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
g1,
height // g2,
g2,
width // g3,
g3,
*x.shape[dim + 3 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 2,
dim + 4,
dim + 1,
dim + 3,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
return x
def local_merge(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
height // g2,
width // g3,
g1,
g2,
g3,
*x.shape[dim + 2 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 3,
dim + 1,
dim + 4,
dim + 2,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
return x
def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> Tensor:
l = torch.Tensor([T, H, W]).amax()
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
sta_t, sta_h, sta_w = (
mat[:T, :T].flatten(),
mat[:H, :H].flatten(),
mat[:W, :W].flatten(),
)
sta_t = sta_t <= wT // 2
sta_h = sta_h <= wH // 2
sta_w = sta_w <= wW // 2
sta_hw = (
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
.reshape(H, H, W, W)
.transpose(1, 2)
.flatten()
)
sta = (
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
.reshape(T, T, H * W, H * W)
.transpose(1, 2)
)
return sta.reshape(T * H * W, T * H * W)
def nablaT_v2(q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9) -> BlockMask:
# Map estimation
B, h, S, D = q.shape
s1 = S // 64
qa = q.reshape(B, h, s1, 64, D).mean(-2)
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
map = qa @ ka
map = torch.softmax(map / math.sqrt(D), dim=-1)
# Map binarization
vals, inds = map.sort(-1)
cvals = vals.cumsum_(-1)
mask = (cvals >= 1 - thr).int()
mask = mask.gather(-1, inds.argsort(-1))
mask = torch.logical_or(mask, sta)
# BlockMask creation
kv_nb = mask.sum(-1).to(torch.int32)
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
return BlockMask.from_kv_blocks(
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
)
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
def nabla(query, key, value, sparse_params=None):
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
block_mask = nablaT_v2(
query,
key,
sparse_params["sta_mask"],
thr=sparse_params["P"],
)
out = (
flex_attention(
query,
key,
value,
block_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
out = out.flatten(-2, -1)
return out