mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Add batch support for nabla
This commit is contained in:
parent
2bff3c520f
commit
0c84b7650f
@ -387,13 +387,12 @@ class Kandinsky5(nn.Module):
|
|||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
|
|
||||||
B, _, T, H, W = x.shape
|
B, _, T, H, W = x.shape
|
||||||
if T > 30: # 10 sec generation
|
NABLA_THR = 40 # long (10 sec) generation
|
||||||
|
if T > NABLA_THR:
|
||||||
assert self.patch_size[0] == 1
|
assert self.patch_size[0] == 1
|
||||||
|
|
||||||
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0]
|
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
|
||||||
visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:])
|
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
|
||||||
visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0)
|
|
||||||
|
|
||||||
pt, ph, pw = self.patch_size
|
pt, ph, pw = self.patch_size
|
||||||
T, H, W = T // pt, H // ph, W // pw
|
T, H, W = T // pt, H // ph, W // pw
|
||||||
|
|
||||||
@ -447,11 +446,11 @@ class Kandinsky5(nn.Module):
|
|||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
if T > 30:
|
if T > NABLA_THR:
|
||||||
visual_embed = fractal_unflatten(
|
visual_embed = fractal_unflatten(
|
||||||
visual_embed[0],
|
visual_embed,
|
||||||
visual_shape[1:],
|
visual_shape[1:],
|
||||||
).unsqueeze(0)
|
)
|
||||||
else:
|
else:
|
||||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||||
|
|
||||||
|
|||||||
@ -7,20 +7,19 @@ from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
|||||||
|
|
||||||
def fractal_flatten(x, rope, shape):
|
def fractal_flatten(x, rope, shape):
|
||||||
pixel_size = 8
|
pixel_size = 8
|
||||||
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
|
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||||
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
|
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
|
||||||
x = x.flatten(0, 1)
|
x = x.flatten(1, 2)
|
||||||
rope = rope.flatten(0, 1)
|
rope = rope.flatten(1, 2)
|
||||||
return x, rope
|
return x, rope
|
||||||
|
|
||||||
|
|
||||||
def fractal_unflatten(x, shape):
|
def fractal_unflatten(x, shape):
|
||||||
pixel_size = 8
|
pixel_size = 8
|
||||||
x = x.reshape(-1, pixel_size**2, x.shape[-1])
|
x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1])
|
||||||
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
|
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def local_patching(x, shape, group_size, dim=0):
|
def local_patching(x, shape, group_size, dim=0):
|
||||||
duration, height, width = shape
|
duration, height, width = shape
|
||||||
g1, g2, g3 = group_size
|
g1, g2, g3 = group_size
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user