diff --git a/main.py b/main.py index 12b04719d..569babbe6 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import os import importlib.util import shutil import importlib.metadata +import sysconfig import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram @@ -29,6 +30,40 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + +def configure_windows_rocm_sdk_env(): + if os.name != "nt": + return None + + if any(os.environ.get(var) for var in ("ROCM_HOME", "ROCM_PATH", "HIP_PATH")): + return None + + platlib = sysconfig.get_paths().get("platlib") + if not platlib: + return None + + rocm_sdk_root = os.path.join(platlib, "_rocm_sdk_core") + hip_header = os.path.join(rocm_sdk_root, "include", "hip", "hip_runtime.h") + if not os.path.isfile(hip_header): + return None + + os.environ["ROCM_HOME"] = rocm_sdk_root + os.environ["ROCM_PATH"] = rocm_sdk_root + os.environ["HIP_PATH"] = rocm_sdk_root + + rocm_bin = os.path.join(rocm_sdk_root, "bin") + if os.path.isdir(rocm_bin): + path_entries = os.environ.get("PATH", "").split(os.pathsep) + if rocm_bin not in path_entries: + os.environ["PATH"] = rocm_bin + os.pathsep + os.environ.get("PATH", "") + + return rocm_sdk_root + + +configured_rocm_sdk_root = configure_windows_rocm_sdk_env() +if configured_rocm_sdk_root is not None: + logging.info("Configured ROCm SDK environment from Python package path: %s", configured_rocm_sdk_root) + faulthandler.enable(file=sys.stderr, all_threads=False) import comfy_aimdo.control