worst PR ever

This commit is contained in:
Christopher Anderson 2025-08-25 09:38:22 +10:00
parent 2a06dc8e87
commit 33c43b68c3

View File

@ -71,7 +71,7 @@ def detect_amd_gpu_architecture():
try: try:
import subprocess import subprocess
result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'], result = subprocess.run(['wmic', 'path', 'win32_VideoController', 'get', 'name'],
capture_output=True, text=True, timeout=10) capture_output=True, text=True, timeout=10)
if result.returncode == 0: if result.returncode == 0:
for line in result.stdout.split('\n'): for line in result.stdout.split('\n'):
line = line.strip() line = line.strip()
@ -303,10 +303,10 @@ def handle_pydantic_packages(required_packages):
try: try:
print(" :: Installing compatible pydantic packages...") print(" :: Installing compatible pydantic packages...")
combined_args = [sys.executable, '-m', 'pip', 'install', combined_args = [sys.executable, '-m', 'pip', 'install',
'pydantic~=2.0', 'pydantic~=2.0',
'pydantic-settings~=2.0', 'pydantic-settings~=2.0',
'--quiet', '--quiet',
'--disable-pip-version-check'] '--disable-pip-version-check']
subprocess.check_call(combined_args) subprocess.check_call(combined_args)
@ -344,10 +344,10 @@ def install_package(package_name, version_spec, upgrade=False):
try: try:
print(f" :: Retrying {package_name} installation without version constraint...") print(f" :: Retrying {package_name} installation without version constraint...")
fallback_args = [sys.executable, '-m', 'pip', 'install', fallback_args = [sys.executable, '-m', 'pip', 'install',
package_name, package_name,
'--upgrade', '--upgrade',
'--quiet', '--quiet',
'--disable-pip-version-check'] '--disable-pip-version-check']
subprocess.check_call(fallback_args) subprocess.check_call(fallback_args)
print(f" :: {package_name} installed successfully without version constraint") print(f" :: {package_name} installed successfully without version constraint")
except subprocess.CalledProcessError as e2: except subprocess.CalledProcessError as e2:
@ -473,7 +473,9 @@ except Exception as e:
MEM_BUS_WIDTH = { MEM_BUS_WIDTH = {
"AMD Radeon RX 9070 XT": 256, "AMD Radeon RX 9070 XT": 256,
"AMD Radeon RX 9070": 256, "AMD Radeon RX 9070": 256,
"AMD Radeon RX 9060 XT": 192, "AMD Radeon RX 9070 GRE": 192,
"AMD Radeon RX 9060 XT": 128,
"AMD Radeon RX 9060": 128,
"AMD Radeon RX 7900 XTX": 384, "AMD Radeon RX 7900 XTX": 384,
"AMD Radeon RX 7900 XT": 320, "AMD Radeon RX 7900 XT": 320,
"AMD Radeon RX 7900 GRE": 256, "AMD Radeon RX 7900 GRE": 256,
@ -483,12 +485,14 @@ MEM_BUS_WIDTH = {
"AMD Radeon RX 7650 GRE": 128, "AMD Radeon RX 7650 GRE": 128,
"AMD Radeon RX 7600 XT": 128, "AMD Radeon RX 7600 XT": 128,
"AMD Radeon RX 7600": 128, "AMD Radeon RX 7600": 128,
"AMD Radeon RX 7500 XT": 96, "AMD Radeon RX 7400": 128,
"AMD Radeon RX 6950 XT": 256, "AMD Radeon RX 6950 XT": 256,
"AMD Radeon RX 6900 XT": 256, "AMD Radeon RX 6900 XT": 256,
"AMD Radeon RX 6800 XT": 256, "AMD Radeon RX 6800 XT": 256,
"AMD Radeon RX 6800": 256, "AMD Radeon RX 6800": 256,
"AMD Radeon RX 6750 XT": 192, "AMD Radeon RX 6750 XT": 192,
"AMD Radeon RX 6750 GRE 12GB": 192,
"AMD Radeon RX 6750 GRE 10GB": 160,
"AMD Radeon RX 6700 XT": 192, "AMD Radeon RX 6700 XT": 192,
"AMD Radeon RX 6700": 160, "AMD Radeon RX 6700": 160,
"AMD Radeon RX 6650 XT": 128, "AMD Radeon RX 6650 XT": 128,
@ -496,6 +500,26 @@ MEM_BUS_WIDTH = {
"AMD Radeon RX 6600": 128, "AMD Radeon RX 6600": 128,
"AMD Radeon RX 6500 XT": 64, "AMD Radeon RX 6500 XT": 64,
"AMD Radeon RX 6400": 64, "AMD Radeon RX 6400": 64,
"AMD Radeon RX 5700 XT": 256,
"AMD Radeon RX 5700": 256,
"AMD Radeon RX 5600 XT": 192,
"AMD Radeon RX 5500 XT": 128,
"AMD Radeon RX 5500": 128,
"AMD Radeon RX 5300": 96,
# AMD Radeon Pro R9000/W7000/W6000/W5000 series, Apple exclusive WX series not listed
"AMD Radeon AI PRO R9700": 256,
"AMD Radeon PRO W7900": 384,
"AMD Radeon PRO W7800 48GB": 384,
"AMD Radeon PRO W7800": 256,
"AMD Radeon PRO W7700": 256,
"AMD Radeon PRO W7600": 128,
"AMD Radeon PRO W7500": 128,
"AMD Radeon PRO W7400": 128,
"AMD Radeon PRO W6800": 256,
"AMD Radeon PRO W6600": 128,
"AMD Radeon PRO W6400": 64,
"AMD Radeon PRO W5700": 256,
"AMD Radeon PRO W5500": 128,
} }
# ------------------- Device Properties Implementation ------------------- # ------------------- Device Properties Implementation -------------------
@ -513,23 +537,23 @@ class DeviceProperties:
# # ------------------- Audio Ops Patch ------------------- # # ------------------- Audio Ops Patch -------------------
# if is_zluda: # if is_zluda:
# _torch_stft = torch.stft # _torch_stft = torch.stft
# _torch_istft = torch.istft # _torch_istft = torch.istft
# def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): # def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
# return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) # return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device)
# def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): # def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
# return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) # return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device)
# def z_jit(f, *_, **__): # def z_jit(f, *_, **__):
# f.graph = torch._C.Graph() # f.graph = torch._C.Graph()
# return f # return f
# torch._dynamo.config.suppress_errors = True # torch._dynamo.config.suppress_errors = True
# torch.stft = z_stft # torch.stft = z_stft
# torch.istft = z_istft # torch.istft = z_istft
# torch.jit.script = z_jit # torch.jit.script = z_jit
# # ------------------- End Audio Patch ------------------- # # ------------------- End Audio Patch -------------------
# ------------------- Top-K Fallback Patch ------------------- # ------------------- Top-K Fallback Patch -------------------
@ -613,8 +637,8 @@ def do_hijack():
def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
try: try:
if (query.shape[-1] <= 128 and if (query.shape[-1] <= 128 and
attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous" attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
query.dtype != torch.float32): query.dtype != torch.float32):
if scale is None: if scale is None:
scale = query.shape[-1] ** -0.5 scale = query.shape[-1] ** -0.5
return interface_fa.fwd( return interface_fa.fwd(