From e6fbc2bedf76848c9ff3da4b5b506bb380a41c8f Mon Sep 17 00:00:00 2001 From: Hacker 17082006 Date: Wed, 15 Mar 2023 22:28:32 +0700 Subject: [PATCH] I want decor --- .../{example_node.py => example_node.py.example} | 13 +++++-------- server.py | 14 ++++++-------- 2 files changed, 11 insertions(+), 16 deletions(-) rename custom_nodes/{example_node.py => example_node.py.example} (94%) diff --git a/custom_nodes/example_node.py b/custom_nodes/example_node.py.example similarity index 94% rename from custom_nodes/example_node.py rename to custom_nodes/example_node.py.example index 6037c31ad..24aaada2e 100644 --- a/custom_nodes/example_node.py +++ b/custom_nodes/example_node.py.example @@ -85,17 +85,14 @@ NODE_CLASS_MAPPINGS = { "Example": Example } -class CustomEndpoint: - routes = None - def __init__(self): - self.routes = web.RouteTableDef() - - @self.routes.get("/test") +class PatchRoutes: + def __init__(self, routes): + @routes.get("/test") async def get(request): return web.Response(text="Hello World! This a test endpoint in example_node.py") - @self.routes.post("/test") + @routes.post("/test") async def post(request): - text_data = request.text() + text_data = await request.text() return web.Response(text=f"Hello World! This a test endpoint in example_node.py\n\nYour request body: {text_data}") diff --git a/server.py b/server.py index 68eef9836..46587824a 100644 --- a/server.py +++ b/server.py @@ -225,11 +225,6 @@ class PromptServer(): return web.Response(status=200) - self.app.add_routes(routes) - self.app.add_routes([ - web.static('/', self.web_root), - ]) - def load_custom_endpoint(module_path): module_name = os.path.basename(module_path) if os.path.isfile(module_path): @@ -241,10 +236,9 @@ class PromptServer(): else: module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) module = importlib.util.module_from_spec(module_spec) - sys.modules[module_name] = module module_spec.loader.exec_module(module) - if hasattr(module, "CustomEndpoint") and getattr(module, "CustomEndpoint") is not None: - self.app.add_routes(module.CustomEndpoint().routes) + if hasattr(module, "PatchRoutes") and getattr(module, "PatchRoutes") is not None: + module.PatchRoutes(routes) except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom endpoints:", e) @@ -261,6 +255,10 @@ class PromptServer(): load_custom_endpoint(module_path) load_custom_endpoints() + self.app.add_routes(routes) + self.app.add_routes([ + web.static('/', self.web_root), + ]) def get_queue_info(self): prompt_info = {}