From 33c43b68c3c996a6ae4c4fded91660c18cd662b0 Mon Sep 17 00:00:00 2001 From: Christopher Anderson Date: Mon, 25 Aug 2025 09:38:22 +1000 Subject: [PATCH] worst PR ever --- comfy/customzluda/zluda.py | 176 +++++++++++++++++++++---------------- 1 file changed, 100 insertions(+), 76 deletions(-) diff --git a/comfy/customzluda/zluda.py b/comfy/customzluda/zluda.py index 013cdc058..97dc425bb 100644 --- a/comfy/customzluda/zluda.py +++ b/comfy/customzluda/zluda.py @@ -66,12 +66,12 @@ def detect_amd_gpu_architecture(): break except ImportError: pass - + # Method 2: Try WMIC command try: import subprocess - result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'], - capture_output=True, text=True, timeout=10) + result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'], + capture_output=True, text=True, timeout=10) if result.returncode == 0: for line in result.stdout.split('\n'): line = line.strip() @@ -80,10 +80,10 @@ def detect_amd_gpu_architecture(): return gpu_name_to_gfx(line) except (FileNotFoundError, subprocess.TimeoutExpired): pass - + print(" :: Could not detect AMD GPU architecture automatically") return None - + except Exception as e: print(f" :: GPU detection failed: {str(e)}") return None @@ -93,7 +93,7 @@ def gpu_name_to_gfx(gpu_name): Map GPU names to their corresponding gfx architecture codes """ gpu_name_lower = gpu_name.lower() - + # RDNA3 (gfx11xx) if any(x in gpu_name_lower for x in ['rx 7900', 'rx 7800', 'rx 7700', 'rx 7600', 'rx 7500']): if 'rx 7900' in gpu_name_lower: @@ -102,29 +102,29 @@ def gpu_name_to_gfx(gpu_name): return 'gfx1101' # Navi 32 elif 'rx 7600' in gpu_name_lower or 'rx 7500' in gpu_name_lower: return 'gfx1102' # Navi 33 - + # RDNA2 (gfx10xx) elif any(x in gpu_name_lower for x in ['rx 6950', 'rx 6900', 'rx 6800', 'rx 6750', 'rx 6700']): return 'gfx1030' # Navi 21/22 elif any(x in gpu_name_lower for x in ['rx 6650', 'rx 6600', 'rx 6500', 'rx 6400']): return 'gfx1032' # Navi 23/24 - + # RDNA1 (gfx10xx) elif any(x in gpu_name_lower for x in ['rx 5700', 'rx 5600', 'rx 5500']): return 'gfx1010' # Navi 10 - + # Vega (gfx9xx) elif any(x in gpu_name_lower for x in ['vega 64', 'vega 56', 'vega 20', 'radeon vii']): return 'gfx900' # Vega 10/20 elif 'vega 11' in gpu_name_lower or 'vega 8' in gpu_name_lower: return 'gfx902' # Raven Ridge APU - + # Polaris (gfx8xx) elif any(x in gpu_name_lower for x in ['rx 580', 'rx 570', 'rx 480', 'rx 470']): return 'gfx803' # Polaris 10/20 elif any(x in gpu_name_lower for x in ['rx 560', 'rx 550', 'rx 460']): return 'gfx803' # Polaris 11/12 - + # Default fallback - try to extract numbers and make educated guess if 'rx 9' in gpu_name_lower: # Future RDNA4? return 'gfx1200' # Anticipated next gen @@ -136,7 +136,7 @@ def gpu_name_to_gfx(gpu_name): return 'gfx1030' # Default RDNA2 elif 'rx 5' in gpu_name_lower: return 'gfx1010' # Default RDNA1 - + print(f" :: Unknown GPU model: {gpu_name}, using default gfx1030") return 'gfx1030' # Safe default for most modern AMD GPUs @@ -148,10 +148,10 @@ def set_triton_arch_override(): if 'TRITON_OVERRIDE_ARCH' in os.environ: print(f" :: TRITON_OVERRIDE_ARCH already set to: {os.environ['TRITON_OVERRIDE_ARCH']}") return - + print(" :: Auto-detecting AMD GPU architecture for Triton...") gfx_arch = detect_amd_gpu_architecture() - + if gfx_arch: os.environ['TRITON_OVERRIDE_ARCH'] = gfx_arch print(f" :: Set TRITON_OVERRIDE_ARCH={gfx_arch}") @@ -201,7 +201,7 @@ def is_compatible_version(installed_version, required_version, operator='>='): from packaging import version installed_v = version.parse(installed_version) required_v = version.parse(required_version) - + if operator == '>=': return installed_v >= required_v elif operator == '==': @@ -212,11 +212,11 @@ def is_compatible_version(installed_version, required_version, operator='>='): required_parts = required_v.release if len(required_parts) == 1: # ~=2 means >=2.0, <3.0 - return (installed_v >= required_v and + return (installed_v >= required_v and installed_v.release[0] == required_parts[0]) else: # ~=2.1 means >=2.1, <2.2 - return (installed_v >= required_v and + return (installed_v >= required_v and installed_v.release[:len(required_parts)-1] == required_parts[:-1] and installed_v.release[len(required_parts)-1] >= required_parts[-1]) else: @@ -255,86 +255,86 @@ def handle_pydantic_packages(required_packages): """Special handling for pydantic packages to ensure compatibility""" import subprocess import sys - + pydantic_packages = ['pydantic', 'pydantic-settings'] packages_in_requirements = [pkg for pkg in pydantic_packages if pkg in required_packages] - + if not packages_in_requirements: return # No pydantic packages to handle - + # Check if both packages are available and what versions pydantic_installed = None pydantic_settings_installed = None - + try: pydantic_installed = get_package_version('pydantic') except: pass - + try: pydantic_settings_installed = get_package_version('pydantic-settings') except: pass - + # If both are installed, check compatibility if pydantic_installed and pydantic_settings_installed: print(f"Found pydantic: {pydantic_installed}, pydantic-settings: {pydantic_settings_installed}") - + # Check if they're compatible by testing the import if not check_pydantic_compatibility(): print(" :: Pydantic packages are compatible, skipping reinstall") return else: print(" :: Pydantic packages are incompatible, need to reinstall") - + # If we get here, we need to install/reinstall pydantic packages print(" :: Setting up pydantic packages for compatibility...") - + # Uninstall existing versions to avoid conflicts if pydantic_installed: print(f" :: Uninstalling existing pydantic {pydantic_installed}") uninstall_package('pydantic') - + if pydantic_settings_installed: print(f" :: Uninstalling existing pydantic-settings {pydantic_settings_installed}") uninstall_package('pydantic-settings') - + # Install both packages together try: print(" :: Installing compatible pydantic packages...") - combined_args = [sys.executable, '-m', 'pip', 'install', - 'pydantic~=2.0', - 'pydantic-settings~=2.0', - '--quiet', - '--disable-pip-version-check'] - + combined_args = [sys.executable, '-m', 'pip', 'install', + 'pydantic~=2.0', + 'pydantic-settings~=2.0', + '--quiet', + '--disable-pip-version-check'] + subprocess.check_call(combined_args) - + # Verify installation new_pydantic = get_package_version('pydantic') new_pydantic_settings = get_package_version('pydantic-settings') print(f" :: Successfully installed pydantic: {new_pydantic}, pydantic-settings: {new_pydantic_settings}") - + except subprocess.CalledProcessError as e: print(f" :: Failed to install pydantic packages: {e}") def install_package(package_name, version_spec, upgrade=False): import subprocess import sys - + # For ~= operator, install with the compatible release syntax if '~=' in version_spec: package_spec = f'{package_name}~={version_spec}' else: package_spec = f'{package_name}=={version_spec}' - - args = [sys.executable, '-m', 'pip', 'install', - package_spec, - '--quiet', + + args = [sys.executable, '-m', 'pip', 'install', + package_spec, + '--quiet', '--disable-pip-version-check'] if upgrade: args.append('--upgrade') - + try: subprocess.check_call(args) except subprocess.CalledProcessError as e: @@ -343,11 +343,11 @@ def install_package(package_name, version_spec, upgrade=False): if upgrade and '~=' in package_spec: try: print(f" :: Retrying {package_name} installation without version constraint...") - fallback_args = [sys.executable, '-m', 'pip', 'install', - package_name, - '--upgrade', - '--quiet', - '--disable-pip-version-check'] + fallback_args = [sys.executable, '-m', 'pip', 'install', + package_name, + '--upgrade', + '--quiet', + '--disable-pip-version-check'] subprocess.check_call(fallback_args) print(f" :: {package_name} installed successfully without version constraint") except subprocess.CalledProcessError as e2: @@ -357,11 +357,11 @@ def ensure_package(package_name, required_version, operator='>='): # Skip individual pydantic package handling - they're handled together if package_name in ['pydantic', 'pydantic-settings']: return - + try: installed_version = get_package_version(package_name) print(f"Installed version of {package_name}: {installed_version}") - + if not is_compatible_version(installed_version, required_version, operator): install_package(package_name, required_version, upgrade=True) print(f"\n{package_name} outdated. Upgraded to {required_version}.") @@ -430,7 +430,7 @@ try: import triton import triton.language as tl print(" :: Triton core imported successfully") - + @triton.jit def _zluda_kernel_test(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) @@ -438,7 +438,7 @@ try: mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) tl.store(y_ptr + offsets, x + 1, mask=mask) - + def _verify_triton() -> bool: try: print(" :: Running Triton kernel test...") @@ -453,7 +453,7 @@ try: except Exception as e: print(f" :: Triton test failed: {str(e)}") return False - + triton_available = _verify_triton() if triton_available: print(" :: Triton initialized successfully") @@ -473,7 +473,9 @@ except Exception as e: MEM_BUS_WIDTH = { "AMD Radeon RX 9070 XT": 256, "AMD Radeon RX 9070": 256, - "AMD Radeon RX 9060 XT": 192, + "AMD Radeon RX 9070 GRE": 192, + "AMD Radeon RX 9060 XT": 128, + "AMD Radeon RX 9060": 128, "AMD Radeon RX 7900 XTX": 384, "AMD Radeon RX 7900 XT": 320, "AMD Radeon RX 7900 GRE": 256, @@ -483,12 +485,14 @@ MEM_BUS_WIDTH = { "AMD Radeon RX 7650 GRE": 128, "AMD Radeon RX 7600 XT": 128, "AMD Radeon RX 7600": 128, - "AMD Radeon RX 7500 XT": 96, + "AMD Radeon RX 7400": 128, "AMD Radeon RX 6950 XT": 256, "AMD Radeon RX 6900 XT": 256, "AMD Radeon RX 6800 XT": 256, "AMD Radeon RX 6800": 256, "AMD Radeon RX 6750 XT": 192, + "AMD Radeon RX 6750 GRE 12GB": 192, + "AMD Radeon RX 6750 GRE 10GB": 160, "AMD Radeon RX 6700 XT": 192, "AMD Radeon RX 6700": 160, "AMD Radeon RX 6650 XT": 128, @@ -496,6 +500,26 @@ MEM_BUS_WIDTH = { "AMD Radeon RX 6600": 128, "AMD Radeon RX 6500 XT": 64, "AMD Radeon RX 6400": 64, + "AMD Radeon RX 5700 XT": 256, + "AMD Radeon RX 5700": 256, + "AMD Radeon RX 5600 XT": 192, + "AMD Radeon RX 5500 XT": 128, + "AMD Radeon RX 5500": 128, + "AMD Radeon RX 5300": 96, + # AMD Radeon Pro R9000/W7000/W6000/W5000 series, Apple exclusive WX series not listed + "AMD Radeon AI PRO R9700": 256, + "AMD Radeon PRO W7900": 384, + "AMD Radeon PRO W7800 48GB": 384, + "AMD Radeon PRO W7800": 256, + "AMD Radeon PRO W7700": 256, + "AMD Radeon PRO W7600": 128, + "AMD Radeon PRO W7500": 128, + "AMD Radeon PRO W7400": 128, + "AMD Radeon PRO W6800": 256, + "AMD Radeon PRO W6600": 128, + "AMD Radeon PRO W6400": 64, + "AMD Radeon PRO W5700": 256, + "AMD Radeon PRO W5500": 128, } # ------------------- Device Properties Implementation ------------------- @@ -513,24 +537,24 @@ class DeviceProperties: # # ------------------- Audio Ops Patch ------------------- # if is_zluda: - # _torch_stft = torch.stft - # _torch_istft = torch.istft +# _torch_stft = torch.stft +# _torch_istft = torch.istft - # def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): - # return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) +# def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): +# return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) - # def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): - # return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) +# def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): +# return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) - # def z_jit(f, *_, **__): - # f.graph = torch._C.Graph() - # return f +# def z_jit(f, *_, **__): +# f.graph = torch._C.Graph() +# return f - # torch._dynamo.config.suppress_errors = True - # torch.stft = z_stft - # torch.istft = z_istft - # torch.jit.script = z_jit -# # ------------------- End Audio Patch ------------------- +# torch._dynamo.config.suppress_errors = True +# torch.stft = z_stft +# torch.istft = z_istft +# torch.jit.script = z_jit +# # ------------------- End Audio Patch ------------------- # ------------------- Top-K Fallback Patch ------------------- if is_zluda: @@ -584,7 +608,7 @@ def do_hijack(): return print(f" :: Using ZLUDA with device: {zluda_device_name}") print(" :: Applying core ZLUDA patches...") - + # 2. Triton optimizations if triton_available: print(" :: Initializing Triton optimizations") @@ -607,14 +631,14 @@ def do_hijack(): try: from comfy.flash_attn_triton_amd import interface_fa print(" :: Flash attention components found") - + original_sdpa = torch.nn.functional.scaled_dot_product_attention - + def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): try: - if (query.shape[-1] <= 128 and - attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous" - query.dtype != torch.float32): + if (query.shape[-1] <= 128 and + attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous" + query.dtype != torch.float32): if scale is None: scale = query.shape[-1] ** -0.5 return interface_fa.fwd( @@ -627,11 +651,11 @@ def do_hijack(): except Exception as e: print(f' :: Flash attention error: {str(e)}') return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - + torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper flash_enabled = True print(" :: AMD flash attention enabled successfully") - + except ImportError: print(" :: Flash attention components not installed") except Exception as e: