Update zluda.py (MEM_BUS_WIDTH#3)

Lower casing the lookup inside MEM_BUS_WIDTH, just in case of incorrect casing on Radeon Pro (PRO) GPUs.

fixed/lower-casing "Triton device properties" lookup inside MEM_BUS_WIDTH.
This commit is contained in:
Rando717 2025-09-09 20:04:20 +02:00 committed by GitHub
parent 13ba6a8a8d
commit 4057f2984c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -507,6 +507,7 @@ except Exception as e:
# # ------------------- ZLUDA Core Implementation -------------------
MEM_BUS_WIDTH = {
k.lower(): v for k, v in {
"AMD Radeon RX 9070 XT": 256,
"AMD Radeon RX 9070": 256,
"AMD Radeon RX 9070 GRE": 192,
@ -559,7 +560,8 @@ MEM_BUS_WIDTH = {
"AMD Radeon PRO W6400": 64,
"AMD Radeon Pro W5700": 256,
"AMD Radeon Pro W5500": 128,
"AMD Radeon Pro VII": 4096,
"AMD Radeon Pro VII": 4096,
}.items()
}
# ------------------- Device Properties Implementation -------------------
@ -659,8 +661,8 @@ def do_hijack():
def patched_props(device):
props = _get_props(device)
name = torch.cuda.get_device_name()[:-8] # Remove [ZLUDA]
props["mem_bus_width"] = MEM_BUS_WIDTH.get(name, 128)
if name not in MEM_BUS_WIDTH:
props["mem_bus_width"] = MEM_BUS_WIDTH.get(name.lower(), 128)
if name.lower() not in MEM_BUS_WIDTH:
print(f' :: Using default mem_bus_width=128 for {name}')
return props
triton.runtime.driver.active.utils.get_device_properties = patched_props