From c0fc1d14584acf285cf1b4413a04316d12da854a Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 22 May 2024 21:21:23 -0700 Subject: [PATCH] Add pytorch-triton-rocm to dependencies when targeting AMD because accelerate needs to find it in the rocm repo --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index d2a21c4ad..c8b2eb923 100644 --- a/setup.py +++ b/setup.py @@ -146,6 +146,7 @@ def dependencies(force_nightly: bool = False) -> List[str]: index_urls += [nvidia_torch_index] elif _is_amd(): index_urls += [amd_torch_index] + _dependencies += ["pytorch-triton-rocm"] else: index_urls += [cpu_torch_index]