Disable flash attention in zluda.py

Comment out flash attention related environment variables and code.
This commit is contained in:
patientx 2025-12-08 15:10:12 +03:00 committed by GitHub
parent 389515d350
commit e3131abac6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,8 +7,9 @@ os.environ.pop("HIP_HOME", None)
os.environ.pop("ROCM_VERSION", None) os.environ.pop("ROCM_VERSION", None)
#triton fix? #triton fix?
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" # disabling flash-attention
os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE" # os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
# os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE"
os.environ["TRITON_DEBUG"] = "1" # Verbose logging os.environ["TRITON_DEBUG"] = "1" # Verbose logging
paths = os.environ["PATH"].split(";") paths = os.environ["PATH"].split(";")
@ -668,46 +669,46 @@ def do_hijack():
triton.runtime.driver.active.utils.get_device_properties = patched_props triton.runtime.driver.active.utils.get_device_properties = patched_props
print(" :: Triton device properties configured") print(" :: Triton device properties configured")
# Flash Attention # # Flash Attention
flash_enabled = False # flash_enabled = False
try: # try:
from comfy.flash_attn_triton_amd import interface_fa # from comfy.flash_attn_triton_amd import interface_fa
print(" :: Flash attention components found") # print(" :: Flash attention components found")
original_sdpa = torch.nn.functional.scaled_dot_product_attention # original_sdpa = torch.nn.functional.scaled_dot_product_attention
def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): # def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
try: # try:
if (query.shape[-1] <= 128 and # if (query.shape[-1] <= 128 and
attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous" # attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
query.dtype != torch.float32): # query.dtype != torch.float32):
if scale is None: # if scale is None:
scale = query.shape[-1] ** -0.5 # scale = query.shape[-1] ** -0.5
return interface_fa.fwd( # return interface_fa.fwd(
query.transpose(1, 2), # query.transpose(1, 2),
key.transpose(1, 2), # key.transpose(1, 2),
value.transpose(1, 2), # value.transpose(1, 2),
None, None, dropout_p, scale, # None, None, dropout_p, scale,
is_causal, -1, -1, 0.0, False, None # is_causal, -1, -1, 0.0, False, None
)[0].transpose(1, 2) # )[0].transpose(1, 2)
except Exception as e: # except Exception as e:
print(f' :: Flash attention error: {str(e)}') # print(f' :: Flash attention error: {str(e)}')
return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) # return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper # torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper
flash_enabled = True # flash_enabled = True
print(" :: AMD flash attention enabled successfully") # print(" :: AMD flash attention enabled successfully")
except ImportError: # except ImportError:
print(" :: Flash attention components not installed") # print(" :: Flash attention components not installed")
except Exception as e: # except Exception as e:
print(f" :: Flash attention setup failed: {str(e)}") # print(f" :: Flash attention setup failed: {str(e)}")
# Other Triton optimizations # # Other Triton optimizations
if not flash_enabled: # if not flash_enabled:
print(" :: Applying basic Triton optimizations") # print(" :: Applying basic Triton optimizations")
# Add other Triton optimizations here # # Add other Triton optimizations here
# ... # # ...
except Exception as e: except Exception as e:
print(f" :: Triton optimization failed: {str(e)}") print(f" :: Triton optimization failed: {str(e)}")
@ -720,7 +721,6 @@ def do_hijack():
torch.backends.cuda.enable_mem_efficient_sdp = do_nothing torch.backends.cuda.enable_mem_efficient_sdp = do_nothing
if hasattr(torch.backends.cuda, "enable_flash_sdp"): if hasattr(torch.backends.cuda, "enable_flash_sdp"):
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
print(" :: Disabled CUDA flash attention")
if hasattr(torch.backends.cuda, "enable_math_sdp"): if hasattr(torch.backends.cuda, "enable_math_sdp"):
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
print(" :: Enabled math attention fallback") print(" :: Enabled math attention fallback")