diff --git a/cuda_malloc.py b/cuda_malloc.py new file mode 100644 index 000000000..382432215 --- /dev/null +++ b/cuda_malloc.py @@ -0,0 +1,77 @@ +import os +import importlib.util +from comfy.cli_args import args + +#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. +def get_gpu_names(): + if os.name == 'nt': + import ctypes + + # Define necessary C structures and types + class DISPLAY_DEVICEA(ctypes.Structure): + _fields_ = [ + ('cb', ctypes.c_ulong), + ('DeviceName', ctypes.c_char * 32), + ('DeviceString', ctypes.c_char * 128), + ('StateFlags', ctypes.c_ulong), + ('DeviceID', ctypes.c_char * 128), + ('DeviceKey', ctypes.c_char * 128) + ] + + # Load user32.dll + user32 = ctypes.windll.user32 + + # Call EnumDisplayDevicesA + def enum_display_devices(): + device_info = DISPLAY_DEVICEA() + device_info.cb = ctypes.sizeof(device_info) + device_index = 0 + gpu_names = set() + + while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0): + device_index += 1 + gpu_names.add(device_info.DeviceString.decode('utf-8')) + return gpu_names + return enum_display_devices() + else: + return set() + +def cuda_malloc_supported(): + blacklist = {"GeForce GTX 960M", "GeForce GTX 950M", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745"} + try: + names = get_gpu_names() + except: + names = set() + for x in names: + if "NVIDIA" in x: + for b in blacklist: + if b in x: + return False + return True + + +if not args.cuda_malloc: + try: + version = "" + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ + if int(version[0]) >= 2: #enable by default for torch version 2.0 and up + args.cuda_malloc = cuda_malloc_supported() + except: + pass + + +if args.cuda_malloc and not args.disable_cuda_malloc: + env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) + if env_var is None: + env_var = "backend:cudaMallocAsync" + else: + env_var += ",backend:cudaMallocAsync" + + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var diff --git a/latent_preview.py b/latent_preview.py index 833e6822e..30c1d1317 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -1,6 +1,5 @@ import torch -from PIL import Image, ImageOps -from io import BytesIO +from PIL import Image import struct import numpy as np from comfy.cli_args import args, LatentPreviewMethod @@ -15,26 +14,7 @@ class LatentPreviewer: def decode_latent_to_preview_image(self, preview_format, x0): preview_image = self.decode_latent_to_preview(x0) - - if hasattr(Image, 'Resampling'): - resampling = Image.Resampling.BILINEAR - else: - resampling = Image.ANTIALIAS - - preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling) - - preview_type = 1 - if preview_format == "JPEG": - preview_type = 1 - elif preview_format == "PNG": - preview_type = 2 - - bytesIO = BytesIO() - header = struct.pack(">I", preview_type) - bytesIO.write(header) - preview_image.save(bytesIO, format=preview_format, quality=95) - preview_bytes = bytesIO.getvalue() - return preview_bytes + return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) class TAESDPreviewerImpl(LatentPreviewer): def __init__(self, taesd): diff --git a/main.py b/main.py index 61ba9e8e6..21f76b617 100644 --- a/main.py +++ b/main.py @@ -61,30 +61,7 @@ if __name__ == "__main__": os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) print("Set cuda device to:", args.cuda_device) - if not args.cuda_malloc: - try: #if there's a better way to check the torch version without importing it let me know - version = "" - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - ver_file = os.path.join(folder, "version.py") - if os.path.isfile(ver_file): - spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - version = module.__version__ - if int(version[0]) >= 2: #enable by default for torch version 2.0 and up - args.cuda_malloc = True - except: - pass - - if args.cuda_malloc and not args.disable_cuda_malloc: - env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) - if env_var is None: - env_var = "backend:cudaMallocAsync" - else: - env_var += ",backend:cudaMallocAsync" - - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + import cuda_malloc import comfy.utils import yaml @@ -115,10 +92,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): - def hook(value, total, preview_image_bytes): + def hook(value, total, preview_image): server.send_sync("progress", {"value": value, "max": total}, server.client_id) - if preview_image_bytes is not None: - server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) + if preview_image is not None: + server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index a0e22878b..242d3175f 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -2,8 +2,12 @@ import json from urllib import request, parse import random -#this is the ComfyUI api prompt format. If you want it for a specific workflow you can copy it from the prompt section -#of the image metadata of images generated with ComfyUI +#This is the ComfyUI api prompt format. + +#If you want it for a specific workflow you can "enable dev mode options" +#in the settings of the UI (gear beside the "Queue Size: ") this will enable +#a button on the UI to save workflows in api format. + #keep in mind ComfyUI is pre alpha software so this format will change a bit. #this is the one for the default workflow diff --git a/server.py b/server.py index 9ca131ede..f61b11a97 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ import uuid import json import glob import struct -from PIL import Image +from PIL import Image, ImageOps from io import BytesIO try: @@ -29,6 +29,7 @@ import comfy.model_management class BinaryEventTypes: PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 async def send_socket_catch_exception(function, message): try: @@ -498,7 +499,9 @@ class PromptServer(): return prompt_info async def send(self, event, data, sid=None): - if isinstance(data, (bytes, bytearray)): + if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: + await self.send_image(data, sid=sid) + elif isinstance(data, (bytes, bytearray)): await self.send_bytes(event, data, sid) else: await self.send_json(event, data, sid) @@ -512,6 +515,30 @@ class PromptServer(): message.extend(data) return message + async def send_image(self, image_data, sid=None): + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.ANTIALIAS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + type_num = 1 + if image_type == "JPEG": + type_num = 1 + elif image_type == "PNG": + type_num = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", type_num) + bytesIO.write(header) + image.save(bytesIO, format=image_type, quality=95, compress_level=4) + preview_bytes = bytesIO.getvalue() + await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) + async def send_bytes(self, event, data, sid=None): message = self.encode_bytes(event, data)