diff --git a/main.py b/main.py index 51e062005..a757a5707 100644 --- a/main.py +++ b/main.py @@ -32,6 +32,21 @@ setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) def is_rocm_torch_build(): + try: + torch_spec = importlib.util.find_spec("torch") + if torch_spec is not None: + for folder in torch_spec.submodule_search_locations or []: + version_file = os.path.join(folder, "version.py") + if os.path.isfile(version_file): + spec = importlib.util.spec_from_file_location("torch_version_probe", version_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if bool(getattr(module, "hip", None)): + return True + break + except Exception: + pass + try: torch_version = importlib.metadata.version("torch").lower() except importlib.metadata.PackageNotFoundError: