mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
Fix compile_core in comfy.ldm.modules.diffusionmodules.mmdit
This commit is contained in:
parent
891154b79e
commit
bd59bae606
@ -837,9 +837,9 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.compile_core = compile_core
|
||||||
if compile_core:
|
if compile_core:
|
||||||
assert False
|
self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat)
|
||||||
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
|
||||||
|
|
||||||
def cropped_pos_embed(self, hw, device=None):
|
def cropped_pos_embed(self, hw, device=None):
|
||||||
p = self.x_embedder.patch_size[0]
|
p = self.x_embedder.patch_size[0]
|
||||||
@ -895,6 +895,8 @@ class MMDiT(nn.Module):
|
|||||||
c_mod: torch.Tensor,
|
c_mod: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if self.compile_core:
|
||||||
|
return self.forward_core_with_concat_compiled(x, c_mod, context)
|
||||||
if self.register_length > 0:
|
if self.register_length > 0:
|
||||||
context = torch.cat(
|
context = torch.cat(
|
||||||
(
|
(
|
||||||
|
|||||||
@ -216,10 +216,10 @@ def xformers_attention(q, k, v):
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if model_management.xformers_enabled_vae():
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||||
except NotImplementedError as e:
|
else:
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user