mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Disable flash attention in zluda.py
Comment out flash attention related environment variables and code.
This commit is contained in:
parent
389515d350
commit
e3131abac6
@ -7,8 +7,9 @@ os.environ.pop("HIP_HOME", None)
|
||||
os.environ.pop("ROCM_VERSION", None)
|
||||
|
||||
#triton fix?
|
||||
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
|
||||
os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE"
|
||||
# disabling flash-attention
|
||||
# os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
|
||||
# os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE"
|
||||
os.environ["TRITON_DEBUG"] = "1" # Verbose logging
|
||||
|
||||
paths = os.environ["PATH"].split(";")
|
||||
@ -668,46 +669,46 @@ def do_hijack():
|
||||
triton.runtime.driver.active.utils.get_device_properties = patched_props
|
||||
print(" :: Triton device properties configured")
|
||||
|
||||
# Flash Attention
|
||||
flash_enabled = False
|
||||
try:
|
||||
from comfy.flash_attn_triton_amd import interface_fa
|
||||
print(" :: Flash attention components found")
|
||||
# # Flash Attention
|
||||
# flash_enabled = False
|
||||
# try:
|
||||
# from comfy.flash_attn_triton_amd import interface_fa
|
||||
# print(" :: Flash attention components found")
|
||||
|
||||
original_sdpa = torch.nn.functional.scaled_dot_product_attention
|
||||
# original_sdpa = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
try:
|
||||
if (query.shape[-1] <= 128 and
|
||||
attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
|
||||
query.dtype != torch.float32):
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** -0.5
|
||||
return interface_fa.fwd(
|
||||
query.transpose(1, 2),
|
||||
key.transpose(1, 2),
|
||||
value.transpose(1, 2),
|
||||
None, None, dropout_p, scale,
|
||||
is_causal, -1, -1, 0.0, False, None
|
||||
)[0].transpose(1, 2)
|
||||
except Exception as e:
|
||||
print(f' :: Flash attention error: {str(e)}')
|
||||
return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
# def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
# try:
|
||||
# if (query.shape[-1] <= 128 and
|
||||
# attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
|
||||
# query.dtype != torch.float32):
|
||||
# if scale is None:
|
||||
# scale = query.shape[-1] ** -0.5
|
||||
# return interface_fa.fwd(
|
||||
# query.transpose(1, 2),
|
||||
# key.transpose(1, 2),
|
||||
# value.transpose(1, 2),
|
||||
# None, None, dropout_p, scale,
|
||||
# is_causal, -1, -1, 0.0, False, None
|
||||
# )[0].transpose(1, 2)
|
||||
# except Exception as e:
|
||||
# print(f' :: Flash attention error: {str(e)}')
|
||||
# return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
|
||||
torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper
|
||||
flash_enabled = True
|
||||
print(" :: AMD flash attention enabled successfully")
|
||||
# torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper
|
||||
# flash_enabled = True
|
||||
# print(" :: AMD flash attention enabled successfully")
|
||||
|
||||
except ImportError:
|
||||
print(" :: Flash attention components not installed")
|
||||
except Exception as e:
|
||||
print(f" :: Flash attention setup failed: {str(e)}")
|
||||
# except ImportError:
|
||||
# print(" :: Flash attention components not installed")
|
||||
# except Exception as e:
|
||||
# print(f" :: Flash attention setup failed: {str(e)}")
|
||||
|
||||
# Other Triton optimizations
|
||||
if not flash_enabled:
|
||||
print(" :: Applying basic Triton optimizations")
|
||||
# Add other Triton optimizations here
|
||||
# ...
|
||||
# # Other Triton optimizations
|
||||
# if not flash_enabled:
|
||||
# print(" :: Applying basic Triton optimizations")
|
||||
# # Add other Triton optimizations here
|
||||
# # ...
|
||||
|
||||
except Exception as e:
|
||||
print(f" :: Triton optimization failed: {str(e)}")
|
||||
@ -720,7 +721,6 @@ def do_hijack():
|
||||
torch.backends.cuda.enable_mem_efficient_sdp = do_nothing
|
||||
if hasattr(torch.backends.cuda, "enable_flash_sdp"):
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
print(" :: Disabled CUDA flash attention")
|
||||
if hasattr(torch.backends.cuda, "enable_math_sdp"):
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
print(" :: Enabled math attention fallback")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user