From 4b8b12cf5903e9530aa66de5b192cfe82add4b65 Mon Sep 17 00:00:00 2001 From: Hacker 17082006 Date: Wed, 15 Mar 2023 17:44:47 +0700 Subject: [PATCH] Custom endpoint is guz --- ...xample_node.py.example => example_node.py} | 17 ++++++++- server.py | 35 ++++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) rename custom_nodes/{example_node.py.example => example_node.py} (86%) diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py similarity index 86% rename from custom_nodes/example_node.py.example rename to custom_nodes/example_node.py index 1bb1a5a37..6037c31ad 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py @@ -1,3 +1,4 @@ +from aiohttp import web class Example: """ A example node @@ -78,9 +79,23 @@ class Example: image = 1.0 - image return (image,) - # A dictionary that contains all nodes you want to export with their names # NOTE: names should be globally unique NODE_CLASS_MAPPINGS = { "Example": Example } + +class CustomEndpoint: + routes = None + def __init__(self): + self.routes = web.RouteTableDef() + + @self.routes.get("/test") + async def get(request): + return web.Response(text="Hello World! This a test endpoint in example_node.py") + + @self.routes.post("/test") + async def post(request): + text_data = request.text() + return web.Response(text=f"Hello World! This a test endpoint in example_node.py\n\nYour request body: {text_data}") + diff --git a/server.py b/server.py index eb6857010..68eef9836 100644 --- a/server.py +++ b/server.py @@ -17,7 +17,8 @@ except ImportError: sys.exit() import mimetypes - +import importlib +import traceback @web.middleware async def cache_control(request: web.Request, handler): @@ -229,6 +230,38 @@ class PromptServer(): web.static('/', self.web_root), ]) + def load_custom_endpoint(module_path): + module_name = os.path.basename(module_path) + if os.path.isfile(module_path): + sp = os.path.splitext(module_path) + module_name = sp[0] + try: + if os.path.isfile(module_path): + module_spec = importlib.util.spec_from_file_location(module_name, module_path) + else: + module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) + module = importlib.util.module_from_spec(module_spec) + sys.modules[module_name] = module + module_spec.loader.exec_module(module) + if hasattr(module, "CustomEndpoint") and getattr(module, "CustomEndpoint") is not None: + self.app.add_routes(module.CustomEndpoint().routes) + except Exception as e: + print(traceback.format_exc()) + print(f"Cannot import {module_path} module for custom endpoints:", e) + + def load_custom_endpoints(): + CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") + possible_modules = os.listdir(CUSTOM_NODE_PATH) + if "__pycache__" in possible_modules: + possible_modules.remove("__pycache__") + + for possible_module in possible_modules: + module_path = os.path.join(CUSTOM_NODE_PATH, possible_module) + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + load_custom_endpoint(module_path) + + load_custom_endpoints() + def get_queue_info(self): prompt_info = {} exec_info = {}