diff --git a/comfy/model_management.py b/comfy/model_management.py index e9af7f3a7..9c3147d76 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,7 @@ import psutil from enum import Enum from comfy.cli_args import args +import torch class VRAMState(Enum): CPU = 0 @@ -33,28 +34,67 @@ if args.directml is not None: lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import torch - if directml_enabled: - pass #TODO - else: - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) - total_ram = psutil.virtual_memory().total / (1024 * 1024) - if not args.normalvram and not args.cpu: - if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True except: pass +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) + +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + mem_total_torch = mem_total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + +total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) +total_ram = psutil.virtual_memory().total / (1024 * 1024) +print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +if not args.normalvram and not args.cpu: + if lowvram_available and total_vram <= 4096: + print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + set_vram_to = VRAMState.LOW_VRAM + elif total_vram > total_ram * 1.1 and total_vram > 14336: + print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") + vram_state = VRAMState.HIGH_VRAM + try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError except: @@ -128,29 +168,17 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_torch_device_name(device): if hasattr(device, 'type'): - return "{}".format(device.type) - return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + if device.type == "cuda": + return "{} {}".format(device, torch.cuda.get_device_name(device)) + else: + return "{}".format(device.type) + else: + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Using device:", get_torch_device_name(get_torch_device())) + print("Device:", get_torch_device_name(get_torch_device())) except: print("Could not pick default device.") diff --git a/folder_paths.py b/folder_paths.py index e179a28d4..a1bf1444d 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,4 +1,5 @@ import os +import time supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) @@ -154,7 +155,7 @@ def get_filename_list_(folder_name): output_list.update(filter_files_extensions(files, folders[1])) output_folders = {**output_folders, **folders_all} - return (sorted(list(output_list)), output_folders) + return (sorted(list(output_list)), output_folders, time.perf_counter()) def cached_filename_list_(folder_name): global filename_list_cache @@ -162,6 +163,8 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] + if time.perf_counter() < (out[2] + 0.5): + return out for x in out[1]: time_modified = out[1][x] folder = x @@ -170,8 +173,9 @@ def cached_filename_list_(folder_name): folders = folder_names_and_paths[folder_name] for x in folders[0]: - if x not in out[1]: - return None + if os.path.isdir(x): + if x not in out[1]: + return None return out diff --git a/server.py b/server.py index 0b64df147..5be822a6f 100644 --- a/server.py +++ b/server.py @@ -23,6 +23,7 @@ except ImportError: import mimetypes from comfy.cli_args import args import comfy.utils +import comfy.model_management @web.middleware async def cache_control(request: web.Request, handler): @@ -280,6 +281,27 @@ class PromptServer(): return web.Response(status=404) return web.json_response(dt["__metadata__"]) + @routes.get("/system_stats") + async def get_queue(request): + device = comfy.model_management.get_torch_device() + device_name = comfy.model_management.get_torch_device_name(device) + vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + system_stats = { + "devices": [ + { + "name": device_name, + "type": device.type, + "index": device.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + } + ] + } + return web.json_response(system_stats) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info())