diff --git a/main.py b/main.py index 50d3b9a62..cf7847fe9 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ import yaml import execution import folder_paths import server +import nodes from nodes import init_custom_nodes @@ -89,6 +90,8 @@ if __name__ == "__main__": server.add_routes() hijack_progress(server) + server.load_node_info() + threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() if args.output_directory: diff --git a/nodes.py b/nodes.py index 90444a92c..92c4d336b 100644 --- a/nodes.py +++ b/nodes.py @@ -1369,6 +1369,28 @@ NODE_DISPLAY_NAME_MAPPINGS = { "VAEEncodeTiled": "VAE Encode (Tiled)", } +def get_node_info(node_class): + global NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + + obj_class = NODE_CLASS_MAPPINGS[node_class] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = node_class + info['display_name'] = NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = '' + info['category'] = 'sd' + if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: + info['output_node'] = True + else: + info['output_node'] = False + + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + return info + def load_custom_node(module_path): module_name = os.path.basename(module_path) if os.path.isfile(module_path): diff --git a/server.py b/server.py index 5be822a6f..f3d18634c 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import execution import uuid import json import glob +import time from PIL import Image from io import BytesIO @@ -72,6 +73,8 @@ class PromptServer(): self.routes = routes self.last_node_id = None self.client_id = None + self.session_id = uuid.uuid4().hex + self.cached_object_info = {} @routes.get('/ws') async def websocket_handler(request): @@ -306,39 +309,24 @@ class PromptServer(): async def get_prompt(request): return web.json_response(self.get_queue_info()) - def node_info(node_class): - obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = node_class - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class - info['description'] = '' - info['category'] = 'sd' - if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: - info['output_node'] = True - else: - info['output_node'] = False - - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - return info - @routes.get("/object_info") async def get_object_info(request): - out = {} - for x in nodes.NODE_CLASS_MAPPINGS: - out[x] = node_info(x) + out = { + "object_info": self.cached_object_info, + "session_id": self.session_id + } return web.json_response(out) @routes.get("/object_info/{node_class}") async def get_object_info_node(request): node_class = request.match_info.get("node_class", None) - out = {} + objinfo = {} if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): - out[node_class] = node_info(node_class) + objinfo[node_class] = node_info(node_class) + out = { + "object_info": objinfo, + "session_id": self.session_id + } return web.json_response(out) @routes.get("/history") @@ -421,7 +409,16 @@ class PromptServer(): self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) - + + def load_node_info(self): + # Precalculate node info to save on request time + start_time = time.time() + objinfo = {} + for x in nodes.NODE_CLASS_MAPPINGS: + objinfo[x] = nodes.get_node_info(x) + self.cached_object_info = objinfo + print(f"Loaded {len(objinfo)} node definitions in {time.time() - start_time} seconds") + def add_routes(self): self.app.add_routes(self.routes) self.app.add_routes([ diff --git a/web/scripts/api.js b/web/scripts/api.js index 378165b3a..3de0be56a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -138,7 +138,8 @@ class ComfyApi extends EventTarget { */ async getNodeDefs() { const resp = await fetch("object_info", { cache: "no-store" }); - return await resp.json(); + const json = await resp.json(); + return json["object_info"]; } /**