Add pytorch-triton-rocm to dependencies when targeting AMD because accelerate needs to find it in the rocm repo

This commit is contained in:
doctorpangloss 2024-05-22 21:21:23 -07:00
parent 0fcd07962f
commit c0fc1d1458

View File

@ -146,6 +146,7 @@ def dependencies(force_nightly: bool = False) -> List[str]:
index_urls += [nvidia_torch_index]
elif _is_amd():
index_urls += [amd_torch_index]
_dependencies += ["pytorch-triton-rocm"]
else:
index_urls += [cpu_torch_index]