Add batch support for nabla

This commit is contained in:
Mihail Karaev 2025-12-16 11:15:59 +00:00
parent 2bff3c520f
commit 0c84b7650f
2 changed files with 13 additions and 15 deletions

View File

@ -387,13 +387,12 @@ class Kandinsky5(nn.Module):
transformer_options["block_type"] = "double"
B, _, T, H, W = x.shape
if T > 30: # 10 sec generation
NABLA_THR = 40 # long (10 sec) generation
if T > NABLA_THR:
assert self.patch_size[0] == 1
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0]
visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:])
visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0)
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
pt, ph, pw = self.patch_size
T, H, W = T // pt, H // ph, W // pw
@ -447,11 +446,11 @@ class Kandinsky5(nn.Module):
transformer_options=transformer_options,
)
if T > 30:
if T > NABLA_THR:
visual_embed = fractal_unflatten(
visual_embed[0],
visual_embed,
visual_shape[1:],
).unsqueeze(0)
)
else:
visual_embed = visual_embed.reshape(*visual_shape, -1)

View File

@ -7,20 +7,19 @@ from torch.nn.attention.flex_attention import BlockMask, flex_attention
def fractal_flatten(x, rope, shape):
pixel_size = 8
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
x = x.flatten(0, 1)
rope = rope.flatten(0, 1)
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
x = x.flatten(1, 2)
rope = rope.flatten(1, 2)
return x, rope
def fractal_unflatten(x, shape):
pixel_size = 8
x = x.reshape(-1, pixel_size**2, x.shape[-1])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
return x
def local_patching(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size