mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
Merge 9d1617bb21 into acbf08cd60
This commit is contained in:
commit
a68686be3f
@ -40,6 +40,7 @@ class CPUState(Enum):
|
||||
GPU = 0
|
||||
CPU = 1
|
||||
MPS = 2
|
||||
OCL = 3
|
||||
|
||||
# Determine VRAM State
|
||||
vram_state = VRAMState.NORMAL_VRAM
|
||||
@ -102,6 +103,14 @@ if args.directml is not None:
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
ocl_available = False
|
||||
try:
|
||||
import pytorch_ocl
|
||||
import torch.ocl
|
||||
ocl_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
except:
|
||||
@ -139,6 +148,10 @@ try:
|
||||
except:
|
||||
ixuca_available = False
|
||||
|
||||
if ocl_available:
|
||||
# TODO gate behind flag.
|
||||
cpu_state = CPUState.OCL
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
@ -168,6 +181,12 @@ def is_ixuca():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_ocl():
|
||||
global ocl_available
|
||||
if ocl_available:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@ -178,6 +197,8 @@ def get_torch_device():
|
||||
return torch.device("mps")
|
||||
if cpu_state == CPUState.CPU:
|
||||
return torch.device("cpu")
|
||||
if cpu_state == CPUState.OCL:
|
||||
return torch.device("ocl:0")
|
||||
else:
|
||||
if is_intel_xpu():
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
@ -193,7 +214,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'):
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
@ -218,6 +239,9 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_mlu
|
||||
elif is_ocl():
|
||||
mem_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -1280,7 +1304,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
@ -1308,6 +1332,15 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_mlu + mem_free_torch
|
||||
elif is_ocl():
|
||||
# stats = torch.ocl.memory_stats(dev)
|
||||
# mem_active = stats['active_bytes.all.current']
|
||||
# mem_reserved = stats['reserved_bytes.all.current']
|
||||
# mem_free_ocl, _ = torch.ocl.mem_get_info(dev)
|
||||
# mem_free_torch = mem_reserved - mem_active
|
||||
# mem_free_total = mem_free_mlu + mem_free_torch
|
||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@ -1386,6 +1419,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_mlu():
|
||||
return True
|
||||
|
||||
if is_ocl():
|
||||
# TODO ? RustiCL now supports fp16 at least.
|
||||
return True
|
||||
|
||||
if is_ixuca():
|
||||
return True
|
||||
|
||||
@ -1462,6 +1499,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return True
|
||||
return False
|
||||
|
||||
if is_ocl():
|
||||
# TODO
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
|
||||
if is_mlu():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user