mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 22:30:50 +08:00
Merge 296b7c7b6d into 3cd7b32f1b
This commit is contained in:
commit
73459299a3
@ -6,6 +6,12 @@ import comfy.ldm.common_dit
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.math import apply_rope1
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
from comfy.ldm.kandinsky5.utils_nabla import (
|
||||||
|
fractal_flatten,
|
||||||
|
fractal_unflatten,
|
||||||
|
fast_sta_nabla,
|
||||||
|
nabla,
|
||||||
|
)
|
||||||
|
|
||||||
def attention(q, k, v, heads, transformer_options={}):
|
def attention(q, k, v, heads, transformer_options={}):
|
||||||
return optimized_attention(
|
return optimized_attention(
|
||||||
@ -116,14 +122,17 @@ class SelfAttention(nn.Module):
|
|||||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
return apply_rope1(norm_fn(result), freqs)
|
return apply_rope1(norm_fn(result), freqs)
|
||||||
|
|
||||||
def _forward(self, x, freqs, transformer_options={}):
|
def _forward(self, x, freqs, sparse_params=None, transformer_options={}):
|
||||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
if sparse_params is None:
|
||||||
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||||
|
else:
|
||||||
|
out = nabla(q, k, v, sparse_params)
|
||||||
return self.out_layer(out)
|
return self.out_layer(out)
|
||||||
|
|
||||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
def _forward_chunked(self, x, freqs, sparse_params=None, transformer_options={}):
|
||||||
def process_chunks(proj_fn, norm_fn):
|
def process_chunks(proj_fn, norm_fn):
|
||||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||||
@ -135,14 +144,17 @@ class SelfAttention(nn.Module):
|
|||||||
q = process_chunks(self.to_query, self.query_norm)
|
q = process_chunks(self.to_query, self.query_norm)
|
||||||
k = process_chunks(self.to_key, self.key_norm)
|
k = process_chunks(self.to_key, self.key_norm)
|
||||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
if sparse_params is None:
|
||||||
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||||
|
else:
|
||||||
|
out = nabla(q, k, v, sparse_params)
|
||||||
return self.out_layer(out)
|
return self.out_layer(out)
|
||||||
|
|
||||||
def forward(self, x, freqs, transformer_options={}):
|
def forward(self, x, freqs, sparse_params=None, transformer_options={}):
|
||||||
if x.shape[1] > 8192:
|
if x.shape[1] > 8192:
|
||||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
return self._forward_chunked(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
return self._forward(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(SelfAttention):
|
class CrossAttention(SelfAttention):
|
||||||
@ -251,12 +263,12 @@ class TransformerDecoderBlock(nn.Module):
|
|||||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
def forward(self, visual_embed, text_embed, time_embed, freqs, sparse_params=None, transformer_options={}):
|
||||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||||
# self attention
|
# self attention
|
||||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
visual_out = self.self_attention(visual_out, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
|
||||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||||
# cross attention
|
# cross attention
|
||||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||||
@ -369,21 +381,82 @@ class Kandinsky5(nn.Module):
|
|||||||
|
|
||||||
visual_embed = self.visual_embeddings(x)
|
visual_embed = self.visual_embeddings(x)
|
||||||
visual_shape = visual_embed.shape[:-1]
|
visual_shape = visual_embed.shape[:-1]
|
||||||
visual_embed = visual_embed.flatten(1, -2)
|
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
|
|
||||||
|
B, _, T, H, W = x.shape
|
||||||
|
NABLA_THR = 31 # long (10 sec) generation
|
||||||
|
if T > NABLA_THR:
|
||||||
|
assert self.patch_size[0] == 1
|
||||||
|
|
||||||
|
# pro video model uses lower P at higher resolutions
|
||||||
|
P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9
|
||||||
|
|
||||||
|
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
|
||||||
|
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
|
||||||
|
pt, ph, pw = self.patch_size
|
||||||
|
T, H, W = T // pt, H // ph, W // pw
|
||||||
|
|
||||||
|
wT, wW, wH = 11, 3, 3
|
||||||
|
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device)
|
||||||
|
|
||||||
|
sparse_params = dict(
|
||||||
|
sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0),
|
||||||
|
attention_type="nabla",
|
||||||
|
to_fractal=True,
|
||||||
|
P=P,
|
||||||
|
wT=wT, wW=wW, wH=wH,
|
||||||
|
add_sta=True,
|
||||||
|
visual_shape=(T, H, W),
|
||||||
|
method="topcdf",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sparse_params = None
|
||||||
|
visual_embed = visual_embed.flatten(1, -2)
|
||||||
|
|
||||||
for i, block in enumerate(self.visual_transformer_blocks):
|
for i, block in enumerate(self.visual_transformer_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
return block(
|
||||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
x=args["x"],
|
||||||
|
context=args["context"],
|
||||||
|
time_embed=args["time_embed"],
|
||||||
|
freqs=args["freqs"],
|
||||||
|
sparse_params=args.get("sparse_params"),
|
||||||
|
transformer_options=args.get("transformer_options"),
|
||||||
|
)
|
||||||
|
visual_embed = blocks_replace[("double_block", i)](
|
||||||
|
{
|
||||||
|
"x": visual_embed,
|
||||||
|
"context": context,
|
||||||
|
"time_embed": time_embed,
|
||||||
|
"freqs": freqs,
|
||||||
|
"sparse_params": sparse_params,
|
||||||
|
"transformer_options": transformer_options,
|
||||||
|
},
|
||||||
|
{"original_block": block_wrap},
|
||||||
|
)["x"]
|
||||||
else:
|
else:
|
||||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
visual_embed = block(
|
||||||
|
visual_embed,
|
||||||
|
context,
|
||||||
|
time_embed,
|
||||||
|
freqs=freqs,
|
||||||
|
sparse_params=sparse_params,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
if T > NABLA_THR:
|
||||||
|
visual_embed = fractal_unflatten(
|
||||||
|
visual_embed,
|
||||||
|
visual_shape[1:],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||||
|
|
||||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
|
||||||
return self.out_layer(visual_embed, time_embed)
|
return self.out_layer(visual_embed, time_embed)
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
|
|||||||
146
comfy/ldm/kandinsky5/utils_nabla.py
Normal file
146
comfy/ldm/kandinsky5/utils_nabla.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
||||||
|
|
||||||
|
|
||||||
|
def fractal_flatten(x, rope, shape):
|
||||||
|
pixel_size = 8
|
||||||
|
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||||
|
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
|
||||||
|
x = x.flatten(1, 2)
|
||||||
|
rope = rope.flatten(1, 2)
|
||||||
|
return x, rope
|
||||||
|
|
||||||
|
|
||||||
|
def fractal_unflatten(x, shape):
|
||||||
|
pixel_size = 8
|
||||||
|
x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1])
|
||||||
|
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def local_patching(x, shape, group_size, dim=0):
|
||||||
|
duration, height, width = shape
|
||||||
|
g1, g2, g3 = group_size
|
||||||
|
x = x.reshape(
|
||||||
|
*x.shape[:dim],
|
||||||
|
duration // g1,
|
||||||
|
g1,
|
||||||
|
height // g2,
|
||||||
|
g2,
|
||||||
|
width // g3,
|
||||||
|
g3,
|
||||||
|
*x.shape[dim + 3 :]
|
||||||
|
)
|
||||||
|
x = x.permute(
|
||||||
|
*range(len(x.shape[:dim])),
|
||||||
|
dim,
|
||||||
|
dim + 2,
|
||||||
|
dim + 4,
|
||||||
|
dim + 1,
|
||||||
|
dim + 3,
|
||||||
|
dim + 5,
|
||||||
|
*range(dim + 6, len(x.shape))
|
||||||
|
)
|
||||||
|
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def local_merge(x, shape, group_size, dim=0):
|
||||||
|
duration, height, width = shape
|
||||||
|
g1, g2, g3 = group_size
|
||||||
|
x = x.reshape(
|
||||||
|
*x.shape[:dim],
|
||||||
|
duration // g1,
|
||||||
|
height // g2,
|
||||||
|
width // g3,
|
||||||
|
g1,
|
||||||
|
g2,
|
||||||
|
g3,
|
||||||
|
*x.shape[dim + 2 :]
|
||||||
|
)
|
||||||
|
x = x.permute(
|
||||||
|
*range(len(x.shape[:dim])),
|
||||||
|
dim,
|
||||||
|
dim + 3,
|
||||||
|
dim + 1,
|
||||||
|
dim + 4,
|
||||||
|
dim + 2,
|
||||||
|
dim + 5,
|
||||||
|
*range(dim + 6, len(x.shape))
|
||||||
|
)
|
||||||
|
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> Tensor:
|
||||||
|
l = torch.Tensor([T, H, W]).amax()
|
||||||
|
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
|
||||||
|
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
|
||||||
|
sta_t, sta_h, sta_w = (
|
||||||
|
mat[:T, :T].flatten(),
|
||||||
|
mat[:H, :H].flatten(),
|
||||||
|
mat[:W, :W].flatten(),
|
||||||
|
)
|
||||||
|
sta_t = sta_t <= wT // 2
|
||||||
|
sta_h = sta_h <= wH // 2
|
||||||
|
sta_w = sta_w <= wW // 2
|
||||||
|
sta_hw = (
|
||||||
|
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
|
||||||
|
.reshape(H, H, W, W)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
sta = (
|
||||||
|
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
|
||||||
|
.reshape(T, T, H * W, H * W)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
return sta.reshape(T * H * W, T * H * W)
|
||||||
|
|
||||||
|
def nablaT_v2(q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9) -> BlockMask:
|
||||||
|
# Map estimation
|
||||||
|
B, h, S, D = q.shape
|
||||||
|
s1 = S // 64
|
||||||
|
qa = q.reshape(B, h, s1, 64, D).mean(-2)
|
||||||
|
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
|
||||||
|
map = qa @ ka
|
||||||
|
|
||||||
|
map = torch.softmax(map / math.sqrt(D), dim=-1)
|
||||||
|
# Map binarization
|
||||||
|
vals, inds = map.sort(-1)
|
||||||
|
cvals = vals.cumsum_(-1)
|
||||||
|
mask = (cvals >= 1 - thr).int()
|
||||||
|
mask = mask.gather(-1, inds.argsort(-1))
|
||||||
|
mask = torch.logical_or(mask, sta)
|
||||||
|
|
||||||
|
# BlockMask creation
|
||||||
|
kv_nb = mask.sum(-1).to(torch.int32)
|
||||||
|
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
|
||||||
|
return BlockMask.from_kv_blocks(
|
||||||
|
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
||||||
|
def nabla(query, key, value, sparse_params=None):
|
||||||
|
query = query.transpose(1, 2).contiguous()
|
||||||
|
key = key.transpose(1, 2).contiguous()
|
||||||
|
value = value.transpose(1, 2).contiguous()
|
||||||
|
block_mask = nablaT_v2(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
sparse_params["sta_mask"],
|
||||||
|
thr=sparse_params["P"],
|
||||||
|
)
|
||||||
|
out = (
|
||||||
|
flex_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
block_mask=block_mask
|
||||||
|
)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
out = out.flatten(-2, -1)
|
||||||
|
return out
|
||||||
@ -34,6 +34,9 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||||
|
if length > 121: # 10 sec generation, for nabla
|
||||||
|
height = 128 * round(height / 128)
|
||||||
|
width = 128 * round(width / 128)
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
cond_latent_out = {}
|
cond_latent_out = {}
|
||||||
if start_image is not None:
|
if start_image is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user