mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
device identification and setting triton arch override
This commit is contained in:
parent
d823c0c615
commit
37415c40c1
@ -35,6 +35,134 @@ from typing import Union, List
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
# ------------------- main imports -------------------
|
# ------------------- main imports -------------------
|
||||||
|
|
||||||
|
# ------------------- gfx detection -------------------
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
def detect_amd_gpu_architecture():
|
||||||
|
"""
|
||||||
|
Detect AMD GPU architecture on Windows and return the appropriate gfx code for TRITON_OVERRIDE_ARCH
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Method 1: Try Windows registry
|
||||||
|
try:
|
||||||
|
import winreg
|
||||||
|
key_path = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e968-e325-11ce-bfc1-08002be10318}"
|
||||||
|
with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_path) as key:
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
subkey_name = winreg.EnumKey(key, i)
|
||||||
|
with winreg.OpenKey(key, subkey_name) as subkey:
|
||||||
|
try:
|
||||||
|
desc = winreg.QueryValueEx(subkey, "DriverDesc")[0]
|
||||||
|
if "AMD" in desc or "Radeon" in desc:
|
||||||
|
print(f" :: Detected GPU via Windows registry: {desc}")
|
||||||
|
return gpu_name_to_gfx(desc)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
i += 1
|
||||||
|
except OSError:
|
||||||
|
break
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Method 2: Try WMIC command
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'],
|
||||||
|
capture_output=True, text=True, timeout=10)
|
||||||
|
if result.returncode == 0:
|
||||||
|
for line in result.stdout.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if line and "AMD" in line or "Radeon" in line:
|
||||||
|
print(f" :: Detected GPU via WMIC: {line}")
|
||||||
|
return gpu_name_to_gfx(line)
|
||||||
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||||
|
pass
|
||||||
|
|
||||||
|
print(" :: Could not detect AMD GPU architecture automatically")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" :: GPU detection failed: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def gpu_name_to_gfx(gpu_name):
|
||||||
|
"""
|
||||||
|
Map GPU names to their corresponding gfx architecture codes
|
||||||
|
"""
|
||||||
|
gpu_name_lower = gpu_name.lower()
|
||||||
|
|
||||||
|
# RDNA3 (gfx11xx)
|
||||||
|
if any(x in gpu_name_lower for x in ['rx 7900', 'rx 7800', 'rx 7700', 'rx 7600', 'rx 7500']):
|
||||||
|
if 'rx 7900' in gpu_name_lower:
|
||||||
|
return 'gfx1100' # Navi 31
|
||||||
|
elif 'rx 7800' in gpu_name_lower or 'rx 7700' in gpu_name_lower:
|
||||||
|
return 'gfx1101' # Navi 32
|
||||||
|
elif 'rx 7600' in gpu_name_lower or 'rx 7500' in gpu_name_lower:
|
||||||
|
return 'gfx1102' # Navi 33
|
||||||
|
|
||||||
|
# RDNA2 (gfx10xx)
|
||||||
|
elif any(x in gpu_name_lower for x in ['rx 6950', 'rx 6900', 'rx 6800', 'rx 6750', 'rx 6700']):
|
||||||
|
return 'gfx1030' # Navi 21/22
|
||||||
|
elif any(x in gpu_name_lower for x in ['rx 6650', 'rx 6600', 'rx 6500', 'rx 6400']):
|
||||||
|
return 'gfx1032' # Navi 23/24
|
||||||
|
|
||||||
|
# RDNA1 (gfx10xx)
|
||||||
|
elif any(x in gpu_name_lower for x in ['rx 5700', 'rx 5600', 'rx 5500']):
|
||||||
|
return 'gfx1010' # Navi 10
|
||||||
|
|
||||||
|
# Vega (gfx9xx)
|
||||||
|
elif any(x in gpu_name_lower for x in ['vega 64', 'vega 56', 'vega 20', 'radeon vii']):
|
||||||
|
return 'gfx900' # Vega 10/20
|
||||||
|
elif 'vega 11' in gpu_name_lower or 'vega 8' in gpu_name_lower:
|
||||||
|
return 'gfx902' # Raven Ridge APU
|
||||||
|
|
||||||
|
# Polaris (gfx8xx)
|
||||||
|
elif any(x in gpu_name_lower for x in ['rx 580', 'rx 570', 'rx 480', 'rx 470']):
|
||||||
|
return 'gfx803' # Polaris 10/20
|
||||||
|
elif any(x in gpu_name_lower for x in ['rx 560', 'rx 550', 'rx 460']):
|
||||||
|
return 'gfx803' # Polaris 11/12
|
||||||
|
|
||||||
|
# Default fallback - try to extract numbers and make educated guess
|
||||||
|
if 'rx 9' in gpu_name_lower: # Future RDNA4?
|
||||||
|
return 'gfx1200' # Anticipated next gen
|
||||||
|
elif 'rx 8' in gpu_name_lower: # Future RDNA4?
|
||||||
|
return 'gfx1150' # Anticipated next gen
|
||||||
|
elif 'rx 7' in gpu_name_lower:
|
||||||
|
return 'gfx1100' # Default RDNA3
|
||||||
|
elif 'rx 6' in gpu_name_lower:
|
||||||
|
return 'gfx1030' # Default RDNA2
|
||||||
|
elif 'rx 5' in gpu_name_lower:
|
||||||
|
return 'gfx1010' # Default RDNA1
|
||||||
|
|
||||||
|
print(f" :: Unknown GPU model: {gpu_name}, using default gfx1030")
|
||||||
|
return 'gfx1030' # Safe default for most modern AMD GPUs
|
||||||
|
|
||||||
|
def set_triton_arch_override():
|
||||||
|
"""
|
||||||
|
Automatically detect and set TRITON_OVERRIDE_ARCH environment variable
|
||||||
|
"""
|
||||||
|
# Check if already set by user
|
||||||
|
if 'TRITON_OVERRIDE_ARCH' in os.environ:
|
||||||
|
print(f" :: TRITON_OVERRIDE_ARCH already set to: {os.environ['TRITON_OVERRIDE_ARCH']}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(" :: Auto-detecting AMD GPU architecture for Triton...")
|
||||||
|
gfx_arch = detect_amd_gpu_architecture()
|
||||||
|
|
||||||
|
if gfx_arch:
|
||||||
|
os.environ['TRITON_OVERRIDE_ARCH'] = gfx_arch
|
||||||
|
print(f" :: Set TRITON_OVERRIDE_ARCH={gfx_arch}")
|
||||||
|
else:
|
||||||
|
# Fallback to a common architecture
|
||||||
|
fallback_arch = 'gfx1030'
|
||||||
|
os.environ['TRITON_OVERRIDE_ARCH'] = fallback_arch
|
||||||
|
print(f" :: Using fallback TRITON_OVERRIDE_ARCH={fallback_arch}")
|
||||||
|
print(" :: If Triton fails, you may need to manually set TRITON_OVERRIDE_ARCH in your environment")
|
||||||
|
# ------------------- gfx detection -------------------
|
||||||
|
|
||||||
# ------------------- ComfyUI Package Version Check -------------------
|
# ------------------- ComfyUI Package Version Check -------------------
|
||||||
def get_package_version(package_name):
|
def get_package_version(package_name):
|
||||||
try:
|
try:
|
||||||
@ -288,8 +416,16 @@ for package_name in packages_to_monitor:
|
|||||||
|
|
||||||
print(" :: Package version check complete.")
|
print(" :: Package version check complete.")
|
||||||
# ------------------- End Version Check -------------------
|
# ------------------- End Version Check -------------------
|
||||||
|
|
||||||
# ------------------- Triton Setup -------------------
|
# ------------------- Triton Setup -------------------
|
||||||
print("\n :: ------------------------ ZLUDA ----------------------- :: ")
|
print("\n :: ------------------------ ZLUDA ----------------------- :: ")
|
||||||
|
|
||||||
|
# identify device and set triton arch override
|
||||||
|
zluda_device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else ""
|
||||||
|
is_zluda = zluda_device_name.endswith("[ZLUDA]")
|
||||||
|
if is_zluda:
|
||||||
|
set_triton_arch_override()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@ -333,11 +469,6 @@ except Exception as e:
|
|||||||
triton_available = False
|
triton_available = False
|
||||||
# ------------------- End Triton Verification -------------------
|
# ------------------- End Triton Verification -------------------
|
||||||
|
|
||||||
# ------------------- ZLUDA Detection -------------------
|
|
||||||
zluda_device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else ""
|
|
||||||
is_zluda = zluda_device_name.endswith("[ZLUDA]")
|
|
||||||
# ------------------- End Detection --------------------
|
|
||||||
|
|
||||||
# # ------------------- ZLUDA Core Implementation -------------------
|
# # ------------------- ZLUDA Core Implementation -------------------
|
||||||
MEM_BUS_WIDTH = {
|
MEM_BUS_WIDTH = {
|
||||||
"AMD Radeon RX 9070 XT": 256,
|
"AMD Radeon RX 9070 XT": 256,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user