mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge pull request #297 from Rando717/Rando717-zluda.py
zluda.py "Expanded gfx identifier, lowercase gpu search, detect Triton version"
This commit is contained in:
commit
aae8c1486f
@ -88,55 +88,71 @@ def detect_amd_gpu_architecture():
|
||||
print(f" :: GPU detection failed: {str(e)}")
|
||||
return None
|
||||
|
||||
def gpu_name_to_gfx(gpu_name):
|
||||
def gpu_name_to_gfx(gpu_name: str) -> str:
|
||||
"""
|
||||
Map GPU names to their corresponding gfx architecture codes
|
||||
Map GPU names to their corresponding gfx architecture codes.
|
||||
Uses prioritized rules with substring matching.
|
||||
"""
|
||||
gpu_name_lower = gpu_name.lower()
|
||||
|
||||
# RDNA3 (gfx11xx)
|
||||
if any(x in gpu_name_lower for x in ['rx 7900', 'rx 7800', 'rx 7700', 'rx 7600', 'rx 7500']):
|
||||
if 'rx 7900' in gpu_name_lower:
|
||||
return 'gfx1100' # Navi 31
|
||||
elif 'rx 7800' in gpu_name_lower or 'rx 7700' in gpu_name_lower:
|
||||
return 'gfx1101' # Navi 32
|
||||
elif 'rx 7600' in gpu_name_lower or 'rx 7500' in gpu_name_lower:
|
||||
return 'gfx1102' # Navi 33
|
||||
|
||||
# RDNA2 (gfx10xx)
|
||||
elif any(x in gpu_name_lower for x in ['rx 6950', 'rx 6900', 'rx 6800', 'rx 6750', 'rx 6700']):
|
||||
return 'gfx1030' # Navi 21/22
|
||||
elif any(x in gpu_name_lower for x in ['rx 6650', 'rx 6600', 'rx 6500', 'rx 6400']):
|
||||
return 'gfx1032' # Navi 23/24
|
||||
|
||||
# RDNA1 (gfx10xx)
|
||||
elif any(x in gpu_name_lower for x in ['rx 5700', 'rx 5600', 'rx 5500']):
|
||||
return 'gfx1010' # Navi 10
|
||||
|
||||
# Vega (gfx9xx)
|
||||
elif any(x in gpu_name_lower for x in ['vega 64', 'vega 56', 'vega 20', 'radeon vii']):
|
||||
return 'gfx900' # Vega 10/20
|
||||
elif 'vega 11' in gpu_name_lower or 'vega 8' in gpu_name_lower:
|
||||
return 'gfx902' # Raven Ridge APU
|
||||
|
||||
# Polaris (gfx8xx)
|
||||
elif any(x in gpu_name_lower for x in ['rx 580', 'rx 570', 'rx 480', 'rx 470']):
|
||||
return 'gfx803' # Polaris 10/20
|
||||
elif any(x in gpu_name_lower for x in ['rx 560', 'rx 550', 'rx 460']):
|
||||
return 'gfx803' # Polaris 11/12
|
||||
|
||||
|
||||
# List of (substrings, gfx_arch, comment)
|
||||
rules = [
|
||||
# RDNA4 (gfx12xx)
|
||||
(['rx 9060'], 'gfx1200', 'Navi 44'),
|
||||
(['rx 9070', 'r9070'], 'gfx1201', 'Navi 48'),
|
||||
|
||||
# RDNA3.5 (gfx115x)
|
||||
(['890m'], 'gfx1150', 'Strix Point'),
|
||||
(['8060s', '8050s', '8040s', '880m'], 'gfx1151', 'Strix Halo'),
|
||||
(['860m', '840m', '820m'], 'gfx1152', 'Krackan Point'),
|
||||
|
||||
# RDNA3 (gfx110x)
|
||||
(['rx 7900', 'w7900', 'w7800'], 'gfx1100', 'Navi 31'),
|
||||
(['rx 7800', 'rx 7700', 'w7700'], 'gfx1101', 'Navi 32'),
|
||||
(['rx 7700s', 'rx 7650', 'rx 7600', 'w7600', 'w7500', 'rx 7400', 'w7400'], 'gfx1102', 'Navi 33'),
|
||||
(['780m', '760m', '740m'], 'gfx1100', 'Hawk Point'),
|
||||
|
||||
# RDNA2 (gfx103x)
|
||||
(['rx 6800m'], 'gfx1031', 'Navi 22'),
|
||||
(['rx 6800s', 'rx 6700s'], 'gfx1032', 'Navi 23'), # must be before 'rx 6800'
|
||||
(['rx 6950', 'rx 6900', 'rx 6800', 'w6800'], 'gfx1030', 'Navi 21'),
|
||||
(['rx 6850', 'rx 6750', 'rx 6700'], 'gfx1031', 'Navi 22'),
|
||||
(['rx 6650', 'rx 6600', 'w6600'], 'gfx1032', 'Navi 23'),
|
||||
(['rx 6550', 'rx 6500', 'w6500', 'rx 6450', 'rx 6400', 'w6400', 'rx 6300', 'w6300'], 'gfx1034', 'Navi 24'),
|
||||
(['680m', '660m'], 'gfx1035', 'Rembrandt'),
|
||||
(['610m'], 'gfx1037', ''),
|
||||
|
||||
# RDNA1 (gfx101x)
|
||||
(['rx 5700', 'w5700', 'rx 5600'], 'gfx1010', 'Navi 10'),
|
||||
(['rx 5500', 'w5500', 'rx 5300', 'w5300'], 'gfx1012', 'Navi 14'),
|
||||
|
||||
# Vega (gfx90x)
|
||||
(['vega 64', 'vega 56', 'frontier'], 'gfx900', 'Vega 10'),
|
||||
(['radeon vii', 'radeon pro vii'], 'gfx906', 'Vega 20'),
|
||||
(['vega 11', 'vega 10', 'vega 9', 'vega 8', 'vega 6', 'vega 3'], 'gfx902', 'Raven Ridge'),
|
||||
|
||||
# Polaris (gfx80x)
|
||||
(['rx 590', 'rx 580', 'rx 570', 'rx 560', 'rx 480', 'rx 470', 'rx 460'], 'gfx803', 'Polaris 10'),
|
||||
(['rx 640', 'rx 550', 'rx 540'], 'gfx804', 'Polaris 12'),
|
||||
]
|
||||
|
||||
# Apply rules in order (priority matters)
|
||||
for substrings, gfx, _ in rules:
|
||||
if any(sub in gpu_name_lower for sub in substrings):
|
||||
return gfx
|
||||
|
||||
# Default fallback - try to extract numbers and make educated guess
|
||||
if 'rx 9' in gpu_name_lower: # Future RDNA4?
|
||||
return 'gfx1200' # Anticipated next gen
|
||||
elif 'rx 8' in gpu_name_lower: # Future RDNA4?
|
||||
return 'gfx1150' # Anticipated next gen
|
||||
return 'gfx1200' # Default RDNA4
|
||||
elif 'rx 8' in gpu_name_lower: # Future RDNA3.5?
|
||||
return 'gfx1150' # Default RDNA3.5
|
||||
elif 'rx 7' in gpu_name_lower:
|
||||
return 'gfx1100' # Default RDNA3
|
||||
elif 'rx 6' in gpu_name_lower:
|
||||
return 'gfx1030' # Default RDNA2
|
||||
elif 'rx 5' in gpu_name_lower:
|
||||
return 'gfx1010' # Default RDNA1
|
||||
|
||||
|
||||
print(f" :: Unknown GPU model: {gpu_name}, using default gfx1030")
|
||||
return 'gfx1030' # Safe default for most modern AMD GPUs
|
||||
|
||||
@ -431,6 +447,12 @@ try:
|
||||
import triton.language as tl
|
||||
print(" :: Triton core imported successfully")
|
||||
|
||||
# Detect Triton version, if possible
|
||||
version = getattr(triton, "__version__", None)
|
||||
if version:
|
||||
print(f" :: Detected Triton version: {version}")
|
||||
# else: do nothing
|
||||
|
||||
# This needs to be up here, so it can disable cudnn before anything can even think about using it
|
||||
torch.backends.cudnn.enabled = os.environ.get("TORCH_BACKENDS_CUDNN_ENABLED", "1").strip().lower() not in {"0", "off", "false", "disable", "disabled", "no"}
|
||||
if torch.backends.cudnn.enabled:
|
||||
@ -485,6 +507,7 @@ except Exception as e:
|
||||
|
||||
# # ------------------- ZLUDA Core Implementation -------------------
|
||||
MEM_BUS_WIDTH = {
|
||||
k.lower(): v for k, v in {
|
||||
"AMD Radeon RX 9070 XT": 256,
|
||||
"AMD Radeon RX 9070": 256,
|
||||
"AMD Radeon RX 9070 GRE": 192,
|
||||
@ -520,7 +543,10 @@ MEM_BUS_WIDTH = {
|
||||
"AMD Radeon RX 5500 XT": 128,
|
||||
"AMD Radeon RX 5500": 128,
|
||||
"AMD Radeon RX 5300": 96,
|
||||
# AMD Radeon Pro R9000/W7000/W6000/W5000 series, Apple exclusive WX series not listed
|
||||
"AMD Radeon Vega Frontier Edition": 2048,
|
||||
"AMD Radeon RX Vega 64": 2048,
|
||||
"AMD Radeon RX Vega 56": 2048,
|
||||
"AMD Radeon VII": 4096,
|
||||
"AMD Radeon AI PRO R9700": 256,
|
||||
"AMD Radeon PRO W7900": 384,
|
||||
"AMD Radeon PRO W7800 48GB": 384,
|
||||
@ -532,8 +558,10 @@ MEM_BUS_WIDTH = {
|
||||
"AMD Radeon PRO W6800": 256,
|
||||
"AMD Radeon PRO W6600": 128,
|
||||
"AMD Radeon PRO W6400": 64,
|
||||
"AMD Radeon PRO W5700": 256,
|
||||
"AMD Radeon PRO W5500": 128,
|
||||
"AMD Radeon Pro W5700": 256,
|
||||
"AMD Radeon Pro W5500": 128,
|
||||
"AMD Radeon Pro VII": 4096,
|
||||
}.items()
|
||||
}
|
||||
|
||||
# ------------------- Device Properties Implementation -------------------
|
||||
@ -633,8 +661,8 @@ def do_hijack():
|
||||
def patched_props(device):
|
||||
props = _get_props(device)
|
||||
name = torch.cuda.get_device_name()[:-8] # Remove [ZLUDA]
|
||||
props["mem_bus_width"] = MEM_BUS_WIDTH.get(name, 128)
|
||||
if name not in MEM_BUS_WIDTH:
|
||||
props["mem_bus_width"] = MEM_BUS_WIDTH.get(name.lower(), 128)
|
||||
if name.lower() not in MEM_BUS_WIDTH:
|
||||
print(f' :: Using default mem_bus_width=128 for {name}')
|
||||
return props
|
||||
triton.runtime.driver.active.utils.get_device_properties = patched_props
|
||||
|
||||
Loading…
Reference in New Issue
Block a user