mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 16:32:34 +08:00
Precalculate object_info on startup
This commit is contained in:
parent
67892b5ac5
commit
aedb246816
3
main.py
3
main.py
@ -26,6 +26,7 @@ import yaml
|
||||
import execution
|
||||
import folder_paths
|
||||
import server
|
||||
import nodes
|
||||
from nodes import init_custom_nodes
|
||||
|
||||
|
||||
@ -89,6 +90,8 @@ if __name__ == "__main__":
|
||||
server.add_routes()
|
||||
hijack_progress(server)
|
||||
|
||||
server.load_node_info()
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
||||
|
||||
if args.output_directory:
|
||||
|
||||
22
nodes.py
22
nodes.py
@ -1369,6 +1369,28 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
}
|
||||
|
||||
def get_node_info(node_class):
|
||||
global NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
obj_class = NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = node_class
|
||||
info['display_name'] = NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||
info['output_node'] = True
|
||||
else:
|
||||
info['output_node'] = False
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
return info
|
||||
|
||||
def load_custom_node(module_path):
|
||||
module_name = os.path.basename(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
|
||||
49
server.py
49
server.py
@ -7,6 +7,7 @@ import execution
|
||||
import uuid
|
||||
import json
|
||||
import glob
|
||||
import time
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
@ -72,6 +73,8 @@ class PromptServer():
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
self.client_id = None
|
||||
self.session_id = uuid.uuid4().hex
|
||||
self.cached_object_info = {}
|
||||
|
||||
@routes.get('/ws')
|
||||
async def websocket_handler(request):
|
||||
@ -306,39 +309,24 @@ class PromptServer():
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = node_class
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||
info['output_node'] = True
|
||||
else:
|
||||
info['output_node'] = False
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
out[x] = node_info(x)
|
||||
out = {
|
||||
"object_info": self.cached_object_info,
|
||||
"session_id": self.session_id
|
||||
}
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/object_info/{node_class}")
|
||||
async def get_object_info_node(request):
|
||||
node_class = request.match_info.get("node_class", None)
|
||||
out = {}
|
||||
objinfo = {}
|
||||
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||
out[node_class] = node_info(node_class)
|
||||
objinfo[node_class] = node_info(node_class)
|
||||
out = {
|
||||
"object_info": objinfo,
|
||||
"session_id": self.session_id
|
||||
}
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/history")
|
||||
@ -421,7 +409,16 @@ class PromptServer():
|
||||
self.prompt_queue.delete_history_item(id_to_delete)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
def load_node_info(self):
|
||||
# Precalculate node info to save on request time
|
||||
start_time = time.time()
|
||||
objinfo = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
objinfo[x] = nodes.get_node_info(x)
|
||||
self.cached_object_info = objinfo
|
||||
print(f"Loaded {len(objinfo)} node definitions in {time.time() - start_time} seconds")
|
||||
|
||||
def add_routes(self):
|
||||
self.app.add_routes(self.routes)
|
||||
self.app.add_routes([
|
||||
|
||||
@ -138,7 +138,8 @@ class ComfyApi extends EventTarget {
|
||||
*/
|
||||
async getNodeDefs() {
|
||||
const resp = await fetch("object_info", { cache: "no-store" });
|
||||
return await resp.json();
|
||||
const json = await resp.json();
|
||||
return json["object_info"];
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Loading…
Reference in New Issue
Block a user