mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 23:00:20 +08:00
Disable flash attention in zluda.py
Comment out flash attention related environment variables and code.
This commit is contained in:
parent
389515d350
commit
e3131abac6
@ -7,8 +7,9 @@ os.environ.pop("HIP_HOME", None)
|
|||||||
os.environ.pop("ROCM_VERSION", None)
|
os.environ.pop("ROCM_VERSION", None)
|
||||||
|
|
||||||
#triton fix?
|
#triton fix?
|
||||||
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
|
# disabling flash-attention
|
||||||
os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE"
|
# os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
|
||||||
|
# os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE"
|
||||||
os.environ["TRITON_DEBUG"] = "1" # Verbose logging
|
os.environ["TRITON_DEBUG"] = "1" # Verbose logging
|
||||||
|
|
||||||
paths = os.environ["PATH"].split(";")
|
paths = os.environ["PATH"].split(";")
|
||||||
@ -668,46 +669,46 @@ def do_hijack():
|
|||||||
triton.runtime.driver.active.utils.get_device_properties = patched_props
|
triton.runtime.driver.active.utils.get_device_properties = patched_props
|
||||||
print(" :: Triton device properties configured")
|
print(" :: Triton device properties configured")
|
||||||
|
|
||||||
# Flash Attention
|
# # Flash Attention
|
||||||
flash_enabled = False
|
# flash_enabled = False
|
||||||
try:
|
# try:
|
||||||
from comfy.flash_attn_triton_amd import interface_fa
|
# from comfy.flash_attn_triton_amd import interface_fa
|
||||||
print(" :: Flash attention components found")
|
# print(" :: Flash attention components found")
|
||||||
|
|
||||||
original_sdpa = torch.nn.functional.scaled_dot_product_attention
|
# original_sdpa = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
|
||||||
def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
# def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||||
try:
|
# try:
|
||||||
if (query.shape[-1] <= 128 and
|
# if (query.shape[-1] <= 128 and
|
||||||
attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
|
# attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
|
||||||
query.dtype != torch.float32):
|
# query.dtype != torch.float32):
|
||||||
if scale is None:
|
# if scale is None:
|
||||||
scale = query.shape[-1] ** -0.5
|
# scale = query.shape[-1] ** -0.5
|
||||||
return interface_fa.fwd(
|
# return interface_fa.fwd(
|
||||||
query.transpose(1, 2),
|
# query.transpose(1, 2),
|
||||||
key.transpose(1, 2),
|
# key.transpose(1, 2),
|
||||||
value.transpose(1, 2),
|
# value.transpose(1, 2),
|
||||||
None, None, dropout_p, scale,
|
# None, None, dropout_p, scale,
|
||||||
is_causal, -1, -1, 0.0, False, None
|
# is_causal, -1, -1, 0.0, False, None
|
||||||
)[0].transpose(1, 2)
|
# )[0].transpose(1, 2)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f' :: Flash attention error: {str(e)}')
|
# print(f' :: Flash attention error: {str(e)}')
|
||||||
return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
# return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||||
|
|
||||||
torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper
|
# torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper
|
||||||
flash_enabled = True
|
# flash_enabled = True
|
||||||
print(" :: AMD flash attention enabled successfully")
|
# print(" :: AMD flash attention enabled successfully")
|
||||||
|
|
||||||
except ImportError:
|
# except ImportError:
|
||||||
print(" :: Flash attention components not installed")
|
# print(" :: Flash attention components not installed")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f" :: Flash attention setup failed: {str(e)}")
|
# print(f" :: Flash attention setup failed: {str(e)}")
|
||||||
|
|
||||||
# Other Triton optimizations
|
# # Other Triton optimizations
|
||||||
if not flash_enabled:
|
# if not flash_enabled:
|
||||||
print(" :: Applying basic Triton optimizations")
|
# print(" :: Applying basic Triton optimizations")
|
||||||
# Add other Triton optimizations here
|
# # Add other Triton optimizations here
|
||||||
# ...
|
# # ...
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" :: Triton optimization failed: {str(e)}")
|
print(f" :: Triton optimization failed: {str(e)}")
|
||||||
@ -720,7 +721,6 @@ def do_hijack():
|
|||||||
torch.backends.cuda.enable_mem_efficient_sdp = do_nothing
|
torch.backends.cuda.enable_mem_efficient_sdp = do_nothing
|
||||||
if hasattr(torch.backends.cuda, "enable_flash_sdp"):
|
if hasattr(torch.backends.cuda, "enable_flash_sdp"):
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
print(" :: Disabled CUDA flash attention")
|
|
||||||
if hasattr(torch.backends.cuda, "enable_math_sdp"):
|
if hasattr(torch.backends.cuda, "enable_math_sdp"):
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
print(" :: Enabled math attention fallback")
|
print(" :: Enabled math attention fallback")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user