diff --git a/comfy/customzluda/zluda.py b/comfy/customzluda/zluda.py index 978f1a5d6..ecc72413c 100644 --- a/comfy/customzluda/zluda.py +++ b/comfy/customzluda/zluda.py @@ -507,6 +507,7 @@ except Exception as e: # # ------------------- ZLUDA Core Implementation ------------------- MEM_BUS_WIDTH = { + k.lower(): v for k, v in { "AMD Radeon RX 9070 XT": 256, "AMD Radeon RX 9070": 256, "AMD Radeon RX 9070 GRE": 192, @@ -559,7 +560,8 @@ MEM_BUS_WIDTH = { "AMD Radeon PRO W6400": 64, "AMD Radeon Pro W5700": 256, "AMD Radeon Pro W5500": 128, - "AMD Radeon Pro VII": 4096, + "AMD Radeon Pro VII": 4096, + }.items() } # ------------------- Device Properties Implementation ------------------- @@ -659,8 +661,8 @@ def do_hijack(): def patched_props(device): props = _get_props(device) name = torch.cuda.get_device_name()[:-8] # Remove [ZLUDA] - props["mem_bus_width"] = MEM_BUS_WIDTH.get(name, 128) - if name not in MEM_BUS_WIDTH: + props["mem_bus_width"] = MEM_BUS_WIDTH.get(name.lower(), 128) + if name.lower() not in MEM_BUS_WIDTH: print(f' :: Using default mem_bus_width=128 for {name}') return props triton.runtime.driver.active.utils.get_device_properties = patched_props