This commit is contained in:
Luca Wehrstedt 2026-05-08 10:45:56 +03:00 committed by GitHub
commit 6e7b15c240
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 29 deletions

View File

@ -17,8 +17,9 @@ from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled():
import xformers
import xformers.ops
# xFormers's fmha module is now provided by MSLK
import mslk
import mslk.attention.fmha
SAGE_ATTENTION_IS_AVAILABLE = False
try:
@ -415,12 +416,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return r1
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
except:
pass
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
@ -474,7 +469,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
# xFormers's fmha module is now provided by MSLK
out = mslk.attention.fmha.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)

View File

@ -10,8 +10,8 @@ import comfy.ops
ops = comfy.ops.disable_weight_init
if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
# xFormers's fmha module is now provided by MSLK
import mslk.attention.fmha
def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
@ -295,7 +295,8 @@ def xformers_attention(q, k, v):
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
# xFormers's fmha module is now provided by MSLK
out = mslk.attention.fmha.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(orig_shape)
except NotImplementedError:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)

View File

@ -10,11 +10,9 @@ from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, time
from comfy.ldm.modules.attention import optimized_attention
# if model_management.xformers_enabled():
# import xformers.ops
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
# else:
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
# # xFormers's fmha module is now provided by MSLK
# import mslk.attention.fmha
# block_diagonal_mask_from_seqlens = mslk.attention.fmha.attn_bias.BlockDiagonalMask.from_seqlens
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@ -51,7 +49,8 @@ class MultiHeadCrossAttention(nn.Module):
# attn_bias = None
# if mask is not None:
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
# # xFormers's fmha module is now provided by MSLK
# x = mslk.attention.fmha.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
# else:
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
# attn_mask = None

View File

@ -291,20 +291,13 @@ if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False
else:
try:
import xformers
import xformers.ops
# xFormers's fmha module is now provided by MSLK
import mslk
import mslk.attention.fmha
XFORMERS_IS_AVAILABLE = True
try:
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
except:
pass
try:
XFORMERS_VERSION = xformers.version.__version__
XFORMERS_VERSION = mslk.__version__
logging.info("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"):
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
XFORMERS_ENABLED_VAE = False
except:
pass
except: