Fix xformers import statements in comfy.ldm.modules.attention

This commit is contained in:
Max Tretikov 2024-06-14 11:21:08 -06:00
parent 74023da3a0
commit 6c53388619

View File

@ -11,8 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from ... import model_management
if model_management.xformers_enabled():
import xformers
import xformers.ops
import xformers # pylint: disable=import-error
from ...cli_args import args
from ... import ops
@ -303,12 +302,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None):
return r1
BROKEN_XFORMERS = False
try:
if model_management.xformers_enabled():
x_vers = xformers.__version__
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape