From bd59bae606fa65d4c6e1789bdccd13bc23c1a565 Mon Sep 17 00:00:00 2001 From: Max Tretikov Date: Fri, 14 Jun 2024 14:43:55 -0600 Subject: [PATCH] Fix compile_core in comfy.ldm.modules.diffusionmodules.mmdit --- comfy/ldm/modules/diffusionmodules/mmdit.py | 6 ++++-- comfy/ldm/modules/diffusionmodules/model.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 7c49f2946..12a44dcee 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -837,9 +837,9 @@ class MMDiT(nn.Module): self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + self.compile_core = compile_core if compile_core: - assert False - self.forward_core_with_concat = torch.compile(self.forward_core_with_concat) + self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat) def cropped_pos_embed(self, hw, device=None): p = self.x_embedder.patch_size[0] @@ -895,6 +895,8 @@ class MMDiT(nn.Module): c_mod: torch.Tensor, context: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self.compile_core: + return self.forward_core_with_concat_compiled(x, c_mod, context) if self.register_length > 0: context = torch.cat( ( diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8e11c1eab..dbba2bb1a 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -216,10 +216,10 @@ def xformers_attention(q, k, v): (q, k, v), ) - try: + if model_management.xformers_enabled_vae(): out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) out = out.transpose(1, 2).reshape(B, C, H, W) - except NotImplementedError as e: + else: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) return out