fix(isolation): refresh loader and isolated route handling

This commit is contained in:
John Pollock 2026-04-12 21:08:35 -05:00
parent 51e70fe033
commit 07fffdd593
10 changed files with 335 additions and 213 deletions

View File

@ -44,21 +44,15 @@ def initialize_proxies() -> None:
from .child_hooks import is_child_process from .child_hooks import is_child_process
is_child = is_child_process() is_child = is_child_process()
logger.warning(
"%s DIAG:initialize_proxies | is_child=%s | PYISOLATE_CHILD=%s",
LOG_PREFIX, is_child, os.environ.get("PYISOLATE_CHILD"),
)
if is_child: if is_child:
from .child_hooks import initialize_child_process from .child_hooks import initialize_child_process
initialize_child_process() initialize_child_process()
logger.warning("%s DIAG:initialize_proxies child_process initialized", LOG_PREFIX)
else: else:
from .host_hooks import initialize_host_process from .host_hooks import initialize_host_process
initialize_host_process() initialize_host_process()
logger.warning("%s DIAG:initialize_proxies host_process initialized", LOG_PREFIX)
if start_shm_forensics is not None: if start_shm_forensics is not None:
start_shm_forensics() start_shm_forensics()

View File

@ -586,29 +586,6 @@ class ComfyUIAdapter(IsolationAdapter):
register_hooks_serializers(registry) register_hooks_serializers(registry)
# Generic Numpy Serializer
def serialize_numpy(obj: Any) -> Any:
import torch
try:
# Attempt zero-copy conversion to Tensor
return torch.from_numpy(obj)
except Exception:
# Fallback for non-numeric arrays (strings, objects, mixes)
return obj.tolist()
def deserialize_numpy_b64(data: Any) -> Any:
"""Deserialize base64-encoded ndarray from sealed worker."""
import base64
import numpy as np
if isinstance(data, dict) and "data" in data and "dtype" in data:
raw = base64.b64decode(data["data"])
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
return torch.from_numpy(arr.copy())
return data
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
# -- File3D (comfy_api.latest._util.geometry_types) --------------------- # -- File3D (comfy_api.latest._util.geometry_types) ---------------------
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129 # Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
@ -873,93 +850,15 @@ class ComfyUIAdapter(IsolationAdapter):
return return
if api_name == "PromptServerProxy": if api_name == "PromptServerService":
if not _IMPORT_TORCH: if not _IMPORT_TORCH:
return return
# Defer heavy import to child context
import server import server
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
instance = api() if isinstance(api, type) else api stub = PromptServerStub()
proxy = (
instance.instance
) # PromptServerProxy instance has .instance property returning self
original_register_route = proxy.register_route
def register_route_wrapper(
method: str, path: str, handler: Callable[..., Any]
) -> None:
callback_id = rpc.register_callback(handler)
loop = getattr(rpc, "loop", None)
if loop and loop.is_running():
import asyncio
asyncio.create_task(
original_register_route(
method, path, handler=callback_id, is_callback=True
)
)
else:
original_register_route(
method, path, handler=callback_id, is_callback=True
)
return None
proxy.register_route = register_route_wrapper
class RouteTableDefProxy:
def __init__(self, proxy_instance: Any):
self.proxy = proxy_instance
def get(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("GET", path, handler)
return handler
return decorator
def post(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("POST", path, handler)
return handler
return decorator
def patch(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("PATCH", path, handler)
return handler
return decorator
def put(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("PUT", path, handler)
return handler
return decorator
def delete(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("DELETE", path, handler)
return handler
return decorator
proxy.routes = RouteTableDefProxy(proxy)
if ( if (
hasattr(server, "PromptServer") hasattr(server, "PromptServer")
and getattr(server.PromptServer, "instance", None) != proxy and getattr(server.PromptServer, "instance", None) is not stub
): ):
server.PromptServer.instance = proxy server.PromptServer.instance = stub

View File

@ -31,7 +31,6 @@ def _load_extra_model_paths() -> None:
def initialize_child_process() -> None: def initialize_child_process() -> None:
logger.warning("][ DIAG:child_hooks initialize_child_process START")
if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0": if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0":
_load_extra_model_paths() _load_extra_model_paths()
_setup_child_loop_bridge() _setup_child_loop_bridge()
@ -41,15 +40,12 @@ def initialize_child_process() -> None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance() rpc = get_child_rpc_instance()
logger.warning("][ DIAG:child_hooks RPC instance: %s", rpc is not None)
if rpc: if rpc:
_setup_proxy_callers(rpc) _setup_proxy_callers(rpc)
logger.warning("][ DIAG:child_hooks proxy callers configured with RPC")
else: else:
logger.warning("][ DIAG:child_hooks NO RPC — proxy callers cleared")
_setup_proxy_callers() _setup_proxy_callers()
except Exception as e: except Exception as e:
logger.error(f"][ DIAG:child_hooks Manual RPC Injection failed: {e}") logger.error(f"][ child_hooks Manual RPC Injection failed: {e}")
_setup_proxy_callers() _setup_proxy_callers()
_setup_logging() _setup_logging()

View File

@ -354,6 +354,16 @@ async def load_isolated_node(
"sandbox": sandbox_config, "sandbox": sandbox_config,
} }
share_torch_no_deps = tool_config.get("share_torch_no_deps", [])
if share_torch_no_deps:
if not isinstance(share_torch_no_deps, list) or not all(
isinstance(dep, str) and dep.strip() for dep in share_torch_no_deps
):
raise ExtensionLoadError(
"[tool.comfy.isolation.share_torch_no_deps] must be a list of non-empty strings"
)
extension_config["share_torch_no_deps"] = share_torch_no_deps
_is_sealed = execution_model == "sealed_worker" _is_sealed = execution_model == "sealed_worker"
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux _is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
logger.info( logger.info(
@ -367,6 +377,16 @@ async def load_isolated_node(
if cuda_wheels is not None: if cuda_wheels is not None:
extension_config["cuda_wheels"] = cuda_wheels extension_config["cuda_wheels"] = cuda_wheels
extra_index_urls = tool_config.get("extra_index_urls", [])
if extra_index_urls:
if not isinstance(extra_index_urls, list) or not all(
isinstance(u, str) and u.strip() for u in extra_index_urls
):
raise ExtensionLoadError(
"[tool.comfy.isolation.extra_index_urls] must be a list of non-empty strings"
)
extension_config["extra_index_urls"] = extra_index_urls
# Conda-specific keys # Conda-specific keys
if is_conda: if is_conda:
extension_config["package_manager"] = "conda" extension_config["package_manager"] = "conda"
@ -408,31 +428,17 @@ async def load_isolated_node(
cache.register_proxy(extension_name, WebDirectoryProxy()) cache.register_proxy(extension_name, WebDirectoryProxy())
# Try cache first (lazy spawn) # Try cache first (lazy spawn)
logger.warning("][ DIAG:ext_loader cache_valid_check for %s", extension_name)
if is_cache_valid(node_dir, manifest_path, venv_root): if is_cache_valid(node_dir, manifest_path, venv_root):
cached_data = load_from_cache(node_dir, venv_root) cached_data = load_from_cache(node_dir, venv_root)
if cached_data: if cached_data:
if _is_stale_node_cache(cached_data): if _is_stale_node_cache(cached_data):
logger.warning( pass
"][ DIAG:ext_loader %s cache is stale/incompatible; rebuilding metadata",
extension_name,
)
else: else:
logger.warning("][ DIAG:ext_loader %s USING CACHE — dumping combo options:", extension_name) try:
for node_name, details in cached_data.items(): flushed = await extension.flush_pending_routes()
schema_v1 = details.get("schema_v1", {}) logger.info("][ %s flushed %d routes", extension_name, flushed)
inp = schema_v1.get("input", {}) if schema_v1 else {} except Exception as exc:
for section_name, section in inp.items(): logger.warning("][ %s route flush failed: %s", extension_name, exc)
if isinstance(section, dict):
for field_name, field_def in section.items():
if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]:
opts = field_def[1]["options"]
logger.warning(
"][ DIAG:ext_loader CACHE %s.%s.%s options=%d first=%s",
node_name, section_name, field_name,
len(opts),
opts[:3] if opts else "EMPTY",
)
specs: List[Tuple[str, str, type]] = [] specs: List[Tuple[str, str, type]] = []
for node_name, details in cached_data.items(): for node_name, details in cached_data.items():
stub_cls = build_stub_class(node_name, details, extension) stub_cls = build_stub_class(node_name, details, extension)
@ -440,11 +446,7 @@ async def load_isolated_node(
(node_name, details.get("display_name", node_name), stub_cls) (node_name, details.get("display_name", node_name), stub_cls)
) )
return specs return specs
else:
logger.warning("][ DIAG:ext_loader %s cache INVALID or MISSING", extension_name)
# Cache miss - spawn process and get metadata # Cache miss - spawn process and get metadata
logger.warning("][ DIAG:ext_loader %s cache miss, spawning process for metadata", extension_name)
try: try:
remote_nodes: Dict[str, str] = await extension.list_nodes() remote_nodes: Dict[str, str] = await extension.list_nodes()
@ -466,7 +468,6 @@ async def load_isolated_node(
cache_data: Dict[str, Dict] = {} cache_data: Dict[str, Dict] = {}
for node_name, display_name in remote_nodes.items(): for node_name, display_name in remote_nodes.items():
logger.warning("][ DIAG:ext_loader calling get_node_details for %s.%s", extension_name, node_name)
try: try:
details = await extension.get_node_details(node_name) details = await extension.get_node_details(node_name)
except Exception as exc: except Exception as exc:
@ -477,20 +478,6 @@ async def load_isolated_node(
exc, exc,
) )
continue continue
# DIAG: dump combo options from freshly-fetched details
schema_v1 = details.get("schema_v1", {})
inp = schema_v1.get("input", {}) if schema_v1 else {}
for section_name, section in inp.items():
if isinstance(section, dict):
for field_name, field_def in section.items():
if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]:
opts = field_def[1]["options"]
logger.warning(
"][ DIAG:ext_loader FRESH %s.%s.%s options=%d first=%s",
node_name, section_name, field_name,
len(opts),
opts[:3] if opts else "EMPTY",
)
details["display_name"] = display_name details["display_name"] = display_name
cache_data[node_name] = details cache_data[node_name] = details
stub_cls = build_stub_class(node_name, details, extension) stub_cls = build_stub_class(node_name, details, extension)
@ -512,6 +499,14 @@ async def load_isolated_node(
if host_policy["sandbox_mode"] == "disabled": if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir) _register_web_directory(extension_name, node_dir)
# Flush any routes the child buffered during module import — must happen
# before router freeze and before we kill the child process.
try:
flushed = await extension.flush_pending_routes()
logger.info("][ %s flushed %d routes", extension_name, flushed)
except Exception as exc:
logger.warning("][ %s route flush failed: %s", extension_name, exc)
# EJECT: Kill process after getting metadata (will respawn on first execution) # EJECT: Kill process after getting metadata (will respawn on first execution)
await _stop_extension_safe(extension, extension_name) await _stop_extension_safe(extension, extension_name)

View File

@ -211,6 +211,7 @@ class ComfyNodeExtension(ExtensionBase):
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
self._register_module_routes(module)
# Register web directory with WebDirectoryProxy (child-side) # Register web directory with WebDirectoryProxy (child-side)
web_dir_attr = getattr(module, "WEB_DIRECTORY", None) web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
@ -280,6 +281,55 @@ class ComfyNodeExtension(ExtensionBase):
self.node_instances = {} self.node_instances = {}
def _register_module_routes(self, module: Any) -> None:
"""Bridge legacy module-level ROUTES declarations into isolated routing."""
routes = getattr(module, "ROUTES", None) or []
if not routes:
return
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
prompt_server = PromptServerStub()
route_table = getattr(prompt_server, "routes", None)
if route_table is None:
logger.warning("%s Route registration unavailable for %s", LOG_PREFIX, module)
return
for route_spec in routes:
if not isinstance(route_spec, dict):
logger.warning("%s Ignoring non-dict ROUTES entry: %r", LOG_PREFIX, route_spec)
continue
method = str(route_spec.get("method", "")).strip().upper()
path = str(route_spec.get("path", "")).strip()
handler_ref = route_spec.get("handler")
if not method or not path:
logger.warning("%s Ignoring incomplete route spec: %r", LOG_PREFIX, route_spec)
continue
if isinstance(handler_ref, str):
handler = getattr(module, handler_ref, None)
else:
handler = handler_ref
if not callable(handler):
logger.warning(
"%s Ignoring route with missing handler %r for %s %s",
LOG_PREFIX,
handler_ref,
method,
path,
)
continue
decorator = getattr(route_table, method.lower(), None)
if not callable(decorator):
logger.warning("%s Unsupported route method %s for %s", LOG_PREFIX, method, path)
continue
decorator(path)(handler)
self._route_handlers[f"{method} {path}"] = handler
logger.info("%s buffered legacy route %s %s", LOG_PREFIX, method, path)
async def list_nodes(self) -> Dict[str, str]: async def list_nodes(self) -> Dict[str, str]:
return {name: self.display_names.get(name, name) for name in self.node_classes} return {name: self.display_names.get(name, name) for name in self.node_classes}
@ -289,10 +339,6 @@ class ComfyNodeExtension(ExtensionBase):
async def get_node_details(self, node_name: str) -> Dict[str, Any]: async def get_node_details(self, node_name: str) -> Dict[str, Any]:
node_cls = self._get_node_class(node_name) node_cls = self._get_node_class(node_name)
is_v3 = issubclass(node_cls, _ComfyNodeInternal) is_v3 = issubclass(node_cls, _ComfyNodeInternal)
logger.warning(
"%s DIAG:get_node_details START | node=%s | is_v3=%s | cls=%s",
LOG_PREFIX, node_name, is_v3, node_cls,
)
input_types_raw = ( input_types_raw = (
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {} node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
@ -316,16 +362,7 @@ class ComfyNodeExtension(ExtensionBase):
if is_v3: if is_v3:
try: try:
logger.warning(
"%s DIAG:get_node_details calling GET_SCHEMA for %s",
LOG_PREFIX, node_name,
)
schema = node_cls.GET_SCHEMA() schema = node_cls.GET_SCHEMA()
logger.warning(
"%s DIAG:get_node_details GET_SCHEMA returned for %s | schema_inputs=%s",
LOG_PREFIX, node_name,
[getattr(i, 'id', '?') for i in (schema.inputs or [])],
)
schema_v1 = asdict(schema.get_v1_info(node_cls)) schema_v1 = asdict(schema.get_v1_info(node_cls))
try: try:
schema_v3 = asdict(schema.get_v3_info(node_cls)) schema_v3 = asdict(schema.get_v3_info(node_cls))
@ -532,6 +569,11 @@ class ComfyNodeExtension(ExtensionBase):
wrapped = self._wrap_unpicklable_objects(result) wrapped = self._wrap_unpicklable_objects(result)
return wrapped return wrapped
async def flush_pending_routes(self) -> int:
"""Flush buffered route registrations to host via RPC. Called by host after node discovery."""
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
return await PromptServerStub.flush_child_routes()
async def flush_transport_state(self) -> int: async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_CHILD") != "1": if os.environ.get("PYISOLATE_CHILD") != "1":
return 0 return 0
@ -750,19 +792,13 @@ class ComfyNodeExtension(ExtensionBase):
return self.node_instances[node_name] return self.node_instances[node_name]
async def before_module_loaded(self) -> None: async def before_module_loaded(self) -> None:
# Inject initialization here if we think this is the child
logger.warning(
"%s DIAG:before_module_loaded START | is_child=%s",
LOG_PREFIX, os.environ.get("PYISOLATE_CHILD"),
)
try: try:
from comfy.isolation import initialize_proxies from comfy.isolation import initialize_proxies
initialize_proxies() initialize_proxies()
logger.warning("%s DIAG:before_module_loaded initialize_proxies OK", LOG_PREFIX)
except Exception as e: except Exception as e:
logger.error( logger.error(
"%s DIAG:before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e "%s before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e
) )
await super().before_module_loaded() await super().before_module_loaded()

View File

@ -166,6 +166,8 @@ def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
if isinstance(whitelist_raw, dict): if isinstance(whitelist_raw, dict):
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()}) policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
os.environ["PYISOLATE_SANDBOX_MODE"] = policy["sandbox_mode"]
logger.debug( logger.debug(
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s", "Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
len(policy["whitelist"]), len(policy["whitelist"]),

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os import os
import traceback
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pyisolate import ProxiedSingleton from pyisolate import ProxiedSingleton
@ -152,24 +152,9 @@ class FolderPathsProxy(ProxiedSingleton):
return list(_folder_paths().get_folder_paths(folder_name)) return list(_folder_paths().get_folder_paths(folder_name))
def get_filename_list(self, folder_name: str) -> list[str]: def get_filename_list(self, folder_name: str) -> list[str]:
caller_stack = "".join(traceback.format_stack()[-4:-1])
_fp_logger.warning(
"][ DIAG:FolderPathsProxy.get_filename_list called | folder=%s | is_child=%s | rpc_configured=%s\n%s",
folder_name, _is_child_process(), self._rpc is not None, caller_stack,
)
if _is_child_process(): if _is_child_process():
result = list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name)) return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
_fp_logger.warning( return list(_folder_paths().get_filename_list(folder_name))
"][ DIAG:FolderPathsProxy.get_filename_list RPC result | folder=%s | count=%d | first=%s",
folder_name, len(result), result[:3] if result else "EMPTY",
)
return result
result = list(_folder_paths().get_filename_list(folder_name))
_fp_logger.warning(
"][ DIAG:FolderPathsProxy.get_filename_list LOCAL result | folder=%s | count=%d | first=%s",
folder_name, len(result), result[:3] if result else "EMPTY",
)
return result
def get_full_path(self, folder_name: str, filename: str) -> str | None: def get_full_path(self, folder_name: str, filename: str) -> str | None:
if _is_child_process(): if _is_child_process():

View File

@ -94,14 +94,13 @@ class PromptServerStub:
def client_id(self) -> Optional[str]: def client_id(self) -> Optional[str]:
return "isolated_client" return "isolated_client"
def supports(self, feature: str) -> bool: @property
return True def supports(self) -> set:
return {"custom_nodes_from_web"}
@property @property
def app(self): def app(self):
raise RuntimeError( return _AppStub(self)
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
)
@property @property
def prompt_queue(self): def prompt_queue(self):
@ -140,18 +139,27 @@ class PromptServerStub:
call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid) call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid)
# --- Route Registration Logic --- # --- Route Registration Logic ---
def register_route(self, method: str, path: str, handler: Callable): _pending_child_routes: list = []
"""Register a route handler via RPC."""
if not self._rpc:
logger.error("RPC not initialized in PromptServerStub")
return
# Fire registration async def register_route(self, method: str, path: str, handler: Callable):
try: """Buffer route registration. Routes are flushed via flush_child_routes()."""
loop = asyncio.get_running_loop() PromptServerStub._pending_child_routes.append((method, path, handler))
loop.create_task(self._rpc.register_route_rpc(method, path, handler)) logger.info("%s Buffered isolated route %s %s", LOG_PREFIX, method, path)
except RuntimeError:
call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler) @classmethod
async def flush_child_routes(cls):
"""Send all buffered route registrations to host via RPC. Call from on_module_loaded."""
if not cls._rpc:
return 0
flushed = 0
for method, path, handler in cls._pending_child_routes:
try:
await cls._rpc.register_route_rpc(method, path, handler)
flushed += 1
except Exception as e:
logger.error("%s Child route flush failed %s %s: %s", LOG_PREFIX, method, path, e)
cls._pending_child_routes = []
return flushed
class RouteStub: class RouteStub:
@ -205,7 +213,6 @@ class PromptServerService(ProxiedSingleton):
"""Host-side RPC Service for PromptServer.""" """Host-side RPC Service for PromptServer."""
def __init__(self): def __init__(self):
# We will bind to the real server instance lazily or via global import
pass pass
@property @property
@ -231,7 +238,7 @@ class PromptServerService(ProxiedSingleton):
async def register_route_rpc(self, method: str, path: str, child_handler_proxy): async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
"""RPC Target: Register a route that forwards to the Child.""" """RPC Target: Register a route that forwards to the Child."""
from aiohttp import web from aiohttp import web
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}") logger.info("%s Registering isolated route %s %s", LOG_PREFIX, method, path)
async def route_wrapper(request: web.Request) -> web.Response: async def route_wrapper(request: web.Request) -> web.Response:
# 1. Capture request data # 1. Capture request data
@ -253,8 +260,8 @@ class PromptServerService(ProxiedSingleton):
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}") logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
# Register loop
self.server.app.router.add_route(method, path, route_wrapper) self.server.app.router.add_route(method, path, route_wrapper)
logger.info("%s Registered isolated route %s %s", LOG_PREFIX, method, path)
def _serialize_response(self, result: Any) -> Any: def _serialize_response(self, result: Any) -> Any:
"""Helper to convert Child result -> web.Response""" """Helper to convert Child result -> web.Response"""
@ -269,3 +276,32 @@ class PromptServerService(ProxiedSingleton):
return web.Response(text=result) return web.Response(text=result)
# Fallback # Fallback
return web.Response(text=str(result)) return web.Response(text=str(result))
class _RouterStub:
"""Captures router.add_route and router.add_static calls in isolation child."""
def __init__(self, stub):
self._stub = stub
def add_route(self, method, path, handler, **kwargs):
self._stub.register_route(method, path, handler)
def add_static(self, prefix, path, **kwargs):
# Static file serving not supported in isolation — silently skip
pass
class _AppStub:
"""Captures PromptServer.app access patterns in isolation child."""
def __init__(self, stub):
self.router = _RouterStub(stub)
self.frozen = False
def add_routes(self, routes):
# aiohttp route table — iterate and register each
for route in routes:
if hasattr(route, "method") and hasattr(route, "handler"):
self.router.add_route(route.method, route.path, route.handler)
# StaticDef and other non-method routes — silently skip

View File

@ -724,6 +724,7 @@ def capture_prompt_web_exact_relay() -> dict[str, object]:
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache
PromptServerStub.set_rpc(fake_rpc) PromptServerStub.set_rpc(fake_rpc)
PromptServerStub._pending_child_routes = []
stub = PromptServerStub() stub = PromptServerStub()
cache = WebDirectoryCache() cache = WebDirectoryCache()
cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts)) cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts))
@ -735,6 +736,7 @@ def capture_prompt_web_exact_relay() -> dict[str, object]:
stub.send_progress_text("hello", "node-17") stub.send_progress_text("hello", "node-17")
stub.routes.get("/demo")(demo_handler) stub.routes.get("/demo")(demo_handler)
asyncio.run(PromptServerStub.flush_child_routes())
web_file = cache.get_file("demo_ext", "js/app.js") web_file = cache.get_file("demo_ext", "js/app.js")
imported = set(sys.modules) - before imported = set(sys.modules) - before
return { return {

View File

@ -108,6 +108,119 @@ flash_attn = "flash-attn-special"
} }
def test_load_isolated_node_passes_share_torch_no_deps(tmp_path, monkeypatch):
node_dir = tmp_path / "node"
node_dir.mkdir()
manifest_path = node_dir / "pyproject.toml"
_write_manifest(
node_dir,
"""
[project]
name = "demo-node"
dependencies = ["timm", "pyyaml"]
[tool.comfy.isolation]
can_isolate = true
share_torch = true
share_torch_no_deps = ["timm"]
""".strip(),
)
captured: dict[str, object] = {}
class DummyManager:
def __init__(self, *args, **kwargs) -> None:
return None
def load_extension(self, config):
captured.update(config)
return _DummyExtension()
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
monkeypatch.setattr(
extension_loader_module,
"load_host_policy",
lambda base_path: {
"sandbox_mode": "disabled",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
},
)
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
monkeypatch.setattr(
extension_loader_module,
"load_from_cache",
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
)
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
specs = asyncio.run(
load_isolated_node(
node_dir,
manifest_path,
logging.getLogger("test"),
lambda *args, **kwargs: object,
tmp_path / "venvs",
[],
)
)
assert len(specs) == 1
assert captured["share_torch_no_deps"] == ["timm"]
def test_on_module_loaded_registers_legacy_routes(monkeypatch):
captured: list[tuple[str, str, Any]] = []
def demo_handler(body):
return body
module = SimpleNamespace(
__file__="/tmp/demo_node/__init__.py",
__name__="demo_node",
NODE_CLASS_MAPPINGS={},
NODE_DISPLAY_NAME_MAPPINGS={},
ROUTES=[
{"method": "POST", "path": "/sam3/interactive_segment_one", "handler": "demo_handler"},
],
demo_handler=demo_handler,
)
def fake_register_route(self, method, path, handler):
captured.append((method, path, handler))
monkeypatch.setattr(
"comfy.isolation.proxies.prompt_server_impl.PromptServerStub.register_route",
fake_register_route,
)
extension = ComfyNodeExtension()
asyncio.run(extension.on_module_loaded(module))
assert captured == [("POST", "/sam3/interactive_segment_one", demo_handler)]
def test_prompt_server_stub_buffers_routes_without_rpc():
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
def demo_handler(body):
return body
old_rpc = PromptServerStub._rpc
old_pending = list(PromptServerStub._pending_child_routes)
try:
PromptServerStub._rpc = None
PromptServerStub._pending_child_routes = []
PromptServerStub().register_route("POST", "/sam3/interactive_segment_one", demo_handler)
assert PromptServerStub._pending_child_routes == [
("POST", "/sam3/interactive_segment_one", demo_handler)
]
finally:
PromptServerStub._rpc = old_rpc
PromptServerStub._pending_child_routes = old_pending
def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency( def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency(
tmp_path, monkeypatch tmp_path, monkeypatch
): ):
@ -362,6 +475,70 @@ can_isolate = true
assert "cuda_wheels" not in captured assert "cuda_wheels" not in captured
def test_load_isolated_node_passes_extra_index_urls(tmp_path, monkeypatch):
node_dir = tmp_path / "node"
node_dir.mkdir()
manifest_path = node_dir / "pyproject.toml"
_write_manifest(
node_dir,
"""
[project]
name = "demo-node"
dependencies = ["fbxsdkpy==2020.1.post2", "numpy>=1.0"]
[tool.comfy.isolation]
can_isolate = true
share_torch = true
extra_index_urls = ["https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple"]
""".strip(),
)
captured: dict[str, object] = {}
class DummyManager:
def __init__(self, *args, **kwargs) -> None:
return None
def load_extension(self, config):
captured.update(config)
return _DummyExtension()
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
monkeypatch.setattr(
extension_loader_module,
"load_host_policy",
lambda base_path: {
"sandbox_mode": "disabled",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
},
)
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
monkeypatch.setattr(
extension_loader_module,
"load_from_cache",
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
)
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
specs = asyncio.run(
load_isolated_node(
node_dir,
manifest_path,
logging.getLogger("test"),
lambda *args, **kwargs: object,
tmp_path / "venvs",
[],
)
)
assert len(specs) == 1
assert captured["extra_index_urls"] == [
"https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple"
]
def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch): def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch):
class DummyRegistry: class DummyRegistry:
def register(self, model): def register(self, model):