This commit is contained in:
Benjamin Berman 2023-08-23 16:14:12 -07:00
parent b98e1c0f01
commit e9365c4678
5 changed files with 24 additions and 6 deletions

View File

View File

@ -12,7 +12,7 @@ from io import BytesIO
import json import json
import os import os
import uuid import uuid
from asyncio import Future from asyncio import Future, AbstractEventLoop
from typing import List from typing import List
import aiofiles import aiofiles
@ -75,6 +75,12 @@ class PromptServer():
prompt_queue: execution.PromptQueue | None prompt_queue: execution.PromptQueue | None
address: str address: str
port: int port: int
loop: AbstractEventLoop
messages: asyncio.Queue
number: int
supports: List[str]
app: web.Application
routes: web.RouteTableDef
def __init__(self, loop): def __init__(self, loop):
PromptServer.instance = self PromptServer.instance = self
@ -539,7 +545,6 @@ class PromptServer():
return web.Response(status=429, return web.Response(status=429,
reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker") reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
# read the request # read the request
upload_dir = PromptServer.get_upload_dir()
prompt_dict: dict = {} prompt_dict: dict = {}
if request.headers[aiohttp.hdrs.CONTENT_TYPE] == 'application/json': if request.headers[aiohttp.hdrs.CONTENT_TYPE] == 'application/json':
prompt_dict = await request.json() prompt_dict = await request.json()
@ -556,6 +561,7 @@ class PromptServer():
elif part.filename: elif part.filename:
file_data = await part.read(decode=True) file_data = await part.read(decode=True)
# overwrite existing files # overwrite existing files
upload_dir = PromptServer.get_upload_dir()
async with aiofiles.open(os.path.join(upload_dir, part.filename), mode='wb') as file: async with aiofiles.open(os.path.join(upload_dir, part.filename), mode='wb') as file:
await file.write(file_data) await file.write(file_data)
except IOError | MemoryError as ioError: except IOError | MemoryError as ioError:
@ -567,8 +573,7 @@ class PromptServer():
return web.Response(status=400, reason="no prompt was specified") return web.Response(status=400, reason="no prompt was specified")
content_digest = digest(prompt_dict) content_digest = digest(prompt_dict)
dump = json.dumps(prompt_dict)
valid = execution.validate_prompt(prompt_dict) valid = execution.validate_prompt(prompt_dict)
if not valid[0]: if not valid[0]:
return web.Response(status=400, body=valid[1]) return web.Response(status=400, body=valid[1])

View File

@ -0,0 +1,13 @@
from asyncio import AbstractEventLoop
from ..cmd.execution import PromptQueue
from ..cmd.server import PromptServer
class Comfy:
loop: AbstractEventLoop
server: PromptServer
queue: PromptQueue
def __init__(self):
pass

View File

@ -7,7 +7,7 @@ import time
import types import types
from . import base_nodes from . import base_nodes
from ..comfy_extras import nodes as comfy_extras_nodes from comfy_extras import nodes as comfy_extras_nodes
try: try:
import custom_nodes import custom_nodes

View File

@ -17,7 +17,7 @@ from setuptools import setup, find_packages
""" """
The name of the package. The name of the package.
""" """
package_name = "comfy" package_name = "comfyui"
""" """
The current version. The current version.