mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
Add batch support for nabla
This commit is contained in:
parent
2bff3c520f
commit
0c84b7650f
@ -387,13 +387,12 @@ class Kandinsky5(nn.Module):
|
||||
transformer_options["block_type"] = "double"
|
||||
|
||||
B, _, T, H, W = x.shape
|
||||
if T > 30: # 10 sec generation
|
||||
NABLA_THR = 40 # long (10 sec) generation
|
||||
if T > NABLA_THR:
|
||||
assert self.patch_size[0] == 1
|
||||
|
||||
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0]
|
||||
visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:])
|
||||
visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0)
|
||||
|
||||
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
|
||||
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
|
||||
pt, ph, pw = self.patch_size
|
||||
T, H, W = T // pt, H // ph, W // pw
|
||||
|
||||
@ -447,11 +446,11 @@ class Kandinsky5(nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if T > 30:
|
||||
if T > NABLA_THR:
|
||||
visual_embed = fractal_unflatten(
|
||||
visual_embed[0],
|
||||
visual_embed,
|
||||
visual_shape[1:],
|
||||
).unsqueeze(0)
|
||||
)
|
||||
else:
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
|
||||
|
||||
@ -7,20 +7,19 @@ from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
||||
|
||||
def fractal_flatten(x, rope, shape):
|
||||
pixel_size = 8
|
||||
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
x = x.flatten(0, 1)
|
||||
rope = rope.flatten(0, 1)
|
||||
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
x = x.flatten(1, 2)
|
||||
rope = rope.flatten(1, 2)
|
||||
return x, rope
|
||||
|
||||
|
||||
def fractal_unflatten(x, shape):
|
||||
pixel_size = 8
|
||||
x = x.reshape(-1, pixel_size**2, x.shape[-1])
|
||||
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1])
|
||||
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
return x
|
||||
|
||||
|
||||
def local_patching(x, shape, group_size, dim=0):
|
||||
duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
|
||||
Loading…
Reference in New Issue
Block a user