From 2fde495c33658cd79de082570c8639005faed08a Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Thu, 23 Apr 2026 04:14:12 -0700 Subject: [PATCH] Replace xformers FMHA imports with mslk package The xFormers project has migrated its fused multi-head attention (FMHA) implementation to a new standalone package called `mslk` (Meta Superintelligence Labs Kernels). The `xformers` package now re-exports from `mslk` for backward compatibility, but direct dependence on `mslk` is preferred going forward. This commit updates all FMHA import sites to use `mslk.attention.fmha` instead of `xformers.ops`. All user-facing behavior -- CLI arguments, environment variables, log messages, error messages, and documentation -- remains unchanged. What changed: - `import xformers` / `import xformers.ops` replaced with `import mslk` / `import mslk.attention.fmha` in: comfy/model_management.py comfy/ldm/modules/attention.py comfy/ldm/modules/diffusionmodules/model.py comfy/ldm/pixart/blocks.py - Calls to `xformers.ops.memory_efficient_attention(...)` replaced with `mslk.attention.fmha.memory_efficient_attention(...)`. - Version-gating logic for old xformers bugs (0.0.18, 0.0.2x) removed, as those versions predate the mslk migration. - The pip dependency is now `mslk` rather than `xformers`. This migration was prepared by the xFormers team. We have done our best to ensure correctness and preserve all existing behavior, but we welcome feedback from maintainers if anything should be adjusted. --- comfy/ldm/modules/attention.py | 14 +++++--------- comfy/ldm/modules/diffusionmodules/model.py | 7 ++++--- comfy/ldm/pixart/blocks.py | 11 +++++------ comfy/model_management.py | 15 ++++----------- 4 files changed, 18 insertions(+), 29 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b193fe5e8..ba52540a8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -15,8 +15,9 @@ from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management if model_management.xformers_enabled(): - import xformers - import xformers.ops + # xFormers's fmha module is now provided by MSLK + import mslk + import mslk.attention.fmha SAGE_ATTENTION_IS_AVAILABLE = False try: @@ -404,12 +405,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return r1 BROKEN_XFORMERS = False -try: - x_vers = xformers.__version__ - # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error) - BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") -except: - pass @wrap_attn def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): @@ -463,7 +458,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh mask = mask_out[..., :mask.shape[-1]] mask = mask.expand(b, heads, -1, -1) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) + # xFormers's fmha module is now provided by MSLK + out = mslk.attention.fmha.memory_efficient_attention(q, k, v, attn_bias=mask) if skip_output_reshape: out = out.permute(0, 2, 1, 3) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fcbaa074f..4c8375742 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -10,8 +10,8 @@ import comfy.ops ops = comfy.ops.disable_weight_init if model_management.xformers_enabled_vae(): - import xformers - import xformers.ops + # xFormers's fmha module is now provided by MSLK + import mslk.attention.fmha def torch_cat_if_needed(xl, dim): xl = [x for x in xl if x is not None and x.shape[dim] > 0] @@ -295,7 +295,8 @@ def xformers_attention(q, k, v): ) try: - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + # xFormers's fmha module is now provided by MSLK + out = mslk.attention.fmha.memory_efficient_attention(q, k, v, attn_bias=None) out = out.transpose(1, 2).reshape(orig_shape) except NotImplementedError: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py index 2225076e5..335e6ef77 100644 --- a/comfy/ldm/pixart/blocks.py +++ b/comfy/ldm/pixart/blocks.py @@ -10,11 +10,9 @@ from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, time from comfy.ldm.modules.attention import optimized_attention # if model_management.xformers_enabled(): -# import xformers.ops -# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28: -# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens -# else: -# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens +# # xFormers's fmha module is now provided by MSLK +# import mslk.attention.fmha +# block_diagonal_mask_from_seqlens = mslk.attention.fmha.attn_bias.BlockDiagonalMask.from_seqlens def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -51,7 +49,8 @@ class MultiHeadCrossAttention(nn.Module): # attn_bias = None # if mask is not None: # attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) - # x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) + # # xFormers's fmha module is now provided by MSLK + # x = mslk.attention.fmha.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) # else: # q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) # attn_mask = None diff --git a/comfy/model_management.py b/comfy/model_management.py index bcf1399c4..f2cd7a3e4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -294,20 +294,13 @@ if args.disable_xformers: XFORMERS_IS_AVAILABLE = False else: try: - import xformers - import xformers.ops + # xFormers's fmha module is now provided by MSLK + import mslk + import mslk.attention.fmha XFORMERS_IS_AVAILABLE = True try: - XFORMERS_IS_AVAILABLE = xformers._has_cpp_library - except: - pass - try: - XFORMERS_VERSION = xformers.version.__version__ + XFORMERS_VERSION = mslk.__version__ logging.info("xformers version: {}".format(XFORMERS_VERSION)) - if XFORMERS_VERSION.startswith("0.0.18"): - logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") - logging.warning("Please downgrade or upgrade xformers to a different version.\n") - XFORMERS_ENABLED_VAE = False except: pass except: