diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 7c49f2946..12a44dcee 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -837,9 +837,9 @@ class MMDiT(nn.Module): self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + self.compile_core = compile_core if compile_core: - assert False - self.forward_core_with_concat = torch.compile(self.forward_core_with_concat) + self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat) def cropped_pos_embed(self, hw, device=None): p = self.x_embedder.patch_size[0] @@ -895,6 +895,8 @@ class MMDiT(nn.Module): c_mod: torch.Tensor, context: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self.compile_core: + return self.forward_core_with_concat_compiled(x, c_mod, context) if self.register_length > 0: context = torch.cat( ( diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8e11c1eab..dbba2bb1a 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -216,10 +216,10 @@ def xformers_attention(q, k, v): (q, k, v), ) - try: + if model_management.xformers_enabled_vae(): out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) out = out.transpose(1, 2).reshape(B, C, H, W) - except NotImplementedError as e: + else: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) return out