mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Fix compile_core in comfy.ldm.modules.diffusionmodules.mmdit
This commit is contained in:
parent
891154b79e
commit
bd59bae606
@ -837,9 +837,9 @@ class MMDiT(nn.Module):
|
||||
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.compile_core = compile_core
|
||||
if compile_core:
|
||||
assert False
|
||||
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
||||
self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat)
|
||||
|
||||
def cropped_pos_embed(self, hw, device=None):
|
||||
p = self.x_embedder.patch_size[0]
|
||||
@ -895,6 +895,8 @@ class MMDiT(nn.Module):
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.compile_core:
|
||||
return self.forward_core_with_concat_compiled(x, c_mod, context)
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
|
||||
@ -216,10 +216,10 @@ def xformers_attention(q, k, v):
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
if model_management.xformers_enabled_vae():
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
else:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user