Replace xformers FMHA imports with mslk package

The xFormers project has migrated its fused multi-head attention (FMHA)
implementation to a new standalone package called `mslk` (Meta
Superintelligence Labs Kernels). The `xformers` package now re-exports
from `mslk` for backward compatibility, but direct dependence on `mslk`
is preferred going forward.

This commit updates all FMHA import sites to use
`mslk.attention.fmha` instead of `xformers.ops`. All user-facing
behavior -- CLI arguments, environment variables, log messages, error
messages, and documentation -- remains unchanged.

What changed:
- `import xformers` / `import xformers.ops` replaced with
  `import mslk` / `import mslk.attention.fmha` in:
    comfy/model_management.py
    comfy/ldm/modules/attention.py
    comfy/ldm/modules/diffusionmodules/model.py
    comfy/ldm/pixart/blocks.py
- Calls to `xformers.ops.memory_efficient_attention(...)` replaced with
  `mslk.attention.fmha.memory_efficient_attention(...)`.
- Version-gating logic for old xformers bugs (0.0.18, 0.0.2x) removed,
  as those versions predate the mslk migration.
- The pip dependency is now `mslk` rather than `xformers`.

This migration was prepared by the xFormers team. We have done our best
to ensure correctness and preserve all existing behavior, but we welcome
feedback from maintainers if anything should be adjusted.
This commit is contained in:
Luca Wehrstedt 2026-04-23 04:14:12 -07:00
parent db85cf03ff
commit 2fde495c33
4 changed files with 18 additions and 29 deletions

View File

@ -15,8 +15,9 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers # xFormers's fmha module is now provided by MSLK
import xformers.ops import mslk
import mslk.attention.fmha
SAGE_ATTENTION_IS_AVAILABLE = False SAGE_ATTENTION_IS_AVAILABLE = False
try: try:
@ -404,12 +405,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return r1 return r1
BROKEN_XFORMERS = False BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
except:
pass
@wrap_attn @wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
@ -463,7 +458,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
mask = mask_out[..., :mask.shape[-1]] mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1) mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) # xFormers's fmha module is now provided by MSLK
out = mslk.attention.fmha.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_output_reshape: if skip_output_reshape:
out = out.permute(0, 2, 1, 3) out = out.permute(0, 2, 1, 3)

View File

@ -10,8 +10,8 @@ import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers # xFormers's fmha module is now provided by MSLK
import xformers.ops import mslk.attention.fmha
def torch_cat_if_needed(xl, dim): def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0] xl = [x for x in xl if x is not None and x.shape[dim] > 0]
@ -295,7 +295,8 @@ def xformers_attention(q, k, v):
) )
try: try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # xFormers's fmha module is now provided by MSLK
out = mslk.attention.fmha.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(orig_shape) out = out.transpose(1, 2).reshape(orig_shape)
except NotImplementedError: except NotImplementedError:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)

View File

@ -10,11 +10,9 @@ from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, time
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
# if model_management.xformers_enabled(): # if model_management.xformers_enabled():
# import xformers.ops # # xFormers's fmha module is now provided by MSLK
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28: # import mslk.attention.fmha
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens # block_diagonal_mask_from_seqlens = mslk.attention.fmha.attn_bias.BlockDiagonalMask.from_seqlens
# else:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
def modulate(x, shift, scale): def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@ -51,7 +49,8 @@ class MultiHeadCrossAttention(nn.Module):
# attn_bias = None # attn_bias = None
# if mask is not None: # if mask is not None:
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) # attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) # # xFormers's fmha module is now provided by MSLK
# x = mslk.attention.fmha.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
# else: # else:
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) # q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
# attn_mask = None # attn_mask = None

View File

@ -294,20 +294,13 @@ if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
else: else:
try: try:
import xformers # xFormers's fmha module is now provided by MSLK
import xformers.ops import mslk
import mslk.attention.fmha
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
try: try:
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library XFORMERS_VERSION = mslk.__version__
except:
pass
try:
XFORMERS_VERSION = xformers.version.__version__
logging.info("xformers version: {}".format(XFORMERS_VERSION)) logging.info("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"):
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
XFORMERS_ENABLED_VAE = False
except: except:
pass pass
except: except: