Precalculate object_info on startup

This commit is contained in:
space-nuko 2023-06-02 15:01:31 -05:00
parent 67892b5ac5
commit aedb246816
4 changed files with 50 additions and 27 deletions

View File

@ -26,6 +26,7 @@ import yaml
import execution import execution
import folder_paths import folder_paths
import server import server
import nodes
from nodes import init_custom_nodes from nodes import init_custom_nodes
@ -89,6 +90,8 @@ if __name__ == "__main__":
server.add_routes() server.add_routes()
hijack_progress(server) hijack_progress(server)
server.load_node_info()
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
if args.output_directory: if args.output_directory:

View File

@ -1369,6 +1369,28 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"VAEEncodeTiled": "VAE Encode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)",
} }
def get_node_info(node_class):
global NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
obj_class = NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
return info
def load_custom_node(module_path): def load_custom_node(module_path):
module_name = os.path.basename(module_path) module_name = os.path.basename(module_path)
if os.path.isfile(module_path): if os.path.isfile(module_path):

View File

@ -7,6 +7,7 @@ import execution
import uuid import uuid
import json import json
import glob import glob
import time
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -72,6 +73,8 @@ class PromptServer():
self.routes = routes self.routes = routes
self.last_node_id = None self.last_node_id = None
self.client_id = None self.client_id = None
self.session_id = uuid.uuid4().hex
self.cached_object_info = {}
@routes.get('/ws') @routes.get('/ws')
async def websocket_handler(request): async def websocket_handler(request):
@ -306,39 +309,24 @@ class PromptServer():
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
return info
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
out = {} out = {
for x in nodes.NODE_CLASS_MAPPINGS: "object_info": self.cached_object_info,
out[x] = node_info(x) "session_id": self.session_id
}
return web.json_response(out) return web.json_response(out)
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
async def get_object_info_node(request): async def get_object_info_node(request):
node_class = request.match_info.get("node_class", None) node_class = request.match_info.get("node_class", None)
out = {} objinfo = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class) objinfo[node_class] = node_info(node_class)
out = {
"object_info": objinfo,
"session_id": self.session_id
}
return web.json_response(out) return web.json_response(out)
@routes.get("/history") @routes.get("/history")
@ -422,6 +410,15 @@ class PromptServer():
return web.Response(status=200) return web.Response(status=200)
def load_node_info(self):
# Precalculate node info to save on request time
start_time = time.time()
objinfo = {}
for x in nodes.NODE_CLASS_MAPPINGS:
objinfo[x] = nodes.get_node_info(x)
self.cached_object_info = objinfo
print(f"Loaded {len(objinfo)} node definitions in {time.time() - start_time} seconds")
def add_routes(self): def add_routes(self):
self.app.add_routes(self.routes) self.app.add_routes(self.routes)
self.app.add_routes([ self.app.add_routes([

View File

@ -138,7 +138,8 @@ class ComfyApi extends EventTarget {
*/ */
async getNodeDefs() { async getNodeDefs() {
const resp = await fetch("object_info", { cache: "no-store" }); const resp = await fetch("object_info", { cache: "no-store" });
return await resp.json(); const json = await resp.json();
return json["object_info"];
} }
/** /**