Merge branch 'comfyanonymous:master' into fix/secure-combo

This commit is contained in:
Dr.Lt.Data 2023-07-20 12:49:26 +09:00 committed by GitHub
commit 5000f0e2f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 118 additions and 53 deletions

77
cuda_malloc.py Normal file
View File

@ -0,0 +1,77 @@
import os
import importlib.util
from comfy.cli_args import args
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
if os.name == 'nt':
import ctypes
# Define necessary C structures and types
class DISPLAY_DEVICEA(ctypes.Structure):
_fields_ = [
('cb', ctypes.c_ulong),
('DeviceName', ctypes.c_char * 32),
('DeviceString', ctypes.c_char * 128),
('StateFlags', ctypes.c_ulong),
('DeviceID', ctypes.c_char * 128),
('DeviceKey', ctypes.c_char * 128)
]
# Load user32.dll
user32 = ctypes.windll.user32
# Call EnumDisplayDevicesA
def enum_display_devices():
device_info = DISPLAY_DEVICEA()
device_info.cb = ctypes.sizeof(device_info)
device_index = 0
gpu_names = set()
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
device_index += 1
gpu_names.add(device_info.DeviceString.decode('utf-8'))
return gpu_names
return enum_display_devices()
else:
return set()
def cuda_malloc_supported():
blacklist = {"GeForce GTX 960M", "GeForce GTX 950M", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745"}
try:
names = get_gpu_names()
except:
names = set()
for x in names:
if "NVIDIA" in x:
for b in blacklist:
if b in x:
return False
return True
if not args.cuda_malloc:
try:
version = ""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
args.cuda_malloc = cuda_malloc_supported()
except:
pass
if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var

View File

@ -1,6 +1,5 @@
import torch import torch
from PIL import Image, ImageOps from PIL import Image
from io import BytesIO
import struct import struct
import numpy as np import numpy as np
from comfy.cli_args import args, LatentPreviewMethod from comfy.cli_args import args, LatentPreviewMethod
@ -15,26 +14,7 @@ class LatentPreviewer:
def decode_latent_to_preview_image(self, preview_format, x0): def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0) preview_image = self.decode_latent_to_preview(x0)
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format, quality=95)
preview_bytes = bytesIO.getvalue()
return preview_bytes
class TAESDPreviewerImpl(LatentPreviewer): class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd): def __init__(self, taesd):

31
main.py
View File

@ -61,30 +61,7 @@ if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device) print("Set cuda device to:", args.cuda_device)
if not args.cuda_malloc: import cuda_malloc
try: #if there's a better way to check the torch version without importing it let me know
version = ""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
args.cuda_malloc = True
except:
pass
if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
import comfy.utils import comfy.utils
import yaml import yaml
@ -115,10 +92,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server): def hijack_progress(server):
def hook(value, total, preview_image_bytes): def hook(value, total, preview_image):
server.send_sync("progress", {"value": value, "max": total}, server.client_id) server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image_bytes is not None: if preview_image is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)

View File

@ -2,8 +2,12 @@ import json
from urllib import request, parse from urllib import request, parse
import random import random
#this is the ComfyUI api prompt format. If you want it for a specific workflow you can copy it from the prompt section #This is the ComfyUI api prompt format.
#of the image metadata of images generated with ComfyUI
#If you want it for a specific workflow you can "enable dev mode options"
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
#a button on the UI to save workflows in api format.
#keep in mind ComfyUI is pre alpha software so this format will change a bit. #keep in mind ComfyUI is pre alpha software so this format will change a bit.
#this is the one for the default workflow #this is the one for the default workflow

View File

@ -8,7 +8,7 @@ import uuid
import json import json
import glob import glob
import struct import struct
from PIL import Image from PIL import Image, ImageOps
from io import BytesIO from io import BytesIO
try: try:
@ -29,6 +29,7 @@ import comfy.model_management
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
@ -498,7 +499,9 @@ class PromptServer():
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
if isinstance(data, (bytes, bytearray)): if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid) await self.send_bytes(event, data, sid)
else: else:
await self.send_json(event, data, sid) await self.send_json(event, data, sid)
@ -512,6 +515,30 @@ class PromptServer():
message.extend(data) message.extend(data)
return message return message
async def send_image(self, image_data, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_bytes(self, event, data, sid=None): async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data) message = self.encode_bytes(event, data)