diff --git a/comfy/customzluda/zluda.py b/comfy/customzluda/zluda.py index ff73a0b5b..d54fe8d74 100644 --- a/comfy/customzluda/zluda.py +++ b/comfy/customzluda/zluda.py @@ -7,17 +7,23 @@ os.environ.pop("HIP_HOME", None) os.environ.pop("ROCM_VERSION", None) #triton fix? -os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" -os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE" -os.environ["TRITON_DEBUG"] = "1" # Verbose logging +if "FLASH_ATTENTION_TRITON_AMD_ENABLE" not in os.environ: + os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" +if "FLASH_ATTENTION_TRITON_AMD_AUTOTUNE" not in os.environ: + os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "TRUE" +if "TRITON_DEBUG" not in os.environ: + os.environ["TRITON_DEBUG"] = "1" # Verbose logging + +paths = os.environ["PATH"].split(os.pathsep) +paths_with_rocm = [p for p in paths if "rocm" in p.lower()] +if paths_with_rocm: + print("Warning: ROCm related path(s) may need to be removed from %PATH%: " + os.pathsep.join(paths_with_rocm)) -paths = os.environ["PATH"].split(";") -paths_no_rocm = [p for p in paths if "rocm" not in p.lower()] -os.environ["PATH"] = ";".join(paths_no_rocm) # ------------------- End ROCm/HIP Hiding ------------- # Fix for cublasLt errors on newer ZLUDA (if no hipblaslt) -os.environ['DISABLE_ADDMM_CUDA_LT'] = '1' +if "DISABLE_ADDMM_CUDA_LT" not in os.environ: + os.environ['DISABLE_ADDMM_CUDA_LT'] = '1' # ------------------- main imports ------------------- # main imports