diff --git a/comfy/customzluda/zluda.py b/comfy/customzluda/zluda.py index ecc72413c..97744a741 100644 --- a/comfy/customzluda/zluda.py +++ b/comfy/customzluda/zluda.py @@ -7,8 +7,9 @@ os.environ.pop("HIP_HOME", None) os.environ.pop("ROCM_VERSION", None) #triton fix? -os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" -os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE" +# disabling flash-attention +# os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" +# os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE" os.environ["TRITON_DEBUG"] = "1" # Verbose logging paths = os.environ["PATH"].split(";") @@ -668,46 +669,46 @@ def do_hijack(): triton.runtime.driver.active.utils.get_device_properties = patched_props print(" :: Triton device properties configured") - # Flash Attention - flash_enabled = False - try: - from comfy.flash_attn_triton_amd import interface_fa - print(" :: Flash attention components found") + # # Flash Attention + # flash_enabled = False + # try: + # from comfy.flash_attn_triton_amd import interface_fa + # print(" :: Flash attention components found") - original_sdpa = torch.nn.functional.scaled_dot_product_attention + # original_sdpa = torch.nn.functional.scaled_dot_product_attention - def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): - try: - if (query.shape[-1] <= 128 and - attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous" - query.dtype != torch.float32): - if scale is None: - scale = query.shape[-1] ** -0.5 - return interface_fa.fwd( - query.transpose(1, 2), - key.transpose(1, 2), - value.transpose(1, 2), - None, None, dropout_p, scale, - is_causal, -1, -1, 0.0, False, None - )[0].transpose(1, 2) - except Exception as e: - print(f' :: Flash attention error: {str(e)}') - return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + # def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + # try: + # if (query.shape[-1] <= 128 and + # attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous" + # query.dtype != torch.float32): + # if scale is None: + # scale = query.shape[-1] ** -0.5 + # return interface_fa.fwd( + # query.transpose(1, 2), + # key.transpose(1, 2), + # value.transpose(1, 2), + # None, None, dropout_p, scale, + # is_causal, -1, -1, 0.0, False, None + # )[0].transpose(1, 2) + # except Exception as e: + # print(f' :: Flash attention error: {str(e)}') + # return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper - flash_enabled = True - print(" :: AMD flash attention enabled successfully") + # torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper + # flash_enabled = True + # print(" :: AMD flash attention enabled successfully") - except ImportError: - print(" :: Flash attention components not installed") - except Exception as e: - print(f" :: Flash attention setup failed: {str(e)}") + # except ImportError: + # print(" :: Flash attention components not installed") + # except Exception as e: + # print(f" :: Flash attention setup failed: {str(e)}") - # Other Triton optimizations - if not flash_enabled: - print(" :: Applying basic Triton optimizations") - # Add other Triton optimizations here - # ... + # # Other Triton optimizations + # if not flash_enabled: + # print(" :: Applying basic Triton optimizations") + # # Add other Triton optimizations here + # # ... except Exception as e: print(f" :: Triton optimization failed: {str(e)}") @@ -720,7 +721,6 @@ def do_hijack(): torch.backends.cuda.enable_mem_efficient_sdp = do_nothing if hasattr(torch.backends.cuda, "enable_flash_sdp"): torch.backends.cuda.enable_flash_sdp(True) - print(" :: Disabled CUDA flash attention") if hasattr(torch.backends.cuda, "enable_math_sdp"): torch.backends.cuda.enable_math_sdp(True) print(" :: Enabled math attention fallback")