mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Precalculate object_info on startup
This commit is contained in:
parent
67892b5ac5
commit
aedb246816
3
main.py
3
main.py
@ -26,6 +26,7 @@ import yaml
|
|||||||
import execution
|
import execution
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import server
|
import server
|
||||||
|
import nodes
|
||||||
from nodes import init_custom_nodes
|
from nodes import init_custom_nodes
|
||||||
|
|
||||||
|
|
||||||
@ -89,6 +90,8 @@ if __name__ == "__main__":
|
|||||||
server.add_routes()
|
server.add_routes()
|
||||||
hijack_progress(server)
|
hijack_progress(server)
|
||||||
|
|
||||||
|
server.load_node_info()
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
||||||
|
|
||||||
if args.output_directory:
|
if args.output_directory:
|
||||||
|
|||||||
22
nodes.py
22
nodes.py
@ -1369,6 +1369,28 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_node_info(node_class):
|
||||||
|
global NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
|
obj_class = NODE_CLASS_MAPPINGS[node_class]
|
||||||
|
info = {}
|
||||||
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
|
info['output'] = obj_class.RETURN_TYPES
|
||||||
|
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||||
|
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||||
|
info['name'] = node_class
|
||||||
|
info['display_name'] = NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||||
|
info['description'] = ''
|
||||||
|
info['category'] = 'sd'
|
||||||
|
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||||
|
info['output_node'] = True
|
||||||
|
else:
|
||||||
|
info['output_node'] = False
|
||||||
|
|
||||||
|
if hasattr(obj_class, 'CATEGORY'):
|
||||||
|
info['category'] = obj_class.CATEGORY
|
||||||
|
return info
|
||||||
|
|
||||||
def load_custom_node(module_path):
|
def load_custom_node(module_path):
|
||||||
module_name = os.path.basename(module_path)
|
module_name = os.path.basename(module_path)
|
||||||
if os.path.isfile(module_path):
|
if os.path.isfile(module_path):
|
||||||
|
|||||||
47
server.py
47
server.py
@ -7,6 +7,7 @@ import execution
|
|||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import glob
|
import glob
|
||||||
|
import time
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@ -72,6 +73,8 @@ class PromptServer():
|
|||||||
self.routes = routes
|
self.routes = routes
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.client_id = None
|
self.client_id = None
|
||||||
|
self.session_id = uuid.uuid4().hex
|
||||||
|
self.cached_object_info = {}
|
||||||
|
|
||||||
@routes.get('/ws')
|
@routes.get('/ws')
|
||||||
async def websocket_handler(request):
|
async def websocket_handler(request):
|
||||||
@ -306,39 +309,24 @@ class PromptServer():
|
|||||||
async def get_prompt(request):
|
async def get_prompt(request):
|
||||||
return web.json_response(self.get_queue_info())
|
return web.json_response(self.get_queue_info())
|
||||||
|
|
||||||
def node_info(node_class):
|
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
|
||||||
info = {}
|
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
|
||||||
info['output'] = obj_class.RETURN_TYPES
|
|
||||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
|
||||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
|
||||||
info['name'] = node_class
|
|
||||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
|
||||||
info['description'] = ''
|
|
||||||
info['category'] = 'sd'
|
|
||||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
|
||||||
info['output_node'] = True
|
|
||||||
else:
|
|
||||||
info['output_node'] = False
|
|
||||||
|
|
||||||
if hasattr(obj_class, 'CATEGORY'):
|
|
||||||
info['category'] = obj_class.CATEGORY
|
|
||||||
return info
|
|
||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
out = {}
|
out = {
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
"object_info": self.cached_object_info,
|
||||||
out[x] = node_info(x)
|
"session_id": self.session_id
|
||||||
|
}
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/object_info/{node_class}")
|
@routes.get("/object_info/{node_class}")
|
||||||
async def get_object_info_node(request):
|
async def get_object_info_node(request):
|
||||||
node_class = request.match_info.get("node_class", None)
|
node_class = request.match_info.get("node_class", None)
|
||||||
out = {}
|
objinfo = {}
|
||||||
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||||
out[node_class] = node_info(node_class)
|
objinfo[node_class] = node_info(node_class)
|
||||||
|
out = {
|
||||||
|
"object_info": objinfo,
|
||||||
|
"session_id": self.session_id
|
||||||
|
}
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/history")
|
@routes.get("/history")
|
||||||
@ -422,6 +410,15 @@ class PromptServer():
|
|||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
def load_node_info(self):
|
||||||
|
# Precalculate node info to save on request time
|
||||||
|
start_time = time.time()
|
||||||
|
objinfo = {}
|
||||||
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
|
objinfo[x] = nodes.get_node_info(x)
|
||||||
|
self.cached_object_info = objinfo
|
||||||
|
print(f"Loaded {len(objinfo)} node definitions in {time.time() - start_time} seconds")
|
||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.app.add_routes(self.routes)
|
self.app.add_routes(self.routes)
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
|
|||||||
@ -138,7 +138,8 @@ class ComfyApi extends EventTarget {
|
|||||||
*/
|
*/
|
||||||
async getNodeDefs() {
|
async getNodeDefs() {
|
||||||
const resp = await fetch("object_info", { cache: "no-store" });
|
const resp = await fetch("object_info", { cache: "no-store" });
|
||||||
return await resp.json();
|
const json = await resp.json();
|
||||||
|
return json["object_info"];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user