From 9ca799362ddf0c749d93668128b55e6243d4b1b5 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Fri, 27 Feb 2026 12:41:58 -0600 Subject: [PATCH] feat(isolation-proxies): proxy base + host service proxies --- comfy/isolation/proxies/__init__.py | 17 ++ comfy/isolation/proxies/base.py | 213 ++++++++++++++ comfy/isolation/proxies/folder_paths_proxy.py | 29 ++ comfy/isolation/proxies/helper_proxies.py | 98 +++++++ .../proxies/model_management_proxy.py | 27 ++ comfy/isolation/proxies/progress_proxy.py | 35 +++ comfy/isolation/proxies/prompt_server_impl.py | 265 ++++++++++++++++++ comfy/isolation/proxies/utils_proxy.py | 64 +++++ 8 files changed, 748 insertions(+) create mode 100644 comfy/isolation/proxies/__init__.py create mode 100644 comfy/isolation/proxies/base.py create mode 100644 comfy/isolation/proxies/folder_paths_proxy.py create mode 100644 comfy/isolation/proxies/helper_proxies.py create mode 100644 comfy/isolation/proxies/model_management_proxy.py create mode 100644 comfy/isolation/proxies/progress_proxy.py create mode 100644 comfy/isolation/proxies/prompt_server_impl.py create mode 100644 comfy/isolation/proxies/utils_proxy.py diff --git a/comfy/isolation/proxies/__init__.py b/comfy/isolation/proxies/__init__.py new file mode 100644 index 000000000..30d0089ad --- /dev/null +++ b/comfy/isolation/proxies/__init__.py @@ -0,0 +1,17 @@ +from .base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, + get_thread_loop, + run_coro_in_new_loop, +) + +__all__ = [ + "IS_CHILD_PROCESS", + "BaseRegistry", + "BaseProxy", + "get_thread_loop", + "run_coro_in_new_loop", + "detach_if_grad", +] diff --git a/comfy/isolation/proxies/base.py b/comfy/isolation/proxies/base.py new file mode 100644 index 000000000..90f75b1e3 --- /dev/null +++ b/comfy/isolation/proxies/base.py @@ -0,0 +1,213 @@ +# pylint: disable=global-statement,import-outside-toplevel,protected-access +from __future__ import annotations + +import asyncio +import logging +import os +import threading +import weakref +from typing import Any, Callable, Dict, Generic, Optional, TypeVar + +try: + from pyisolate import ProxiedSingleton +except ImportError: + + class ProxiedSingleton: # type: ignore[no-redef] + pass + + +logger = logging.getLogger(__name__) + +IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1" +_thread_local = threading.local() +T = TypeVar("T") + + +def get_thread_loop() -> asyncio.AbstractEventLoop: + loop = getattr(_thread_local, "loop", None) + if loop is None or loop.is_closed(): + loop = asyncio.new_event_loop() + _thread_local.loop = loop + return loop + + +def run_coro_in_new_loop(coro: Any) -> Any: + result_box: Dict[str, Any] = {} + exc_box: Dict[str, BaseException] = {} + + def runner() -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result_box["value"] = loop.run_until_complete(coro) + except Exception as exc: # noqa: BLE001 + exc_box["exc"] = exc + finally: + loop.close() + + t = threading.Thread(target=runner, daemon=True) + t.start() + t.join() + if "exc" in exc_box: + raise exc_box["exc"] + return result_box.get("value") + + +def detach_if_grad(obj: Any) -> Any: + try: + import torch + except Exception: + return obj + + if isinstance(obj, torch.Tensor): + return obj.detach() if obj.requires_grad else obj + if isinstance(obj, (list, tuple)): + return type(obj)(detach_if_grad(x) for x in obj) + if isinstance(obj, dict): + return {k: detach_if_grad(v) for k, v in obj.items()} + return obj + + +class BaseRegistry(ProxiedSingleton, Generic[T]): + _type_prefix: str = "base" + + def __init__(self) -> None: + if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object: + super().__init__() + self._registry: Dict[str, T] = {} + self._id_map: Dict[int, str] = {} + self._counter = 0 + self._lock = threading.Lock() + + def register(self, instance: T) -> str: + with self._lock: + obj_id = id(instance) + if obj_id in self._id_map: + return self._id_map[obj_id] + instance_id = f"{self._type_prefix}_{self._counter}" + self._counter += 1 + self._registry[instance_id] = instance + self._id_map[obj_id] = instance_id + return instance_id + + def unregister_sync(self, instance_id: str) -> None: + with self._lock: + instance = self._registry.pop(instance_id, None) + if instance: + self._id_map.pop(id(instance), None) + + def _get_instance(self, instance_id: str) -> T: + if IS_CHILD_PROCESS: + raise RuntimeError( + f"[{self.__class__.__name__}] _get_instance called in child" + ) + with self._lock: + instance = self._registry.get(instance_id) + if instance is None: + raise ValueError(f"{instance_id} not found") + return instance + + +_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None + + +def set_global_loop(loop: asyncio.AbstractEventLoop) -> None: + global _GLOBAL_LOOP + _GLOBAL_LOOP = loop + + +class BaseProxy(Generic[T]): + _registry_class: type = BaseRegistry # type: ignore[type-arg] + __module__: str = "comfy.isolation.proxies.base" + + def __init__( + self, + instance_id: str, + registry: Optional[Any] = None, + manage_lifecycle: bool = False, + ) -> None: + self._instance_id = instance_id + self._rpc_caller: Optional[Any] = None + self._registry = registry if registry is not None else self._registry_class() + self._manage_lifecycle = manage_lifecycle + self._cleaned_up = False + if manage_lifecycle and not IS_CHILD_PROCESS: + self._finalizer = weakref.finalize( + self, self._registry.unregister_sync, instance_id + ) + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is None: + raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child") + self._rpc_caller = rpc.create_caller( + self._registry_class, self._registry_class.get_remote_id() + ) + return self._rpc_caller + + def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + rpc = self._get_rpc() + method = getattr(rpc, method_name) + coro = method(self._instance_id, *args, **kwargs) + + # If we have a global loop (Main Thread Loop), use it for dispatch from worker threads + if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): + try: + # If we are already in the global loop, we can't block on it? + # Actually, this method is synchronous (__getattr__ -> lambda). + # If called from async context in main loop, we need to handle that. + curr_loop = asyncio.get_running_loop() + if curr_loop is _GLOBAL_LOOP: + # We are in the main loop. We cannot await/block here if we are just a sync function. + # But proxies are often called from sync code. + # If called from sync code in main loop, creating a new loop is bad. + # But we can't await `coro`. + # This implies proxies MUST be awaited if called from async context? + # Existing code used `run_coro_in_new_loop` which is weird. + # Let's trust that if we are in a thread (RuntimeError on get_running_loop), + # we use run_coroutine_threadsafe. + pass + except RuntimeError: + # No running loop - we are in a worker thread. + future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP) + return future.result() + + try: + asyncio.get_running_loop() + return run_coro_in_new_loop(coro) + except RuntimeError: + loop = get_thread_loop() + return loop.run_until_complete(coro) + + def __getstate__(self) -> Dict[str, Any]: + return {"_instance_id": self._instance_id} + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._instance_id = state["_instance_id"] + self._rpc_caller = None + self._registry = self._registry_class() + self._manage_lifecycle = False + self._cleaned_up = False + + def cleanup(self) -> None: + if self._cleaned_up or IS_CHILD_PROCESS: + return + self._cleaned_up = True + finalizer = getattr(self, "_finalizer", None) + if finalizer is not None: + finalizer.detach() + self._registry.unregister_sync(self._instance_id) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self._instance_id}>" + + +def create_rpc_method(method_name: str) -> Callable[..., Any]: + def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any: + return self._call_rpc(method_name, *args, **kwargs) + + method.__name__ = method_name + return method diff --git a/comfy/isolation/proxies/folder_paths_proxy.py b/comfy/isolation/proxies/folder_paths_proxy.py new file mode 100644 index 000000000..a2996ec24 --- /dev/null +++ b/comfy/isolation/proxies/folder_paths_proxy.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from typing import Dict + +import folder_paths +from pyisolate import ProxiedSingleton + + +class FolderPathsProxy(ProxiedSingleton): + """ + Dynamic proxy for folder_paths. + Uses __getattr__ for most lookups, with explicit handling for + mutable collections to ensure efficient by-value transfer. + """ + + def __getattr__(self, name): + return getattr(folder_paths, name) + + # Return dict snapshots (avoid RPC chatter) + @property + def folder_names_and_paths(self) -> Dict: + return dict(folder_paths.folder_names_and_paths) + + @property + def extension_mimetypes_cache(self) -> Dict: + return dict(folder_paths.extension_mimetypes_cache) + + @property + def filename_list_cache(self) -> Dict: + return dict(folder_paths.filename_list_cache) diff --git a/comfy/isolation/proxies/helper_proxies.py b/comfy/isolation/proxies/helper_proxies.py new file mode 100644 index 000000000..a50b9e4c4 --- /dev/null +++ b/comfy/isolation/proxies/helper_proxies.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + + +class AnyTypeProxy(str): + """Replacement for custom AnyType objects used by some nodes.""" + + def __new__(cls, value: str = "*"): + return super().__new__(cls, value) + + def __ne__(self, other): # type: ignore[override] + return False + + +class FlexibleOptionalInputProxy(dict): + """Replacement for FlexibleOptionalInputType to allow dynamic inputs.""" + + def __init__(self, flex_type, data: Optional[Dict[str, object]] = None): + super().__init__() + self.type = flex_type + if data: + self.update(data) + + def __getitem__(self, key): # type: ignore[override] + return (self.type,) + + def __contains__(self, key): # type: ignore[override] + return True + + +class ByPassTypeTupleProxy(tuple): + """Replacement for ByPassTypeTuple to mirror wildcard fallback behavior.""" + + def __new__(cls, values): + return super().__new__(cls, values) + + def __getitem__(self, index): # type: ignore[override] + if index >= len(self): + return AnyTypeProxy("*") + return super().__getitem__(index) + + +def _restore_special_value(value: Any) -> Any: + if isinstance(value, dict): + if value.get("__pyisolate_any_type__"): + return AnyTypeProxy(value.get("value", "*")) + if value.get("__pyisolate_flexible_optional__"): + flex_type = _restore_special_value(value.get("type")) + data_raw = value.get("data") + data = ( + {k: _restore_special_value(v) for k, v in data_raw.items()} + if isinstance(data_raw, dict) + else {} + ) + return FlexibleOptionalInputProxy(flex_type, data) + if value.get("__pyisolate_tuple__") is not None: + return tuple( + _restore_special_value(v) for v in value["__pyisolate_tuple__"] + ) + if value.get("__pyisolate_bypass_tuple__") is not None: + return ByPassTypeTupleProxy( + tuple( + _restore_special_value(v) + for v in value["__pyisolate_bypass_tuple__"] + ) + ) + return {k: _restore_special_value(v) for k, v in value.items()} + if isinstance(value, list): + return [_restore_special_value(v) for v in value] + return value + + +def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]: + """Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects.""" + + if not isinstance(raw, dict): + return raw # type: ignore[return-value] + + restored: Dict[str, object] = {} + for section, entries in raw.items(): + if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"): + restored[section] = _restore_special_value(entries) + elif isinstance(entries, dict): + restored[section] = { + k: _restore_special_value(v) for k, v in entries.items() + } + else: + restored[section] = _restore_special_value(entries) + return restored + + +__all__ = [ + "AnyTypeProxy", + "FlexibleOptionalInputProxy", + "ByPassTypeTupleProxy", + "restore_input_types", +] diff --git a/comfy/isolation/proxies/model_management_proxy.py b/comfy/isolation/proxies/model_management_proxy.py new file mode 100644 index 000000000..00e14d9b4 --- /dev/null +++ b/comfy/isolation/proxies/model_management_proxy.py @@ -0,0 +1,27 @@ +import comfy.model_management as mm +from pyisolate import ProxiedSingleton + + +class ModelManagementProxy(ProxiedSingleton): + """ + Dynamic proxy for comfy.model_management. + Uses __getattr__ to forward all calls to the underlying module, + reducing maintenance burden. + """ + + # Explicitly expose Enums/Classes as properties + @property + def VRAMState(self): + return mm.VRAMState + + @property + def CPUState(self): + return mm.CPUState + + @property + def OOM_EXCEPTION(self): + return mm.OOM_EXCEPTION + + def __getattr__(self, name): + """Forward all other attribute access to the module.""" + return getattr(mm, name) diff --git a/comfy/isolation/proxies/progress_proxy.py b/comfy/isolation/proxies/progress_proxy.py new file mode 100644 index 000000000..44494ea31 --- /dev/null +++ b/comfy/isolation/proxies/progress_proxy.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import logging +from typing import Any, Optional + +try: + from pyisolate import ProxiedSingleton +except ImportError: + + class ProxiedSingleton: + pass + + +from comfy_execution.progress import get_progress_state + +logger = logging.getLogger(__name__) + + +class ProgressProxy(ProxiedSingleton): + def set_progress( + self, + value: float, + max_value: float, + node_id: Optional[str] = None, + image: Any = None, + ) -> None: + get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=image, + ) + + +__all__ = ["ProgressProxy"] diff --git a/comfy/isolation/proxies/prompt_server_impl.py b/comfy/isolation/proxies/prompt_server_impl.py new file mode 100644 index 000000000..2a775e097 --- /dev/null +++ b/comfy/isolation/proxies/prompt_server_impl.py @@ -0,0 +1,265 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called +"""Stateless RPC Implementation for PromptServer. + +Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture. +- Host: PromptServerService (RPC Handler) +- Child: PromptServerStub (Interface Implementation) +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any, Dict, Optional, Callable + +import logging +from aiohttp import web + +# IMPORTS +from pyisolate import ProxiedSingleton + +logger = logging.getLogger(__name__) +LOG_PREFIX = "[Isolation:C<->H]" + +# ... + +# ============================================================================= +# CHILD SIDE: PromptServerStub +# ============================================================================= + + +class PromptServerStub: + """Stateless Stub for PromptServer.""" + + # Masquerade as the real server module + __module__ = "server" + + _instance: Optional["PromptServerStub"] = None + _rpc: Optional[Any] = None # This will be the Caller object + _source_file: Optional[str] = None + + def __init__(self): + self.routes = RouteStub(self) + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + """Inject RPC client (called by adapter.py or manually).""" + # Create caller for HOST Service + # Assuming Host Service is registered as "PromptServerService" (class name) + # We target the Host Service Class + target_id = "PromptServerService" + # We need to pass a class to create_caller? Usually yes. + # But we don't have the Service class imported here necessarily (if running on child). + # pyisolate check verify_service type? + # If we pass PromptServerStub as the 'class', it might mismatch if checking types. + # But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub. + # We need a dummy class with right name? + # Or just rely on string ID if create_caller supports it? + # Standard: rpc.create_caller(PromptServerStub, target_id) + # But wait, PromptServerStub is the *Local* class. + # We want to call *Remote* class. + # If we use PromptServerStub as the type, returning object will be typed as PromptServerStub? + # The first arg is 'service_cls'. + cls._rpc = rpc.create_caller( + PromptServerService, target_id + ) # We import Service below? + + # We need PromptServerService available for the create_caller call? + # Or just use the Stub class if ID matches? + # prompt_server_impl.py defines BOTH. So PromptServerService IS available! + + @property + def instance(self) -> "PromptServerStub": + return self + + # ... Compatibility ... + @classmethod + def _get_source_file(cls) -> str: + if cls._source_file is None: + import folder_paths + + cls._source_file = os.path.join(folder_paths.base_path, "server.py") + return cls._source_file + + @property + def __file__(self) -> str: + return self._get_source_file() + + # --- Properties --- + @property + def client_id(self) -> Optional[str]: + return "isolated_client" + + def supports(self, feature: str) -> bool: + return True + + @property + def app(self): + raise RuntimeError( + "PromptServer.app is not accessible in isolated nodes. Use RPC routes instead." + ) + + @property + def prompt_queue(self): + raise RuntimeError( + "PromptServer.prompt_queue is not accessible in isolated nodes." + ) + + # --- UI Communication (RPC Delegates) --- + async def send_sync( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ) -> None: + if self._rpc: + await self._rpc.ui_send_sync(event, data, sid) + + async def send( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ) -> None: + if self._rpc: + await self._rpc.ui_send(event, data, sid) + + def send_progress_text(self, text: str, node_id: str, sid=None) -> None: + if self._rpc: + # Fire and forget likely needed. If method is async on host, caller invocation returns coroutine. + # We must schedule it? + # Or use fire_remote equivalent? + # Caller object usually proxies calls. If host method is async, it returns coro. + # If we are sync here (send_progress_text checks imply sync usage), we must background it. + # But UtilsProxy hook wrapper creates task. + # Does send_progress_text need to be sync? Yes, node code calls it sync. + import asyncio + + try: + loop = asyncio.get_running_loop() + loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid)) + except RuntimeError: + pass # Sync context without loop? + + # --- Route Registration Logic --- + def register_route(self, method: str, path: str, handler: Callable): + """Register a route handler via RPC.""" + if not self._rpc: + logger.error("RPC not initialized in PromptServerStub") + return + + # Fire registration async + try: + loop = asyncio.get_running_loop() + loop.create_task(self._rpc.register_route_rpc(method, path, handler)) + except RuntimeError: + pass + + +class RouteStub: + """Simulates aiohttp.web.RouteTableDef.""" + + def __init__(self, stub: PromptServerStub): + self._stub = stub + + def get(self, path: str): + def decorator(handler): + self._stub.register_route("GET", path, handler) + return handler + + return decorator + + def post(self, path: str): + def decorator(handler): + self._stub.register_route("POST", path, handler) + return handler + + return decorator + + def patch(self, path: str): + def decorator(handler): + self._stub.register_route("PATCH", path, handler) + return handler + + return decorator + + def put(self, path: str): + def decorator(handler): + self._stub.register_route("PUT", path, handler) + return handler + + return decorator + + def delete(self, path: str): + def decorator(handler): + self._stub.register_route("DELETE", path, handler) + return handler + + return decorator + + +# ============================================================================= +# HOST SIDE: PromptServerService +# ============================================================================= + + +class PromptServerService(ProxiedSingleton): + """Host-side RPC Service for PromptServer.""" + + def __init__(self): + # We will bind to the real server instance lazily or via global import + pass + + @property + def server(self): + from server import PromptServer + + return PromptServer.instance + + async def ui_send_sync( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ): + await self.server.send_sync(event, data, sid) + + async def ui_send( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ): + await self.server.send(event, data, sid) + + async def ui_send_progress_text(self, text: str, node_id: str, sid=None): + # Made async to be awaitable by RPC layer + self.server.send_progress_text(text, node_id, sid) + + async def register_route_rpc(self, method: str, path: str, child_handler_proxy): + """RPC Target: Register a route that forwards to the Child.""" + logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}") + + async def route_wrapper(request: web.Request) -> web.Response: + # 1. Capture request data + req_data = { + "method": request.method, + "path": request.path, + "query": dict(request.query), + } + if request.can_read_body: + req_data["text"] = await request.text() + + try: + # 2. Call Child Handler via RPC (child_handler_proxy is async callable) + result = await child_handler_proxy(req_data) + + # 3. Serialize Response + return self._serialize_response(result) + except Exception as e: + logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}") + return web.Response(status=500, text=str(e)) + + # Register loop + self.server.app.router.add_route(method, path, route_wrapper) + + def _serialize_response(self, result: Any) -> web.Response: + """Helper to convert Child result -> web.Response""" + if isinstance(result, web.Response): + return result + # Handle dict (json) + if isinstance(result, dict): + return web.json_response(result) + # Handle string + if isinstance(result, str): + return web.Response(text=result) + # Fallback + return web.Response(text=str(result)) diff --git a/comfy/isolation/proxies/utils_proxy.py b/comfy/isolation/proxies/utils_proxy.py new file mode 100644 index 000000000..432f7ec90 --- /dev/null +++ b/comfy/isolation/proxies/utils_proxy.py @@ -0,0 +1,64 @@ +# pylint: disable=cyclic-import,import-outside-toplevel +from __future__ import annotations + +from typing import Optional, Any +import comfy.utils +from pyisolate import ProxiedSingleton + +import os + + +class UtilsProxy(ProxiedSingleton): + """ + Proxy for comfy.utils. + Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates + from isolated nodes reach the host. + """ + + # _instance and __new__ removed to rely on SingletonMetaclass + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + # Create caller using class name as ID (standard for Singletons) + cls._rpc = rpc.create_caller(cls, "UtilsProxy") + + async def progress_bar_hook( + self, + value: int, + total: int, + preview: Optional[bytes] = None, + node_id: Optional[str] = None, + ) -> Any: + """ + Host-side implementation: forwards the call to the real global hook. + Child-side: this method call is intercepted by RPC and sent to host. + """ + if os.environ.get("PYISOLATE_CHILD") == "1": + # Manual RPC dispatch for Child process + # Use class-level RPC storage (Static Injection) + if UtilsProxy._rpc: + return await UtilsProxy._rpc.progress_bar_hook( + value, total, preview, node_id + ) + + # Fallback channel: global child rpc + try: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + get_child_rpc_instance() + # If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it, + # but we need a caller. For now, just pass to avoid crashing. + pass + except (ImportError, LookupError): + pass + + return None + + # Host Execution + if comfy.utils.PROGRESS_BAR_HOOK is not None: + comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id) + + def set_progress_bar_global_hook(self, hook: Any) -> None: + """Forward hook registration (though usually not needed from child).""" + comfy.utils.set_progress_bar_global_hook(hook)