mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 18:27:40 +08:00
266 lines
8.9 KiB
Python
266 lines
8.9 KiB
Python
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
|
|
"""Stateless RPC Implementation for PromptServer.
|
|
|
|
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
|
|
- Host: PromptServerService (RPC Handler)
|
|
- Child: PromptServerStub (Interface Implementation)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
from typing import Any, Dict, Optional, Callable
|
|
|
|
import logging
|
|
from aiohttp import web
|
|
|
|
# IMPORTS
|
|
from pyisolate import ProxiedSingleton
|
|
|
|
logger = logging.getLogger(__name__)
|
|
LOG_PREFIX = "[Isolation:C<->H]"
|
|
|
|
# ...
|
|
|
|
# =============================================================================
|
|
# CHILD SIDE: PromptServerStub
|
|
# =============================================================================
|
|
|
|
|
|
class PromptServerStub:
|
|
"""Stateless Stub for PromptServer."""
|
|
|
|
# Masquerade as the real server module
|
|
__module__ = "server"
|
|
|
|
_instance: Optional["PromptServerStub"] = None
|
|
_rpc: Optional[Any] = None # This will be the Caller object
|
|
_source_file: Optional[str] = None
|
|
|
|
def __init__(self):
|
|
self.routes = RouteStub(self)
|
|
|
|
@classmethod
|
|
def set_rpc(cls, rpc: Any) -> None:
|
|
"""Inject RPC client (called by adapter.py or manually)."""
|
|
# Create caller for HOST Service
|
|
# Assuming Host Service is registered as "PromptServerService" (class name)
|
|
# We target the Host Service Class
|
|
target_id = "PromptServerService"
|
|
# We need to pass a class to create_caller? Usually yes.
|
|
# But we don't have the Service class imported here necessarily (if running on child).
|
|
# pyisolate check verify_service type?
|
|
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
|
|
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
|
|
# We need a dummy class with right name?
|
|
# Or just rely on string ID if create_caller supports it?
|
|
# Standard: rpc.create_caller(PromptServerStub, target_id)
|
|
# But wait, PromptServerStub is the *Local* class.
|
|
# We want to call *Remote* class.
|
|
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
|
|
# The first arg is 'service_cls'.
|
|
cls._rpc = rpc.create_caller(
|
|
PromptServerService, target_id
|
|
) # We import Service below?
|
|
|
|
# We need PromptServerService available for the create_caller call?
|
|
# Or just use the Stub class if ID matches?
|
|
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
|
|
|
|
@property
|
|
def instance(self) -> "PromptServerStub":
|
|
return self
|
|
|
|
# ... Compatibility ...
|
|
@classmethod
|
|
def _get_source_file(cls) -> str:
|
|
if cls._source_file is None:
|
|
import folder_paths
|
|
|
|
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
|
|
return cls._source_file
|
|
|
|
@property
|
|
def __file__(self) -> str:
|
|
return self._get_source_file()
|
|
|
|
# --- Properties ---
|
|
@property
|
|
def client_id(self) -> Optional[str]:
|
|
return "isolated_client"
|
|
|
|
def supports(self, feature: str) -> bool:
|
|
return True
|
|
|
|
@property
|
|
def app(self):
|
|
raise RuntimeError(
|
|
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
|
|
)
|
|
|
|
@property
|
|
def prompt_queue(self):
|
|
raise RuntimeError(
|
|
"PromptServer.prompt_queue is not accessible in isolated nodes."
|
|
)
|
|
|
|
# --- UI Communication (RPC Delegates) ---
|
|
async def send_sync(
|
|
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
|
) -> None:
|
|
if self._rpc:
|
|
await self._rpc.ui_send_sync(event, data, sid)
|
|
|
|
async def send(
|
|
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
|
) -> None:
|
|
if self._rpc:
|
|
await self._rpc.ui_send(event, data, sid)
|
|
|
|
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
|
|
if self._rpc:
|
|
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
|
|
# We must schedule it?
|
|
# Or use fire_remote equivalent?
|
|
# Caller object usually proxies calls. If host method is async, it returns coro.
|
|
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
|
|
# But UtilsProxy hook wrapper creates task.
|
|
# Does send_progress_text need to be sync? Yes, node code calls it sync.
|
|
import asyncio
|
|
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
|
|
except RuntimeError:
|
|
pass # Sync context without loop?
|
|
|
|
# --- Route Registration Logic ---
|
|
def register_route(self, method: str, path: str, handler: Callable):
|
|
"""Register a route handler via RPC."""
|
|
if not self._rpc:
|
|
logger.error("RPC not initialized in PromptServerStub")
|
|
return
|
|
|
|
# Fire registration async
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
|
|
except RuntimeError:
|
|
pass
|
|
|
|
|
|
class RouteStub:
|
|
"""Simulates aiohttp.web.RouteTableDef."""
|
|
|
|
def __init__(self, stub: PromptServerStub):
|
|
self._stub = stub
|
|
|
|
def get(self, path: str):
|
|
def decorator(handler):
|
|
self._stub.register_route("GET", path, handler)
|
|
return handler
|
|
|
|
return decorator
|
|
|
|
def post(self, path: str):
|
|
def decorator(handler):
|
|
self._stub.register_route("POST", path, handler)
|
|
return handler
|
|
|
|
return decorator
|
|
|
|
def patch(self, path: str):
|
|
def decorator(handler):
|
|
self._stub.register_route("PATCH", path, handler)
|
|
return handler
|
|
|
|
return decorator
|
|
|
|
def put(self, path: str):
|
|
def decorator(handler):
|
|
self._stub.register_route("PUT", path, handler)
|
|
return handler
|
|
|
|
return decorator
|
|
|
|
def delete(self, path: str):
|
|
def decorator(handler):
|
|
self._stub.register_route("DELETE", path, handler)
|
|
return handler
|
|
|
|
return decorator
|
|
|
|
|
|
# =============================================================================
|
|
# HOST SIDE: PromptServerService
|
|
# =============================================================================
|
|
|
|
|
|
class PromptServerService(ProxiedSingleton):
|
|
"""Host-side RPC Service for PromptServer."""
|
|
|
|
def __init__(self):
|
|
# We will bind to the real server instance lazily or via global import
|
|
pass
|
|
|
|
@property
|
|
def server(self):
|
|
from server import PromptServer
|
|
|
|
return PromptServer.instance
|
|
|
|
async def ui_send_sync(
|
|
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
|
):
|
|
await self.server.send_sync(event, data, sid)
|
|
|
|
async def ui_send(
|
|
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
|
):
|
|
await self.server.send(event, data, sid)
|
|
|
|
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
|
|
# Made async to be awaitable by RPC layer
|
|
self.server.send_progress_text(text, node_id, sid)
|
|
|
|
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
|
|
"""RPC Target: Register a route that forwards to the Child."""
|
|
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
|
|
|
|
async def route_wrapper(request: web.Request) -> web.Response:
|
|
# 1. Capture request data
|
|
req_data = {
|
|
"method": request.method,
|
|
"path": request.path,
|
|
"query": dict(request.query),
|
|
}
|
|
if request.can_read_body:
|
|
req_data["text"] = await request.text()
|
|
|
|
try:
|
|
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
|
|
result = await child_handler_proxy(req_data)
|
|
|
|
# 3. Serialize Response
|
|
return self._serialize_response(result)
|
|
except Exception as e:
|
|
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
|
|
return web.Response(status=500, text=str(e))
|
|
|
|
# Register loop
|
|
self.server.app.router.add_route(method, path, route_wrapper)
|
|
|
|
def _serialize_response(self, result: Any) -> web.Response:
|
|
"""Helper to convert Child result -> web.Response"""
|
|
if isinstance(result, web.Response):
|
|
return result
|
|
# Handle dict (json)
|
|
if isinstance(result, dict):
|
|
return web.json_response(result)
|
|
# Handle string
|
|
if isinstance(result, str):
|
|
return web.Response(text=result)
|
|
# Fallback
|
|
return web.Response(text=str(result))
|