diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c9a0a2183..f2b0ced2a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -54,6 +54,8 @@ jobs: container: "ubuntu:22.04" - labels: [self-hosted, Linux, X64, cuda-3060-12gb] container: "nvcr.io/nvidia/pytorch:24.03-py3" + - labels: [self-hosted, Linux, X64, rocm-7600-8gb] + container: "rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0" steps: - run: | apt update @@ -72,11 +74,8 @@ jobs: - name: Run tests run: | export HSA_OVERRIDE_GFX_VERSION=11.0.0 - export TORCH_BLAS_PREFER_HIPBLASLT=0 export HIP_VISIBLE_DEVICES=0 - export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True - export NUMBA_THREADING_LAYER=omp - export AMD_SERIALIZE_KERNEL=1 + nvidia-smi || rocminfo || true pytest -v tests/unit - name: Lint for errors run: |