Disable flash attention in zluda.py

Comment out flash attention related environment variables and code.
This commit is contained in:
patientx 2025-12-08 15:10:12 +03:00 committed by GitHub
parent 389515d350
commit e3131abac6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,8 +7,9 @@ os.environ.pop("HIP_HOME", None)
os.environ.pop("ROCM_VERSION", None)
#triton fix?
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE"
# disabling flash-attention
# os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
# os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE"
os.environ["TRITON_DEBUG"] = "1" # Verbose logging
paths = os.environ["PATH"].split(";")
@ -668,46 +669,46 @@ def do_hijack():
triton.runtime.driver.active.utils.get_device_properties = patched_props
print(" :: Triton device properties configured")
# Flash Attention
flash_enabled = False
try:
from comfy.flash_attn_triton_amd import interface_fa
print(" :: Flash attention components found")
# # Flash Attention
# flash_enabled = False
# try:
# from comfy.flash_attn_triton_amd import interface_fa
# print(" :: Flash attention components found")
original_sdpa = torch.nn.functional.scaled_dot_product_attention
# original_sdpa = torch.nn.functional.scaled_dot_product_attention
def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
try:
if (query.shape[-1] <= 128 and
attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
query.dtype != torch.float32):
if scale is None:
scale = query.shape[-1] ** -0.5
return interface_fa.fwd(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
None, None, dropout_p, scale,
is_causal, -1, -1, 0.0, False, None
)[0].transpose(1, 2)
except Exception as e:
print(f' :: Flash attention error: {str(e)}')
return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
# def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
# try:
# if (query.shape[-1] <= 128 and
# attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
# query.dtype != torch.float32):
# if scale is None:
# scale = query.shape[-1] ** -0.5
# return interface_fa.fwd(
# query.transpose(1, 2),
# key.transpose(1, 2),
# value.transpose(1, 2),
# None, None, dropout_p, scale,
# is_causal, -1, -1, 0.0, False, None
# )[0].transpose(1, 2)
# except Exception as e:
# print(f' :: Flash attention error: {str(e)}')
# return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper
flash_enabled = True
print(" :: AMD flash attention enabled successfully")
# torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper
# flash_enabled = True
# print(" :: AMD flash attention enabled successfully")
except ImportError:
print(" :: Flash attention components not installed")
except Exception as e:
print(f" :: Flash attention setup failed: {str(e)}")
# except ImportError:
# print(" :: Flash attention components not installed")
# except Exception as e:
# print(f" :: Flash attention setup failed: {str(e)}")
# Other Triton optimizations
if not flash_enabled:
print(" :: Applying basic Triton optimizations")
# Add other Triton optimizations here
# ...
# # Other Triton optimizations
# if not flash_enabled:
# print(" :: Applying basic Triton optimizations")
# # Add other Triton optimizations here
# # ...
except Exception as e:
print(f" :: Triton optimization failed: {str(e)}")
@ -720,7 +721,6 @@ def do_hijack():
torch.backends.cuda.enable_mem_efficient_sdp = do_nothing
if hasattr(torch.backends.cuda, "enable_flash_sdp"):
torch.backends.cuda.enable_flash_sdp(True)
print(" :: Disabled CUDA flash attention")
if hasattr(torch.backends.cuda, "enable_math_sdp"):
torch.backends.cuda.enable_math_sdp(True)
print(" :: Enabled math attention fallback")