diff --git a/comfy/model_management.py b/comfy/model_management.py index 2501cecb7..f6b80152f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -40,6 +40,7 @@ class CPUState(Enum): GPU = 0 CPU = 1 MPS = 2 + OCL = 3 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM @@ -102,6 +103,14 @@ if args.directml is not None: # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. +ocl_available = False +try: + import pytorch_ocl + import torch.ocl + ocl_available = True +except ImportError: + pass + try: import intel_extension_for_pytorch as ipex # noqa: F401 except: @@ -139,6 +148,10 @@ try: except: ixuca_available = False +if ocl_available: + # TODO gate behind flag. + cpu_state = CPUState.OCL + if args.cpu: cpu_state = CPUState.CPU @@ -168,6 +181,12 @@ def is_ixuca(): return True return False +def is_ocl(): + global ocl_available + if ocl_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -178,6 +197,8 @@ def get_torch_device(): return torch.device("mps") if cpu_state == CPUState.CPU: return torch.device("cpu") + if cpu_state == CPUState.OCL: + return torch.device("ocl:0") else: if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) @@ -193,7 +214,7 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: @@ -218,6 +239,9 @@ def get_total_memory(dev=None, torch_total_too=False): _, mem_total_mlu = torch.mlu.mem_get_info(dev) mem_total_torch = mem_reserved mem_total = mem_total_mlu + elif is_ocl(): + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -1280,7 +1304,7 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: @@ -1308,6 +1332,15 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_mlu, _ = torch.mlu.mem_get_info(dev) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_mlu + mem_free_torch + elif is_ocl(): + # stats = torch.ocl.memory_stats(dev) + # mem_active = stats['active_bytes.all.current'] + # mem_reserved = stats['reserved_bytes.all.current'] + # mem_free_ocl, _ = torch.ocl.mem_get_info(dev) + # mem_free_torch = mem_reserved - mem_active + # mem_free_total = mem_free_mlu + mem_free_torch + mem_free_total = 1024 * 1024 * 1024 #TODO + mem_free_torch = mem_free_total else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -1386,6 +1419,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_mlu(): return True + if is_ocl(): + # TODO ? RustiCL now supports fp16 at least. + return True + if is_ixuca(): return True @@ -1462,6 +1499,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return True return False + if is_ocl(): + # TODO + return True + props = torch.cuda.get_device_properties(device) if is_mlu():