Fix compile_core in comfy.ldm.modules.diffusionmodules.mmdit

This commit is contained in:
Max Tretikov 2024-06-14 14:43:55 -06:00
parent 891154b79e
commit bd59bae606
2 changed files with 6 additions and 4 deletions

View File

@ -837,9 +837,9 @@ class MMDiT(nn.Module):
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
self.compile_core = compile_core
if compile_core:
assert False
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat)
def cropped_pos_embed(self, hw, device=None):
p = self.x_embedder.patch_size[0]
@ -895,6 +895,8 @@ class MMDiT(nn.Module):
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.compile_core:
return self.forward_core_with_concat_compiled(x, c_mod, context)
if self.register_length > 0:
context = torch.cat(
(

View File

@ -216,10 +216,10 @@ def xformers_attention(q, k, v):
(q, k, v),
)
try:
if model_management.xformers_enabled_vae():
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
else:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out