mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Update zluda.py (MEM_BUS_WIDTH#3)
Lower casing the lookup inside MEM_BUS_WIDTH, just in case of incorrect casing on Radeon Pro (PRO) GPUs. fixed/lower-casing "Triton device properties" lookup inside MEM_BUS_WIDTH.
This commit is contained in:
parent
13ba6a8a8d
commit
4057f2984c
@ -507,6 +507,7 @@ except Exception as e:
|
||||
|
||||
# # ------------------- ZLUDA Core Implementation -------------------
|
||||
MEM_BUS_WIDTH = {
|
||||
k.lower(): v for k, v in {
|
||||
"AMD Radeon RX 9070 XT": 256,
|
||||
"AMD Radeon RX 9070": 256,
|
||||
"AMD Radeon RX 9070 GRE": 192,
|
||||
@ -559,7 +560,8 @@ MEM_BUS_WIDTH = {
|
||||
"AMD Radeon PRO W6400": 64,
|
||||
"AMD Radeon Pro W5700": 256,
|
||||
"AMD Radeon Pro W5500": 128,
|
||||
"AMD Radeon Pro VII": 4096,
|
||||
"AMD Radeon Pro VII": 4096,
|
||||
}.items()
|
||||
}
|
||||
|
||||
# ------------------- Device Properties Implementation -------------------
|
||||
@ -659,8 +661,8 @@ def do_hijack():
|
||||
def patched_props(device):
|
||||
props = _get_props(device)
|
||||
name = torch.cuda.get_device_name()[:-8] # Remove [ZLUDA]
|
||||
props["mem_bus_width"] = MEM_BUS_WIDTH.get(name, 128)
|
||||
if name not in MEM_BUS_WIDTH:
|
||||
props["mem_bus_width"] = MEM_BUS_WIDTH.get(name.lower(), 128)
|
||||
if name.lower() not in MEM_BUS_WIDTH:
|
||||
print(f' :: Using default mem_bus_width=128 for {name}')
|
||||
return props
|
||||
triton.runtime.driver.active.utils.get_device_properties = patched_props
|
||||
|
||||
Loading…
Reference in New Issue
Block a user