This commit is contained in:
Stuart Axon 2026-01-05 00:38:09 +08:00 committed by GitHub
commit a68686be3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -40,6 +40,7 @@ class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
OCL = 3
# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
@ -102,6 +103,14 @@ if args.directml is not None:
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
ocl_available = False
try:
import pytorch_ocl
import torch.ocl
ocl_available = True
except ImportError:
pass
try:
import intel_extension_for_pytorch as ipex # noqa: F401
except:
@ -139,6 +148,10 @@ try:
except:
ixuca_available = False
if ocl_available:
# TODO gate behind flag.
cpu_state = CPUState.OCL
if args.cpu:
cpu_state = CPUState.CPU
@ -168,6 +181,12 @@ def is_ixuca():
return True
return False
def is_ocl():
global ocl_available
if ocl_available:
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@ -178,6 +197,8 @@ def get_torch_device():
return torch.device("mps")
if cpu_state == CPUState.CPU:
return torch.device("cpu")
if cpu_state == CPUState.OCL:
return torch.device("ocl:0")
else:
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
@ -193,7 +214,7 @@ def get_total_memory(dev=None, torch_total_too=False):
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
@ -218,6 +239,9 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
elif is_ocl():
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -1280,7 +1304,7 @@ def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
@ -1308,6 +1332,15 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
elif is_ocl():
# stats = torch.ocl.memory_stats(dev)
# mem_active = stats['active_bytes.all.current']
# mem_reserved = stats['reserved_bytes.all.current']
# mem_free_ocl, _ = torch.ocl.mem_get_info(dev)
# mem_free_torch = mem_reserved - mem_active
# mem_free_total = mem_free_mlu + mem_free_torch
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@ -1386,6 +1419,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu():
return True
if is_ocl():
# TODO ? RustiCL now supports fp16 at least.
return True
if is_ixuca():
return True
@ -1462,6 +1499,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True
return False
if is_ocl():
# TODO
return True
props = torch.cuda.get_device_properties(device)
if is_mlu():