diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index df10e4496..ac9352867 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -387,13 +387,12 @@ class Kandinsky5(nn.Module): transformer_options["block_type"] = "double" B, _, T, H, W = x.shape - if T > 30: # 10 sec generation + NABLA_THR = 40 # long (10 sec) generation + if T > NABLA_THR: assert self.patch_size[0] == 1 - freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0] - visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:]) - visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0) - + freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:]) + visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:]) pt, ph, pw = self.patch_size T, H, W = T // pt, H // ph, W // pw @@ -447,11 +446,11 @@ class Kandinsky5(nn.Module): transformer_options=transformer_options, ) - if T > 30: + if T > NABLA_THR: visual_embed = fractal_unflatten( - visual_embed[0], + visual_embed, visual_shape[1:], - ).unsqueeze(0) + ) else: visual_embed = visual_embed.reshape(*visual_shape, -1) diff --git a/comfy/ldm/kandinsky5/utils_nabla.py b/comfy/ldm/kandinsky5/utils_nabla.py index 705b1d75e..5e2bc4076 100644 --- a/comfy/ldm/kandinsky5/utils_nabla.py +++ b/comfy/ldm/kandinsky5/utils_nabla.py @@ -7,20 +7,19 @@ from torch.nn.attention.flex_attention import BlockMask, flex_attention def fractal_flatten(x, rope, shape): pixel_size = 8 - x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) - rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) - x = x.flatten(0, 1) - rope = rope.flatten(0, 1) + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) + x = x.flatten(1, 2) + rope = rope.flatten(1, 2) return x, rope def fractal_unflatten(x, shape): pixel_size = 8 - x = x.reshape(-1, pixel_size**2, x.shape[-1]) - x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) return x - def local_patching(x, shape, group_size, dim=0): duration, height, width = shape g1, g2, g3 = group_size