mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-18 02:10:17 +08:00
Merge 9d1617bb21 into 35fa091340
This commit is contained in:
commit
b68a6cb5fa
@ -39,6 +39,7 @@ class CPUState(Enum):
|
|||||||
GPU = 0
|
GPU = 0
|
||||||
CPU = 1
|
CPU = 1
|
||||||
MPS = 2
|
MPS = 2
|
||||||
|
OCL = 3
|
||||||
|
|
||||||
# Determine VRAM State
|
# Determine VRAM State
|
||||||
vram_state = VRAMState.NORMAL_VRAM
|
vram_state = VRAMState.NORMAL_VRAM
|
||||||
@ -101,6 +102,14 @@ if args.directml is not None:
|
|||||||
# torch_directml.disable_tiled_resources(True)
|
# torch_directml.disable_tiled_resources(True)
|
||||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||||
|
|
||||||
|
ocl_available = False
|
||||||
|
try:
|
||||||
|
import pytorch_ocl
|
||||||
|
import torch.ocl
|
||||||
|
ocl_available = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||||
except:
|
except:
|
||||||
@ -138,6 +147,10 @@ try:
|
|||||||
except:
|
except:
|
||||||
ixuca_available = False
|
ixuca_available = False
|
||||||
|
|
||||||
|
if ocl_available:
|
||||||
|
# TODO gate behind flag.
|
||||||
|
cpu_state = CPUState.OCL
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
cpu_state = CPUState.CPU
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
@ -167,6 +180,12 @@ def is_ixuca():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_ocl():
|
||||||
|
global ocl_available
|
||||||
|
if ocl_available:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@ -177,6 +196,8 @@ def get_torch_device():
|
|||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if cpu_state == CPUState.CPU:
|
if cpu_state == CPUState.CPU:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
if cpu_state == CPUState.OCL:
|
||||||
|
return torch.device("ocl:0")
|
||||||
else:
|
else:
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return torch.device("xpu", torch.xpu.current_device())
|
return torch.device("xpu", torch.xpu.current_device())
|
||||||
@ -192,7 +213,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
||||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'):
|
||||||
mem_total = psutil.virtual_memory().total
|
mem_total = psutil.virtual_memory().total
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_total
|
||||||
else:
|
else:
|
||||||
@ -217,6 +238,9 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
|
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
|
||||||
mem_total_torch = mem_reserved
|
mem_total_torch = mem_reserved
|
||||||
mem_total = mem_total_mlu
|
mem_total = mem_total_mlu
|
||||||
|
elif is_ocl():
|
||||||
|
mem_total = 1024 * 1024 * 1024 #TODO
|
||||||
|
mem_total_torch = mem_total
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@ -1261,7 +1285,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
||||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'):
|
||||||
mem_free_total = psutil.virtual_memory().available
|
mem_free_total = psutil.virtual_memory().available
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
@ -1289,6 +1313,15 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
|
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
|
||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_total = mem_free_mlu + mem_free_torch
|
mem_free_total = mem_free_mlu + mem_free_torch
|
||||||
|
elif is_ocl():
|
||||||
|
# stats = torch.ocl.memory_stats(dev)
|
||||||
|
# mem_active = stats['active_bytes.all.current']
|
||||||
|
# mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
# mem_free_ocl, _ = torch.ocl.mem_get_info(dev)
|
||||||
|
# mem_free_torch = mem_reserved - mem_active
|
||||||
|
# mem_free_total = mem_free_mlu + mem_free_torch
|
||||||
|
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||||
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@ -1367,6 +1400,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_mlu():
|
if is_mlu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_ocl():
|
||||||
|
# TODO ? RustiCL now supports fp16 at least.
|
||||||
|
return True
|
||||||
|
|
||||||
if is_ixuca():
|
if is_ixuca():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -1443,6 +1480,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if is_ocl():
|
||||||
|
# TODO
|
||||||
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
|
||||||
if is_mlu():
|
if is_mlu():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user