From 77107e22a1690b2a388fbc69b889453bfd744948 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=87=8E=E7=94=9F=E3=81=AE=E7=94=B7?= Date: Fri, 3 Apr 2026 01:47:40 +0900 Subject: [PATCH] Align ROCm detection with torch.version.hip --- main.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/main.py b/main.py index 51e062005..a757a5707 100644 --- a/main.py +++ b/main.py @@ -32,6 +32,21 @@ setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) def is_rocm_torch_build(): + try: + torch_spec = importlib.util.find_spec("torch") + if torch_spec is not None: + for folder in torch_spec.submodule_search_locations or []: + version_file = os.path.join(folder, "version.py") + if os.path.isfile(version_file): + spec = importlib.util.spec_from_file_location("torch_version_probe", version_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if bool(getattr(module, "hip", None)): + return True + break + except Exception: + pass + try: torch_version = importlib.metadata.version("torch").lower() except importlib.metadata.PackageNotFoundError: