diff --git a/comfy/customzluda/zluda.py b/comfy/customzluda/zluda.py index 053f28889..ff73a0b5b 100644 --- a/comfy/customzluda/zluda.py +++ b/comfy/customzluda/zluda.py @@ -35,6 +35,134 @@ from typing import Union, List from enum import Enum # ------------------- main imports ------------------- +# ------------------- gfx detection ------------------- +import os +import re + +def detect_amd_gpu_architecture(): + """ + Detect AMD GPU architecture on Windows and return the appropriate gfx code for TRITON_OVERRIDE_ARCH + """ + try: + # Method 1: Try Windows registry + try: + import winreg + key_path = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e968-e325-11ce-bfc1-08002be10318}" + with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_path) as key: + i = 0 + while True: + try: + subkey_name = winreg.EnumKey(key, i) + with winreg.OpenKey(key, subkey_name) as subkey: + try: + desc = winreg.QueryValueEx(subkey, "DriverDesc")[0] + if "AMD" in desc or "Radeon" in desc: + print(f" :: Detected GPU via Windows registry: {desc}") + return gpu_name_to_gfx(desc) + except FileNotFoundError: + pass + i += 1 + except OSError: + break + except ImportError: + pass + + # Method 2: Try WMIC command + try: + import subprocess + result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'], + capture_output=True, text=True, timeout=10) + if result.returncode == 0: + for line in result.stdout.split('\n'): + line = line.strip() + if line and "AMD" in line or "Radeon" in line: + print(f" :: Detected GPU via WMIC: {line}") + return gpu_name_to_gfx(line) + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + print(" :: Could not detect AMD GPU architecture automatically") + return None + + except Exception as e: + print(f" :: GPU detection failed: {str(e)}") + return None + +def gpu_name_to_gfx(gpu_name): + """ + Map GPU names to their corresponding gfx architecture codes + """ + gpu_name_lower = gpu_name.lower() + + # RDNA3 (gfx11xx) + if any(x in gpu_name_lower for x in ['rx 7900', 'rx 7800', 'rx 7700', 'rx 7600', 'rx 7500']): + if 'rx 7900' in gpu_name_lower: + return 'gfx1100' # Navi 31 + elif 'rx 7800' in gpu_name_lower or 'rx 7700' in gpu_name_lower: + return 'gfx1101' # Navi 32 + elif 'rx 7600' in gpu_name_lower or 'rx 7500' in gpu_name_lower: + return 'gfx1102' # Navi 33 + + # RDNA2 (gfx10xx) + elif any(x in gpu_name_lower for x in ['rx 6950', 'rx 6900', 'rx 6800', 'rx 6750', 'rx 6700']): + return 'gfx1030' # Navi 21/22 + elif any(x in gpu_name_lower for x in ['rx 6650', 'rx 6600', 'rx 6500', 'rx 6400']): + return 'gfx1032' # Navi 23/24 + + # RDNA1 (gfx10xx) + elif any(x in gpu_name_lower for x in ['rx 5700', 'rx 5600', 'rx 5500']): + return 'gfx1010' # Navi 10 + + # Vega (gfx9xx) + elif any(x in gpu_name_lower for x in ['vega 64', 'vega 56', 'vega 20', 'radeon vii']): + return 'gfx900' # Vega 10/20 + elif 'vega 11' in gpu_name_lower or 'vega 8' in gpu_name_lower: + return 'gfx902' # Raven Ridge APU + + # Polaris (gfx8xx) + elif any(x in gpu_name_lower for x in ['rx 580', 'rx 570', 'rx 480', 'rx 470']): + return 'gfx803' # Polaris 10/20 + elif any(x in gpu_name_lower for x in ['rx 560', 'rx 550', 'rx 460']): + return 'gfx803' # Polaris 11/12 + + # Default fallback - try to extract numbers and make educated guess + if 'rx 9' in gpu_name_lower: # Future RDNA4? + return 'gfx1200' # Anticipated next gen + elif 'rx 8' in gpu_name_lower: # Future RDNA4? + return 'gfx1150' # Anticipated next gen + elif 'rx 7' in gpu_name_lower: + return 'gfx1100' # Default RDNA3 + elif 'rx 6' in gpu_name_lower: + return 'gfx1030' # Default RDNA2 + elif 'rx 5' in gpu_name_lower: + return 'gfx1010' # Default RDNA1 + + print(f" :: Unknown GPU model: {gpu_name}, using default gfx1030") + return 'gfx1030' # Safe default for most modern AMD GPUs + +def set_triton_arch_override(): + """ + Automatically detect and set TRITON_OVERRIDE_ARCH environment variable + """ + # Check if already set by user + if 'TRITON_OVERRIDE_ARCH' in os.environ: + print(f" :: TRITON_OVERRIDE_ARCH already set to: {os.environ['TRITON_OVERRIDE_ARCH']}") + return + + print(" :: Auto-detecting AMD GPU architecture for Triton...") + gfx_arch = detect_amd_gpu_architecture() + + if gfx_arch: + os.environ['TRITON_OVERRIDE_ARCH'] = gfx_arch + print(f" :: Set TRITON_OVERRIDE_ARCH={gfx_arch}") + else: + # Fallback to a common architecture + fallback_arch = 'gfx1030' + os.environ['TRITON_OVERRIDE_ARCH'] = fallback_arch + print(f" :: Using fallback TRITON_OVERRIDE_ARCH={fallback_arch}") + print(" :: If Triton fails, you may need to manually set TRITON_OVERRIDE_ARCH in your environment") +# ------------------- gfx detection ------------------- + # ------------------- ComfyUI Package Version Check ------------------- def get_package_version(package_name): try: @@ -288,8 +416,16 @@ for package_name in packages_to_monitor: print(" :: Package version check complete.") # ------------------- End Version Check ------------------- + # ------------------- Triton Setup ------------------- print("\n :: ------------------------ ZLUDA ----------------------- :: ") + +# identify device and set triton arch override +zluda_device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "" +is_zluda = zluda_device_name.endswith("[ZLUDA]") +if is_zluda: + set_triton_arch_override() + try: import triton import triton.language as tl @@ -333,11 +469,6 @@ except Exception as e: triton_available = False # ------------------- End Triton Verification ------------------- -# ------------------- ZLUDA Detection ------------------- -zluda_device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "" -is_zluda = zluda_device_name.endswith("[ZLUDA]") -# ------------------- End Detection -------------------- - # # ------------------- ZLUDA Core Implementation ------------------- MEM_BUS_WIDTH = { "AMD Radeon RX 9070 XT": 256,