device identification and setting triton arch override

This commit is contained in:
patientx 2025-08-04 10:44:18 +03:00 committed by GitHub
parent d823c0c615
commit 37415c40c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,6 +35,134 @@ from typing import Union, List
from enum import Enum
# ------------------- main imports -------------------
# ------------------- gfx detection -------------------
import os
import re
def detect_amd_gpu_architecture():
"""
Detect AMD GPU architecture on Windows and return the appropriate gfx code for TRITON_OVERRIDE_ARCH
"""
try:
# Method 1: Try Windows registry
try:
import winreg
key_path = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e968-e325-11ce-bfc1-08002be10318}"
with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_path) as key:
i = 0
while True:
try:
subkey_name = winreg.EnumKey(key, i)
with winreg.OpenKey(key, subkey_name) as subkey:
try:
desc = winreg.QueryValueEx(subkey, "DriverDesc")[0]
if "AMD" in desc or "Radeon" in desc:
print(f" :: Detected GPU via Windows registry: {desc}")
return gpu_name_to_gfx(desc)
except FileNotFoundError:
pass
i += 1
except OSError:
break
except ImportError:
pass
# Method 2: Try WMIC command
try:
import subprocess
result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'],
capture_output=True, text=True, timeout=10)
if result.returncode == 0:
for line in result.stdout.split('\n'):
line = line.strip()
if line and "AMD" in line or "Radeon" in line:
print(f" :: Detected GPU via WMIC: {line}")
return gpu_name_to_gfx(line)
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
print(" :: Could not detect AMD GPU architecture automatically")
return None
except Exception as e:
print(f" :: GPU detection failed: {str(e)}")
return None
def gpu_name_to_gfx(gpu_name):
"""
Map GPU names to their corresponding gfx architecture codes
"""
gpu_name_lower = gpu_name.lower()
# RDNA3 (gfx11xx)
if any(x in gpu_name_lower for x in ['rx 7900', 'rx 7800', 'rx 7700', 'rx 7600', 'rx 7500']):
if 'rx 7900' in gpu_name_lower:
return 'gfx1100' # Navi 31
elif 'rx 7800' in gpu_name_lower or 'rx 7700' in gpu_name_lower:
return 'gfx1101' # Navi 32
elif 'rx 7600' in gpu_name_lower or 'rx 7500' in gpu_name_lower:
return 'gfx1102' # Navi 33
# RDNA2 (gfx10xx)
elif any(x in gpu_name_lower for x in ['rx 6950', 'rx 6900', 'rx 6800', 'rx 6750', 'rx 6700']):
return 'gfx1030' # Navi 21/22
elif any(x in gpu_name_lower for x in ['rx 6650', 'rx 6600', 'rx 6500', 'rx 6400']):
return 'gfx1032' # Navi 23/24
# RDNA1 (gfx10xx)
elif any(x in gpu_name_lower for x in ['rx 5700', 'rx 5600', 'rx 5500']):
return 'gfx1010' # Navi 10
# Vega (gfx9xx)
elif any(x in gpu_name_lower for x in ['vega 64', 'vega 56', 'vega 20', 'radeon vii']):
return 'gfx900' # Vega 10/20
elif 'vega 11' in gpu_name_lower or 'vega 8' in gpu_name_lower:
return 'gfx902' # Raven Ridge APU
# Polaris (gfx8xx)
elif any(x in gpu_name_lower for x in ['rx 580', 'rx 570', 'rx 480', 'rx 470']):
return 'gfx803' # Polaris 10/20
elif any(x in gpu_name_lower for x in ['rx 560', 'rx 550', 'rx 460']):
return 'gfx803' # Polaris 11/12
# Default fallback - try to extract numbers and make educated guess
if 'rx 9' in gpu_name_lower: # Future RDNA4?
return 'gfx1200' # Anticipated next gen
elif 'rx 8' in gpu_name_lower: # Future RDNA4?
return 'gfx1150' # Anticipated next gen
elif 'rx 7' in gpu_name_lower:
return 'gfx1100' # Default RDNA3
elif 'rx 6' in gpu_name_lower:
return 'gfx1030' # Default RDNA2
elif 'rx 5' in gpu_name_lower:
return 'gfx1010' # Default RDNA1
print(f" :: Unknown GPU model: {gpu_name}, using default gfx1030")
return 'gfx1030' # Safe default for most modern AMD GPUs
def set_triton_arch_override():
"""
Automatically detect and set TRITON_OVERRIDE_ARCH environment variable
"""
# Check if already set by user
if 'TRITON_OVERRIDE_ARCH' in os.environ:
print(f" :: TRITON_OVERRIDE_ARCH already set to: {os.environ['TRITON_OVERRIDE_ARCH']}")
return
print(" :: Auto-detecting AMD GPU architecture for Triton...")
gfx_arch = detect_amd_gpu_architecture()
if gfx_arch:
os.environ['TRITON_OVERRIDE_ARCH'] = gfx_arch
print(f" :: Set TRITON_OVERRIDE_ARCH={gfx_arch}")
else:
# Fallback to a common architecture
fallback_arch = 'gfx1030'
os.environ['TRITON_OVERRIDE_ARCH'] = fallback_arch
print(f" :: Using fallback TRITON_OVERRIDE_ARCH={fallback_arch}")
print(" :: If Triton fails, you may need to manually set TRITON_OVERRIDE_ARCH in your environment")
# ------------------- gfx detection -------------------
# ------------------- ComfyUI Package Version Check -------------------
def get_package_version(package_name):
try:
@ -288,8 +416,16 @@ for package_name in packages_to_monitor:
print(" :: Package version check complete.")
# ------------------- End Version Check -------------------
# ------------------- Triton Setup -------------------
print("\n :: ------------------------ ZLUDA ----------------------- :: ")
# identify device and set triton arch override
zluda_device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else ""
is_zluda = zluda_device_name.endswith("[ZLUDA]")
if is_zluda:
set_triton_arch_override()
try:
import triton
import triton.language as tl
@ -333,11 +469,6 @@ except Exception as e:
triton_available = False
# ------------------- End Triton Verification -------------------
# ------------------- ZLUDA Detection -------------------
zluda_device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else ""
is_zluda = zluda_device_name.endswith("[ZLUDA]")
# ------------------- End Detection --------------------
# # ------------------- ZLUDA Core Implementation -------------------
MEM_BUS_WIDTH = {
"AMD Radeon RX 9070 XT": 256,