diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b193fe5e8..ba52540a8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -15,8 +15,9 @@ from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management if model_management.xformers_enabled(): - import xformers - import xformers.ops + # xFormers's fmha module is now provided by MSLK + import mslk + import mslk.attention.fmha SAGE_ATTENTION_IS_AVAILABLE = False try: @@ -404,12 +405,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return r1 BROKEN_XFORMERS = False -try: - x_vers = xformers.__version__ - # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error) - BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") -except: - pass @wrap_attn def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): @@ -463,7 +458,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh mask = mask_out[..., :mask.shape[-1]] mask = mask.expand(b, heads, -1, -1) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) + # xFormers's fmha module is now provided by MSLK + out = mslk.attention.fmha.memory_efficient_attention(q, k, v, attn_bias=mask) if skip_output_reshape: out = out.permute(0, 2, 1, 3) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fcbaa074f..4c8375742 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -10,8 +10,8 @@ import comfy.ops ops = comfy.ops.disable_weight_init if model_management.xformers_enabled_vae(): - import xformers - import xformers.ops + # xFormers's fmha module is now provided by MSLK + import mslk.attention.fmha def torch_cat_if_needed(xl, dim): xl = [x for x in xl if x is not None and x.shape[dim] > 0] @@ -295,7 +295,8 @@ def xformers_attention(q, k, v): ) try: - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + # xFormers's fmha module is now provided by MSLK + out = mslk.attention.fmha.memory_efficient_attention(q, k, v, attn_bias=None) out = out.transpose(1, 2).reshape(orig_shape) except NotImplementedError: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py index 2225076e5..335e6ef77 100644 --- a/comfy/ldm/pixart/blocks.py +++ b/comfy/ldm/pixart/blocks.py @@ -10,11 +10,9 @@ from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, time from comfy.ldm.modules.attention import optimized_attention # if model_management.xformers_enabled(): -# import xformers.ops -# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28: -# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens -# else: -# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens +# # xFormers's fmha module is now provided by MSLK +# import mslk.attention.fmha +# block_diagonal_mask_from_seqlens = mslk.attention.fmha.attn_bias.BlockDiagonalMask.from_seqlens def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -51,7 +49,8 @@ class MultiHeadCrossAttention(nn.Module): # attn_bias = None # if mask is not None: # attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) - # x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) + # # xFormers's fmha module is now provided by MSLK + # x = mslk.attention.fmha.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) # else: # q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) # attn_mask = None diff --git a/comfy/model_management.py b/comfy/model_management.py index bcf1399c4..f2cd7a3e4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -294,20 +294,13 @@ if args.disable_xformers: XFORMERS_IS_AVAILABLE = False else: try: - import xformers - import xformers.ops + # xFormers's fmha module is now provided by MSLK + import mslk + import mslk.attention.fmha XFORMERS_IS_AVAILABLE = True try: - XFORMERS_IS_AVAILABLE = xformers._has_cpp_library - except: - pass - try: - XFORMERS_VERSION = xformers.version.__version__ + XFORMERS_VERSION = mslk.__version__ logging.info("xformers version: {}".format(XFORMERS_VERSION)) - if XFORMERS_VERSION.startswith("0.0.18"): - logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") - logging.warning("Please downgrade or upgrade xformers to a different version.\n") - XFORMERS_ENABLED_VAE = False except: pass except: