basic caching

This commit is contained in:
Benjamin Berman 2023-08-01 16:35:00 -07:00
parent 7c197409be
commit 66b857d069
4 changed files with 298 additions and 43 deletions

View File

@ -8,7 +8,7 @@ servers:
paths:
/:
get:
summary: Web UI index.html
summary: (UI) index.html
operationId: get_root
responses:
200:
@ -18,7 +18,7 @@ paths:
example: "<!DOCTYPE html>..."
/embeddings:
get:
summary: Get embeddings
summary: (UI) Get embeddings
operationId: get_embeddings
responses:
200:
@ -35,7 +35,7 @@ paths:
type: string
/extensions:
get:
summary: Get extensions
summary: (UI) Get extensions
operationId: get_extensions
responses:
200:
@ -48,7 +48,7 @@ paths:
type: string
/upload/image:
post:
summary: Upload an image.
summary: (UI) Upload an image.
description: |
Uploads an image to the input/ directory.
@ -82,7 +82,7 @@ paths:
The request was missing an image upload.
/view:
get:
summary: View image
summary: (UI) View image
operationId: view_image
parameters:
- in: query
@ -118,7 +118,7 @@ paths:
description: Not Found
/prompt:
get:
summary: Get queue info
summary: (UI) Get queue info
operationId: get_prompt
responses:
200:
@ -134,7 +134,7 @@ paths:
queue_remaining:
type: integer
post:
summary: Post prompt
summary: (UI) Post prompt
operationId: post_prompt
requestBody:
content:
@ -157,7 +157,7 @@ paths:
type: string
/object_info:
get:
summary: Get object info
summary: (UI) Get object info
operationId: get_object_info
responses:
'200':
@ -172,7 +172,7 @@ paths:
$ref: "#/components/schemas/Node"
/history:
get:
summary: Get history
summary: (UI) Get history
operationId: get_history
responses:
"200":
@ -192,7 +192,7 @@ paths:
outputs:
type: object
post:
summary: Post history
summary: (UI) Post history
operationId: post_history
requestBody:
content:
@ -211,7 +211,7 @@ paths:
description: OK
/queue:
get:
summary: Get queue
summary: (UI) Get queue
operationId: get_queue
responses:
"200":
@ -230,7 +230,7 @@ paths:
items:
$ref: "#/components/schemas/QueueTuple"
post:
summary: Post queue
summary: (UI) Post queue
operationId: post_queue
requestBody:
content:
@ -249,15 +249,40 @@ paths:
description: OK
/interrupt:
post:
summary: Post interrupt
summary: (UI) Post interrupt
operationId: post_interrupt
responses:
'200':
description: OK
/api/v1/images/{digest}:
get:
summary: (API) Get image
description: |
Returns an image given a content hash.
parameters:
- name: digest
in: path
required: true
description: A digest of the request used to generate the imaeg
schema:
type: string
example: e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
responses:
404:
description: No image was found.
200:
description: An image.
content:
image/png:
schema:
type: string
format: binary
/api/v1/prompts:
get:
summary: Return the last prompt run anywhere that was used to produce an image.
summary: (API) Get prompt
description: |
Return the last prompt run anywhere that was used to produce an image
The prompt object can be POSTed to run the image again with your own parameters.
The last prompt, whether it was in the UI or via the API, will be returned here.
@ -273,8 +298,10 @@ paths:
description: |
There were no prompts in the history to return.
post:
summary: Run a prompt to produce an image.
summary: (API) Generate image
description: |
Run a prompt to generate an image.
Blocks until the image is produced. This may take an arbitrarily long amount of time due to model loading.
Prompts that produce multiple images will return the last SaveImage output node in the Prompt by default. To return a specific image, remove other
@ -284,11 +311,89 @@ paths:
filenames will be used in your Prompt.
responses:
200:
headers:
Location:
description: The URL to the file based on a hash of the request body.
example: /api/v1/images/e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
schema:
type: string
Digest:
description: The digest of the request body
example: SHA256=e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
schema:
type: string
Content-Disposition:
description: The filename when a SaveImage node is specified.
example: filename=ComfyUI_00001.png
schema:
type: string
description: |
The binary content of the last SaveImage node.
content:
text/uri-list:
schema:
description: |
The URI to retrieve the binary content of the image.
This will return two URLs. The first is the ordinary ComfyUI view image URL that exactly corresponds
to the UI call. The second is the URL that corresponds to sha256 hash of the request body.
Hashing function for web browsers:
```js
async function generateHash(body) {
// Stringify and sort keys in the JSON object
let str = JSON.stringify(body);
// Encode the string as a Uint8Array
let encoder = new TextEncoder();
let data = encoder.encode(str);
// Create a SHA-256 hash of the data
let hash = await window.crypto.subtle.digest('SHA-256', data);
// Convert the hash (which is an ArrayBuffer) to a hex string
let hashArray = Array.from(new Uint8Array(hash));
let hashHex = hashArray.map(b => b.toString(16).padStart(2, '0')).join('');
return hashHex;
}
```
Hashing function for nodejs:
```js
const crypto = require('crypto');
function generateHash(body) {
// Stringify and sort keys in the JSON object
let str = JSON.stringify(body);
// Create a SHA-256 hash of the string
let hash = crypto.createHash('sha256');
hash.update(str);
// Return the hexadecimal representation of the hash
return hash.digest('hex');
}
```
Hashing function for python:
```python
def digest(data: dict | str) -> str:
json_str = data if isinstance(data, str) else json.dumps(data)
json_bytes = json_str.encode('utf-8')
hash_object = hashlib.sha256(json_bytes)
return hash_object.hexdigest()
```
type: string
example: |
/api/v1/images/e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output
image/png:
schema:
description: The PNG binary content.
type: string
format: binary
204:
@ -398,6 +503,92 @@ components:
$ref: "#/components/schemas/Workflow"
Prompt:
type: object
example: {
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": 8,
"denoise": 1,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "euler",
"scheduler": "normal",
"seed": 8566257,
"steps": 20
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": 512,
"width": 512
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "masterpiece best quality girl"
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "bad hands"
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
}
}
}
description: |
The keys are stringified integers corresponding to nodes.

9
comfy/digest.py Normal file
View File

@ -0,0 +1,9 @@
import hashlib
import json
def digest(data: dict | str) -> str:
json_str = data if isinstance(data, str) else json.dumps(data)
hash_object = hashlib.sha256()
hash_object.update(json_str.encode())
return hash_object.hexdigest()

View File

@ -166,6 +166,8 @@ if __name__ == "__main__":
webbrowser.open(f"http://{address}:{port}")
call_on_start = startup_server
server.address = args.listen
server.port = args.port
try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:

109
server.py
View File

@ -6,7 +6,6 @@ from PIL import Image, ImageOps
from io import BytesIO
import json
import mimetypes
import os
import uuid
from asyncio import Future
@ -20,21 +19,26 @@ import execution
import folder_paths
import nodes
import mimetypes
from comfy.digest import digest
from comfy.cli_args import args
import comfy.utils
import comfy.model_management
from comfy.vendor.appdirs import user_data_dir
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
print("send error:", err)
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
@ -42,6 +46,7 @@ async def cache_control(request: web.Request, handler):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
def create_cors_middleware(allowed_origin: str):
@web.middleware
async def cors_middleware(request: web.Request, handler):
@ -59,8 +64,11 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware
class PromptServer():
prompt_queue: execution.PromptQueue | None
address: str
port: int
def __init__(self, loop):
PromptServer.instance = self
@ -76,7 +84,8 @@ class PromptServer():
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
self.app = web.Application(client_max_size=20971520, handler_args={'max_field_size': 16380},
middlewares=middlewares)
self.sockets = dict()
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web")
@ -100,10 +109,10 @@ class PromptServer():
try:
# Send initial state to the new client
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
await self.send("status", {"status": self.get_queue_info(), 'sid': sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
await self.send("executing", {"node": self.last_node_id}, sid)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
@ -124,7 +133,8 @@ class PromptServer():
@routes.get("/extensions")
async def get_extensions(request):
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
return web.json_response(
list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
def get_dir_by_type(dir_type=None):
type_dir = ""
@ -179,7 +189,7 @@ class PromptServer():
async with aiofiles.open(filepath, mode='wb') as file:
await file.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
return web.json_response({"name": filename, "subfolder": subfolder, "type": image_upload_type})
else:
return web.Response(status=400)
@ -188,7 +198,6 @@ class PromptServer():
post = await request.post()
return image_upload(post)
@routes.post("/upload/mask")
async def upload_mask(request):
post = await request.post()
@ -232,7 +241,7 @@ class PromptServer():
async def view_image(request):
if "filename" in request.rel_url.query:
filename = request.rel_url.query["filename"]
filename,output_dir = folder_paths.annotated_filepath(filename)
filename, output_dir = folder_paths.annotated_filepath(filename)
# validation for security: prevent accessing arbitrary path
if filename[0] == '/' or '..' in filename:
@ -330,7 +339,7 @@ class PromptServer():
safetensors_path = folder_paths.get_full_path(folder_name, filename)
if safetensors_path is None:
return web.Response(status=404)
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024 * 1024)
if out is None:
return web.Response(status=404)
dt = json.loads(out)
@ -373,10 +382,13 @@ class PromptServer():
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [
False] * len(
obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[
node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
@ -425,7 +437,7 @@ class PromptServer():
print("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()
json_data = await request.json()
if "number" in json_data:
number = float(json_data['number'])
@ -449,7 +461,9 @@ class PromptServer():
if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put(execution.QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute), completed=None))
self.prompt_queue.put(
execution.QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute),
completed=None))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
else:
@ -460,7 +474,7 @@ class PromptServer():
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_queue()
@ -479,7 +493,7 @@ class PromptServer():
@routes.post("/history")
async def post_history(request):
json_data = await request.json()
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_history()
@ -490,13 +504,21 @@ class PromptServer():
return web.Response(status=200)
@routes.get("/api/v1/images/{content_digest}")
async def get_image(request: web.Request) -> web.FileResponse:
digest_ = request.match_info['content_digest']
path = os.path.join(user_data_dir("comfyui", "comfyanonymous", roaming=False), digest_)
return web.FileResponse(path,
headers={"Content-Disposition": f"filename=\"{digest_}.png\""})
@routes.post("/api/v1/prompts")
async def post_prompt(request: web.Request) -> web.Response | web.FileResponse:
# check if the queue is too long
queue_size = self.prompt_queue.size()
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
if queue_size > queue_too_busy_size:
return web.Response(status=429, reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
return web.Response(status=429,
reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
# read the request
upload_dir = PromptServer.get_upload_dir()
prompt_dict: dict = {}
@ -525,9 +547,15 @@ class PromptServer():
if len(prompt_dict) == 0:
return web.Response(status=400, reason="no prompt was specified")
valid, error_message = execution.validate_prompt(prompt_dict)
if not valid:
return web.Response(status=400, body=error_message)
valid = execution.validate_prompt(prompt_dict)
if not valid[0]:
return web.Response(status=400, body=valid[1])
content_digest = digest(prompt_dict)
cache_path = os.path.join(user_data_dir("comfyui", "comfyanonymous", roaming=False), content_digest)
if os.path.exists(cache_path):
return web.FileResponse(path=cache_path,
headers={"Content-Disposition": f"filename=\"{content_digest}.png\""})
# todo: check that the files specified in the InputFile nodes exist
@ -535,7 +563,9 @@ class PromptServer():
completed: Future = self.loop.create_future()
number = self.number
self.number += 1
self.prompt_queue.put(execution.QueueItem(queue_tuple=(number, id(prompt_dict), prompt_dict, dict()), completed=completed))
self.prompt_queue.put(
execution.QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]),
completed=completed))
try:
await completed
@ -547,15 +577,36 @@ class PromptServer():
output_images: List[str] = []
for node_id, node in outputs_dict.items():
if isinstance(node, dict) and 'ui' in node and isinstance(node['ui'], dict) and 'images' in node['ui']:
for image_tuple in node['ui']['images']:
subfolder_ = image_tuple['subfolder']
filename_ = image_tuple['filename']
output_images.append(PromptServer.get_output_path(subfolder=subfolder_, filename=filename_))
images: List[dict] = []
if 'images' in node:
images = node['images']
elif isinstance(node, dict) and 'ui' in node and isinstance(node['ui'], dict) and 'images' in node[
'ui']:
images = node['ui']['images']
for image_tuple in images:
subfolder_ = image_tuple['subfolder']
filename_ = image_tuple['filename']
output_images.append(PromptServer.get_output_path(subfolder=subfolder_, filename=filename_))
if len(output_images) > 0:
image_ = output_images[-1]
return web.FileResponse(path=image_, headers={"Content-Disposition": f"filename=\"{os.path.basename(image_)}\""})
if not os.path.exists(os.path.dirname(cache_path)):
os.makedirs(os.path.dirname(cache_path))
os.symlink(image_, cache_path)
cache_url = "/api/v1/images/{content_digest}"
filename = os.path.basename(image_)
if 'Accept' in request.headers and request.headers['Accept'] == 'text/uri-list':
res = web.Response(status=200, text=f"""
{cache_url}
http://{self.address}:{self.port}/view?filename={filename}&type=output
""")
else:
res = web.FileResponse(path=image_,
headers={
"Digest": f"SHA-256={content_digest}",
"Location": f"/api/v1/images/{content_digest}",
"Content-Disposition": f"filename=\"{filename}\""})
return res
else:
return web.Response(status=204)
@ -569,7 +620,9 @@ class PromptServer():
# argmax
def _history_item_timestamp(i: int):
return history_items[i]['timestamp']
last_history_item: execution.HistoryEntry = history_items[max(range(len(history_items)), key=_history_item_timestamp)]
last_history_item: execution.HistoryEntry = history_items[
max(range(len(history_items)), key=_history_item_timestamp)]
prompt = last_history_item['prompt'][2]
return web.json_response(prompt, status=200)
@ -650,7 +703,7 @@ class PromptServer():
self.messages.put_nowait, (event, data, sid))
def queue_updated(self):
self.send_sync("status", { "status": self.get_queue_info() })
self.send_sync("status", {"status": self.get_queue_info()})
async def publish_loop(self):
while True: