Align ROCm detection with torch.version.hip

This commit is contained in:
野生の男 2026-04-03 01:47:40 +09:00
parent eb465ad574
commit 77107e22a1

15
main.py
View File

@ -32,6 +32,21 @@ setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
def is_rocm_torch_build():
try:
torch_spec = importlib.util.find_spec("torch")
if torch_spec is not None:
for folder in torch_spec.submodule_search_locations or []:
version_file = os.path.join(folder, "version.py")
if os.path.isfile(version_file):
spec = importlib.util.spec_from_file_location("torch_version_probe", version_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if bool(getattr(module, "hip", None)):
return True
break
except Exception:
pass
try:
torch_version = importlib.metadata.version("torch").lower()
except importlib.metadata.PackageNotFoundError: