Precalculate object_info on startup

This commit is contained in:
space-nuko 2023-06-02 15:01:31 -05:00
parent 67892b5ac5
commit aedb246816
4 changed files with 50 additions and 27 deletions

View File

@ -26,6 +26,7 @@ import yaml
import execution
import folder_paths
import server
import nodes
from nodes import init_custom_nodes
@ -89,6 +90,8 @@ if __name__ == "__main__":
server.add_routes()
hijack_progress(server)
server.load_node_info()
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
if args.output_directory:

View File

@ -1369,6 +1369,28 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"VAEEncodeTiled": "VAE Encode (Tiled)",
}
def get_node_info(node_class):
global NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
obj_class = NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
return info
def load_custom_node(module_path):
module_name = os.path.basename(module_path)
if os.path.isfile(module_path):

View File

@ -7,6 +7,7 @@ import execution
import uuid
import json
import glob
import time
from PIL import Image
from io import BytesIO
@ -72,6 +73,8 @@ class PromptServer():
self.routes = routes
self.last_node_id = None
self.client_id = None
self.session_id = uuid.uuid4().hex
self.cached_object_info = {}
@routes.get('/ws')
async def websocket_handler(request):
@ -306,39 +309,24 @@ class PromptServer():
async def get_prompt(request):
return web.json_response(self.get_queue_info())
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
return info
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
out[x] = node_info(x)
out = {
"object_info": self.cached_object_info,
"session_id": self.session_id
}
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):
node_class = request.match_info.get("node_class", None)
out = {}
objinfo = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class)
objinfo[node_class] = node_info(node_class)
out = {
"object_info": objinfo,
"session_id": self.session_id
}
return web.json_response(out)
@routes.get("/history")
@ -421,7 +409,16 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200)
def load_node_info(self):
# Precalculate node info to save on request time
start_time = time.time()
objinfo = {}
for x in nodes.NODE_CLASS_MAPPINGS:
objinfo[x] = nodes.get_node_info(x)
self.cached_object_info = objinfo
print(f"Loaded {len(objinfo)} node definitions in {time.time() - start_time} seconds")
def add_routes(self):
self.app.add_routes(self.routes)
self.app.add_routes([

View File

@ -138,7 +138,8 @@ class ComfyApi extends EventTarget {
*/
async getNodeDefs() {
const resp = await fetch("object_info", { cache: "no-store" });
return await resp.json();
const json = await resp.json();
return json["object_info"];
}
/**