diff --git a/comfy/customzluda/zluda.py b/comfy/customzluda/zluda.py index faaa9695c..0dcd85613 100644 --- a/comfy/customzluda/zluda.py +++ b/comfy/customzluda/zluda.py @@ -66,11 +66,11 @@ def detect_amd_gpu_architecture(): break except ImportError: pass - + # Method 2: Try WMIC command try: import subprocess - result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'], + result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'], capture_output=True, text=True, timeout=10) if result.returncode == 0: for line in result.stdout.split('\n'): @@ -80,10 +80,10 @@ def detect_amd_gpu_architecture(): return gpu_name_to_gfx(line) except (FileNotFoundError, subprocess.TimeoutExpired): pass - + print(" :: Could not detect AMD GPU architecture automatically") return None - + except Exception as e: print(f" :: GPU detection failed: {str(e)}") return None @@ -93,7 +93,7 @@ def gpu_name_to_gfx(gpu_name): Map GPU names to their corresponding gfx architecture codes """ gpu_name_lower = gpu_name.lower() - + # RDNA3 (gfx11xx) if any(x in gpu_name_lower for x in ['rx 7900', 'rx 7800', 'rx 7700', 'rx 7600', 'rx 7500']): if 'rx 7900' in gpu_name_lower: @@ -102,29 +102,29 @@ def gpu_name_to_gfx(gpu_name): return 'gfx1101' # Navi 32 elif 'rx 7600' in gpu_name_lower or 'rx 7500' in gpu_name_lower: return 'gfx1102' # Navi 33 - + # RDNA2 (gfx10xx) elif any(x in gpu_name_lower for x in ['rx 6950', 'rx 6900', 'rx 6800', 'rx 6750', 'rx 6700']): return 'gfx1030' # Navi 21/22 elif any(x in gpu_name_lower for x in ['rx 6650', 'rx 6600', 'rx 6500', 'rx 6400']): return 'gfx1032' # Navi 23/24 - + # RDNA1 (gfx10xx) elif any(x in gpu_name_lower for x in ['rx 5700', 'rx 5600', 'rx 5500']): return 'gfx1010' # Navi 10 - + # Vega (gfx9xx) elif any(x in gpu_name_lower for x in ['vega 64', 'vega 56', 'vega 20', 'radeon vii']): return 'gfx900' # Vega 10/20 elif 'vega 11' in gpu_name_lower or 'vega 8' in gpu_name_lower: return 'gfx902' # Raven Ridge APU - + # Polaris (gfx8xx) elif any(x in gpu_name_lower for x in ['rx 580', 'rx 570', 'rx 480', 'rx 470']): return 'gfx803' # Polaris 10/20 elif any(x in gpu_name_lower for x in ['rx 560', 'rx 550', 'rx 460']): return 'gfx803' # Polaris 11/12 - + # Default fallback - try to extract numbers and make educated guess if 'rx 9' in gpu_name_lower: # Future RDNA4? return 'gfx1200' # Anticipated next gen @@ -136,7 +136,7 @@ def gpu_name_to_gfx(gpu_name): return 'gfx1030' # Default RDNA2 elif 'rx 5' in gpu_name_lower: return 'gfx1010' # Default RDNA1 - + print(f" :: Unknown GPU model: {gpu_name}, using default gfx1030") return 'gfx1030' # Safe default for most modern AMD GPUs @@ -148,10 +148,10 @@ def set_triton_arch_override(): if 'TRITON_OVERRIDE_ARCH' in os.environ: print(f" :: TRITON_OVERRIDE_ARCH already set to: {os.environ['TRITON_OVERRIDE_ARCH']}") return - + print(" :: Auto-detecting AMD GPU architecture for Triton...") gfx_arch = detect_amd_gpu_architecture() - + if gfx_arch: os.environ['TRITON_OVERRIDE_ARCH'] = gfx_arch print(f" :: Set TRITON_OVERRIDE_ARCH={gfx_arch}") @@ -201,7 +201,7 @@ def is_compatible_version(installed_version, required_version, operator='>='): from packaging import version installed_v = version.parse(installed_version) required_v = version.parse(required_version) - + if operator == '>=': return installed_v >= required_v elif operator == '==': @@ -212,11 +212,11 @@ def is_compatible_version(installed_version, required_version, operator='>='): required_parts = required_v.release if len(required_parts) == 1: # ~=2 means >=2.0, <3.0 - return (installed_v >= required_v and + return (installed_v >= required_v and installed_v.release[0] == required_parts[0]) else: # ~=2.1 means >=2.1, <2.2 - return (installed_v >= required_v and + return (installed_v >= required_v and installed_v.release[:len(required_parts)-1] == required_parts[:-1] and installed_v.release[len(required_parts)-1] >= required_parts[-1]) else: @@ -255,86 +255,86 @@ def handle_pydantic_packages(required_packages): """Special handling for pydantic packages to ensure compatibility""" import subprocess import sys - + pydantic_packages = ['pydantic', 'pydantic-settings'] packages_in_requirements = [pkg for pkg in pydantic_packages if pkg in required_packages] - + if not packages_in_requirements: return # No pydantic packages to handle - + # Check if both packages are available and what versions pydantic_installed = None pydantic_settings_installed = None - + try: pydantic_installed = get_package_version('pydantic') except: pass - + try: pydantic_settings_installed = get_package_version('pydantic-settings') except: pass - + # If both are installed, check compatibility if pydantic_installed and pydantic_settings_installed: print(f"Found pydantic: {pydantic_installed}, pydantic-settings: {pydantic_settings_installed}") - + # Check if they're compatible by testing the import if not check_pydantic_compatibility(): print(" :: Pydantic packages are compatible, skipping reinstall") return else: print(" :: Pydantic packages are incompatible, need to reinstall") - + # If we get here, we need to install/reinstall pydantic packages print(" :: Setting up pydantic packages for compatibility...") - + # Uninstall existing versions to avoid conflicts if pydantic_installed: print(f" :: Uninstalling existing pydantic {pydantic_installed}") uninstall_package('pydantic') - + if pydantic_settings_installed: print(f" :: Uninstalling existing pydantic-settings {pydantic_settings_installed}") uninstall_package('pydantic-settings') - + # Install both packages together try: print(" :: Installing compatible pydantic packages...") - combined_args = [sys.executable, '-m', 'pip', 'install', - 'pydantic~=2.0', + combined_args = [sys.executable, '-m', 'pip', 'install', + 'pydantic~=2.0', 'pydantic-settings~=2.0', - '--quiet', + '--quiet', '--disable-pip-version-check'] - + subprocess.check_call(combined_args) - + # Verify installation new_pydantic = get_package_version('pydantic') new_pydantic_settings = get_package_version('pydantic-settings') print(f" :: Successfully installed pydantic: {new_pydantic}, pydantic-settings: {new_pydantic_settings}") - + except subprocess.CalledProcessError as e: print(f" :: Failed to install pydantic packages: {e}") def install_package(package_name, version_spec, upgrade=False): import subprocess import sys - + # For ~= operator, install with the compatible release syntax if '~=' in version_spec: package_spec = f'{package_name}~={version_spec}' else: package_spec = f'{package_name}=={version_spec}' - - args = [sys.executable, '-m', 'pip', 'install', - package_spec, - '--quiet', + + args = [sys.executable, '-m', 'pip', 'install', + package_spec, + '--quiet', '--disable-pip-version-check'] if upgrade: args.append('--upgrade') - + try: subprocess.check_call(args) except subprocess.CalledProcessError as e: @@ -343,10 +343,10 @@ def install_package(package_name, version_spec, upgrade=False): if upgrade and '~=' in package_spec: try: print(f" :: Retrying {package_name} installation without version constraint...") - fallback_args = [sys.executable, '-m', 'pip', 'install', - package_name, + fallback_args = [sys.executable, '-m', 'pip', 'install', + package_name, '--upgrade', - '--quiet', + '--quiet', '--disable-pip-version-check'] subprocess.check_call(fallback_args) print(f" :: {package_name} installed successfully without version constraint") @@ -357,11 +357,11 @@ def ensure_package(package_name, required_version, operator='>='): # Skip individual pydantic package handling - they're handled together if package_name in ['pydantic', 'pydantic-settings']: return - + try: installed_version = get_package_version(package_name) print(f"Installed version of {package_name}: {installed_version}") - + if not is_compatible_version(installed_version, required_version, operator): install_package(package_name, required_version, upgrade=True) print(f"\n{package_name} outdated. Upgraded to {required_version}.") @@ -430,7 +430,7 @@ try: import triton import triton.language as tl print(" :: Triton core imported successfully") - + @triton.jit def _zluda_kernel_test(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) @@ -438,7 +438,7 @@ try: mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) tl.store(y_ptr + offsets, x + 1, mask=mask) - + def _verify_triton() -> bool: try: print(" :: Running Triton kernel test...") @@ -453,7 +453,7 @@ try: except Exception as e: print(f" :: Triton test failed: {str(e)}") return False - + triton_available = _verify_triton() if triton_available: print(" :: Triton initialized successfully") @@ -519,7 +519,7 @@ MEM_BUS_WIDTH = { "AMD Radeon PRO W6600": 128, "AMD Radeon PRO W6400": 64, "AMD Radeon PRO W5700": 256, - "AMD Radeon PRO W5500": 128, + "AMD Radeon PRO W5500": 128, } # ------------------- Device Properties Implementation ------------------- @@ -554,7 +554,7 @@ class DeviceProperties: # torch.stft = z_stft # torch.istft = z_istft # torch.jit.script = z_jit -# # ------------------- End Audio Patch ------------------- +# # ------------------- End Audio Patch ------------------- # ------------------- Top-K Fallback Patch ------------------- if is_zluda: @@ -608,7 +608,7 @@ def do_hijack(): return print(f" :: Using ZLUDA with device: {zluda_device_name}") print(" :: Applying core ZLUDA patches...") - + # 2. Triton optimizations if triton_available: print(" :: Initializing Triton optimizations") @@ -631,13 +631,13 @@ def do_hijack(): try: from comfy.flash_attn_triton_amd import interface_fa print(" :: Flash attention components found") - + original_sdpa = torch.nn.functional.scaled_dot_product_attention - + def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): try: - if (query.shape[-1] <= 128 and - attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous" + if (query.shape[-1] <= 128 and + attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous" query.dtype != torch.float32): if scale is None: scale = query.shape[-1] ** -0.5 @@ -651,11 +651,11 @@ def do_hijack(): except Exception as e: print(f' :: Flash attention error: {str(e)}') return original_sdpa(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - + torch.nn.functional.scaled_dot_product_attention = amd_flash_wrapper flash_enabled = True print(" :: AMD flash attention enabled successfully") - + except ImportError: print(" :: Flash attention components not installed") except Exception as e: