fix(isolation): refresh loader and isolated route handling

This commit is contained in:
John Pollock 2026-04-12 21:08:35 -05:00
parent 51e70fe033
commit 07fffdd593
10 changed files with 335 additions and 213 deletions

View File

@ -44,21 +44,15 @@ def initialize_proxies() -> None:
from .child_hooks import is_child_process
is_child = is_child_process()
logger.warning(
"%s DIAG:initialize_proxies | is_child=%s | PYISOLATE_CHILD=%s",
LOG_PREFIX, is_child, os.environ.get("PYISOLATE_CHILD"),
)
if is_child:
from .child_hooks import initialize_child_process
initialize_child_process()
logger.warning("%s DIAG:initialize_proxies child_process initialized", LOG_PREFIX)
else:
from .host_hooks import initialize_host_process
initialize_host_process()
logger.warning("%s DIAG:initialize_proxies host_process initialized", LOG_PREFIX)
if start_shm_forensics is not None:
start_shm_forensics()

View File

@ -586,29 +586,6 @@ class ComfyUIAdapter(IsolationAdapter):
register_hooks_serializers(registry)
# Generic Numpy Serializer
def serialize_numpy(obj: Any) -> Any:
import torch
try:
# Attempt zero-copy conversion to Tensor
return torch.from_numpy(obj)
except Exception:
# Fallback for non-numeric arrays (strings, objects, mixes)
return obj.tolist()
def deserialize_numpy_b64(data: Any) -> Any:
"""Deserialize base64-encoded ndarray from sealed worker."""
import base64
import numpy as np
if isinstance(data, dict) and "data" in data and "dtype" in data:
raw = base64.b64decode(data["data"])
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
return torch.from_numpy(arr.copy())
return data
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
# -- File3D (comfy_api.latest._util.geometry_types) ---------------------
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
@ -873,93 +850,15 @@ class ComfyUIAdapter(IsolationAdapter):
return
if api_name == "PromptServerProxy":
if api_name == "PromptServerService":
if not _IMPORT_TORCH:
return
# Defer heavy import to child context
import server
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
instance = api() if isinstance(api, type) else api
proxy = (
instance.instance
) # PromptServerProxy instance has .instance property returning self
original_register_route = proxy.register_route
def register_route_wrapper(
method: str, path: str, handler: Callable[..., Any]
) -> None:
callback_id = rpc.register_callback(handler)
loop = getattr(rpc, "loop", None)
if loop and loop.is_running():
import asyncio
asyncio.create_task(
original_register_route(
method, path, handler=callback_id, is_callback=True
)
)
else:
original_register_route(
method, path, handler=callback_id, is_callback=True
)
return None
proxy.register_route = register_route_wrapper
class RouteTableDefProxy:
def __init__(self, proxy_instance: Any):
self.proxy = proxy_instance
def get(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("GET", path, handler)
return handler
return decorator
def post(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("POST", path, handler)
return handler
return decorator
def patch(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("PATCH", path, handler)
return handler
return decorator
def put(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("PUT", path, handler)
return handler
return decorator
def delete(
self, path: str, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
self.proxy.register_route("DELETE", path, handler)
return handler
return decorator
proxy.routes = RouteTableDefProxy(proxy)
stub = PromptServerStub()
if (
hasattr(server, "PromptServer")
and getattr(server.PromptServer, "instance", None) != proxy
and getattr(server.PromptServer, "instance", None) is not stub
):
server.PromptServer.instance = proxy
server.PromptServer.instance = stub

View File

@ -31,7 +31,6 @@ def _load_extra_model_paths() -> None:
def initialize_child_process() -> None:
logger.warning("][ DIAG:child_hooks initialize_child_process START")
if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0":
_load_extra_model_paths()
_setup_child_loop_bridge()
@ -41,15 +40,12 @@ def initialize_child_process() -> None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
logger.warning("][ DIAG:child_hooks RPC instance: %s", rpc is not None)
if rpc:
_setup_proxy_callers(rpc)
logger.warning("][ DIAG:child_hooks proxy callers configured with RPC")
else:
logger.warning("][ DIAG:child_hooks NO RPC — proxy callers cleared")
_setup_proxy_callers()
except Exception as e:
logger.error(f"][ DIAG:child_hooks Manual RPC Injection failed: {e}")
logger.error(f"][ child_hooks Manual RPC Injection failed: {e}")
_setup_proxy_callers()
_setup_logging()

View File

@ -354,6 +354,16 @@ async def load_isolated_node(
"sandbox": sandbox_config,
}
share_torch_no_deps = tool_config.get("share_torch_no_deps", [])
if share_torch_no_deps:
if not isinstance(share_torch_no_deps, list) or not all(
isinstance(dep, str) and dep.strip() for dep in share_torch_no_deps
):
raise ExtensionLoadError(
"[tool.comfy.isolation.share_torch_no_deps] must be a list of non-empty strings"
)
extension_config["share_torch_no_deps"] = share_torch_no_deps
_is_sealed = execution_model == "sealed_worker"
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
logger.info(
@ -367,6 +377,16 @@ async def load_isolated_node(
if cuda_wheels is not None:
extension_config["cuda_wheels"] = cuda_wheels
extra_index_urls = tool_config.get("extra_index_urls", [])
if extra_index_urls:
if not isinstance(extra_index_urls, list) or not all(
isinstance(u, str) and u.strip() for u in extra_index_urls
):
raise ExtensionLoadError(
"[tool.comfy.isolation.extra_index_urls] must be a list of non-empty strings"
)
extension_config["extra_index_urls"] = extra_index_urls
# Conda-specific keys
if is_conda:
extension_config["package_manager"] = "conda"
@ -408,31 +428,17 @@ async def load_isolated_node(
cache.register_proxy(extension_name, WebDirectoryProxy())
# Try cache first (lazy spawn)
logger.warning("][ DIAG:ext_loader cache_valid_check for %s", extension_name)
if is_cache_valid(node_dir, manifest_path, venv_root):
cached_data = load_from_cache(node_dir, venv_root)
if cached_data:
if _is_stale_node_cache(cached_data):
logger.warning(
"][ DIAG:ext_loader %s cache is stale/incompatible; rebuilding metadata",
extension_name,
)
pass
else:
logger.warning("][ DIAG:ext_loader %s USING CACHE — dumping combo options:", extension_name)
for node_name, details in cached_data.items():
schema_v1 = details.get("schema_v1", {})
inp = schema_v1.get("input", {}) if schema_v1 else {}
for section_name, section in inp.items():
if isinstance(section, dict):
for field_name, field_def in section.items():
if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]:
opts = field_def[1]["options"]
logger.warning(
"][ DIAG:ext_loader CACHE %s.%s.%s options=%d first=%s",
node_name, section_name, field_name,
len(opts),
opts[:3] if opts else "EMPTY",
)
try:
flushed = await extension.flush_pending_routes()
logger.info("][ %s flushed %d routes", extension_name, flushed)
except Exception as exc:
logger.warning("][ %s route flush failed: %s", extension_name, exc)
specs: List[Tuple[str, str, type]] = []
for node_name, details in cached_data.items():
stub_cls = build_stub_class(node_name, details, extension)
@ -440,11 +446,7 @@ async def load_isolated_node(
(node_name, details.get("display_name", node_name), stub_cls)
)
return specs
else:
logger.warning("][ DIAG:ext_loader %s cache INVALID or MISSING", extension_name)
# Cache miss - spawn process and get metadata
logger.warning("][ DIAG:ext_loader %s cache miss, spawning process for metadata", extension_name)
try:
remote_nodes: Dict[str, str] = await extension.list_nodes()
@ -466,7 +468,6 @@ async def load_isolated_node(
cache_data: Dict[str, Dict] = {}
for node_name, display_name in remote_nodes.items():
logger.warning("][ DIAG:ext_loader calling get_node_details for %s.%s", extension_name, node_name)
try:
details = await extension.get_node_details(node_name)
except Exception as exc:
@ -477,20 +478,6 @@ async def load_isolated_node(
exc,
)
continue
# DIAG: dump combo options from freshly-fetched details
schema_v1 = details.get("schema_v1", {})
inp = schema_v1.get("input", {}) if schema_v1 else {}
for section_name, section in inp.items():
if isinstance(section, dict):
for field_name, field_def in section.items():
if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]:
opts = field_def[1]["options"]
logger.warning(
"][ DIAG:ext_loader FRESH %s.%s.%s options=%d first=%s",
node_name, section_name, field_name,
len(opts),
opts[:3] if opts else "EMPTY",
)
details["display_name"] = display_name
cache_data[node_name] = details
stub_cls = build_stub_class(node_name, details, extension)
@ -512,6 +499,14 @@ async def load_isolated_node(
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# Flush any routes the child buffered during module import — must happen
# before router freeze and before we kill the child process.
try:
flushed = await extension.flush_pending_routes()
logger.info("][ %s flushed %d routes", extension_name, flushed)
except Exception as exc:
logger.warning("][ %s route flush failed: %s", extension_name, exc)
# EJECT: Kill process after getting metadata (will respawn on first execution)
await _stop_extension_safe(extension, extension_name)

View File

@ -211,6 +211,7 @@ class ComfyNodeExtension(ExtensionBase):
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
self._register_module_routes(module)
# Register web directory with WebDirectoryProxy (child-side)
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
@ -280,6 +281,55 @@ class ComfyNodeExtension(ExtensionBase):
self.node_instances = {}
def _register_module_routes(self, module: Any) -> None:
"""Bridge legacy module-level ROUTES declarations into isolated routing."""
routes = getattr(module, "ROUTES", None) or []
if not routes:
return
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
prompt_server = PromptServerStub()
route_table = getattr(prompt_server, "routes", None)
if route_table is None:
logger.warning("%s Route registration unavailable for %s", LOG_PREFIX, module)
return
for route_spec in routes:
if not isinstance(route_spec, dict):
logger.warning("%s Ignoring non-dict ROUTES entry: %r", LOG_PREFIX, route_spec)
continue
method = str(route_spec.get("method", "")).strip().upper()
path = str(route_spec.get("path", "")).strip()
handler_ref = route_spec.get("handler")
if not method or not path:
logger.warning("%s Ignoring incomplete route spec: %r", LOG_PREFIX, route_spec)
continue
if isinstance(handler_ref, str):
handler = getattr(module, handler_ref, None)
else:
handler = handler_ref
if not callable(handler):
logger.warning(
"%s Ignoring route with missing handler %r for %s %s",
LOG_PREFIX,
handler_ref,
method,
path,
)
continue
decorator = getattr(route_table, method.lower(), None)
if not callable(decorator):
logger.warning("%s Unsupported route method %s for %s", LOG_PREFIX, method, path)
continue
decorator(path)(handler)
self._route_handlers[f"{method} {path}"] = handler
logger.info("%s buffered legacy route %s %s", LOG_PREFIX, method, path)
async def list_nodes(self) -> Dict[str, str]:
return {name: self.display_names.get(name, name) for name in self.node_classes}
@ -289,10 +339,6 @@ class ComfyNodeExtension(ExtensionBase):
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
node_cls = self._get_node_class(node_name)
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
logger.warning(
"%s DIAG:get_node_details START | node=%s | is_v3=%s | cls=%s",
LOG_PREFIX, node_name, is_v3, node_cls,
)
input_types_raw = (
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
@ -316,16 +362,7 @@ class ComfyNodeExtension(ExtensionBase):
if is_v3:
try:
logger.warning(
"%s DIAG:get_node_details calling GET_SCHEMA for %s",
LOG_PREFIX, node_name,
)
schema = node_cls.GET_SCHEMA()
logger.warning(
"%s DIAG:get_node_details GET_SCHEMA returned for %s | schema_inputs=%s",
LOG_PREFIX, node_name,
[getattr(i, 'id', '?') for i in (schema.inputs or [])],
)
schema_v1 = asdict(schema.get_v1_info(node_cls))
try:
schema_v3 = asdict(schema.get_v3_info(node_cls))
@ -532,6 +569,11 @@ class ComfyNodeExtension(ExtensionBase):
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
async def flush_pending_routes(self) -> int:
"""Flush buffered route registrations to host via RPC. Called by host after node discovery."""
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
return await PromptServerStub.flush_child_routes()
async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_CHILD") != "1":
return 0
@ -750,19 +792,13 @@ class ComfyNodeExtension(ExtensionBase):
return self.node_instances[node_name]
async def before_module_loaded(self) -> None:
# Inject initialization here if we think this is the child
logger.warning(
"%s DIAG:before_module_loaded START | is_child=%s",
LOG_PREFIX, os.environ.get("PYISOLATE_CHILD"),
)
try:
from comfy.isolation import initialize_proxies
initialize_proxies()
logger.warning("%s DIAG:before_module_loaded initialize_proxies OK", LOG_PREFIX)
except Exception as e:
logger.error(
"%s DIAG:before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e
"%s before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e
)
await super().before_module_loaded()

View File

@ -166,6 +166,8 @@ def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
if isinstance(whitelist_raw, dict):
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
os.environ["PYISOLATE_SANDBOX_MODE"] = policy["sandbox_mode"]
logger.debug(
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
len(policy["whitelist"]),

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
import os
import traceback
from typing import Any, Dict, Optional
from pyisolate import ProxiedSingleton
@ -152,24 +152,9 @@ class FolderPathsProxy(ProxiedSingleton):
return list(_folder_paths().get_folder_paths(folder_name))
def get_filename_list(self, folder_name: str) -> list[str]:
caller_stack = "".join(traceback.format_stack()[-4:-1])
_fp_logger.warning(
"][ DIAG:FolderPathsProxy.get_filename_list called | folder=%s | is_child=%s | rpc_configured=%s\n%s",
folder_name, _is_child_process(), self._rpc is not None, caller_stack,
)
if _is_child_process():
result = list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
_fp_logger.warning(
"][ DIAG:FolderPathsProxy.get_filename_list RPC result | folder=%s | count=%d | first=%s",
folder_name, len(result), result[:3] if result else "EMPTY",
)
return result
result = list(_folder_paths().get_filename_list(folder_name))
_fp_logger.warning(
"][ DIAG:FolderPathsProxy.get_filename_list LOCAL result | folder=%s | count=%d | first=%s",
folder_name, len(result), result[:3] if result else "EMPTY",
)
return result
return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
return list(_folder_paths().get_filename_list(folder_name))
def get_full_path(self, folder_name: str, filename: str) -> str | None:
if _is_child_process():

View File

@ -94,14 +94,13 @@ class PromptServerStub:
def client_id(self) -> Optional[str]:
return "isolated_client"
def supports(self, feature: str) -> bool:
return True
@property
def supports(self) -> set:
return {"custom_nodes_from_web"}
@property
def app(self):
raise RuntimeError(
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
)
return _AppStub(self)
@property
def prompt_queue(self):
@ -140,18 +139,27 @@ class PromptServerStub:
call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid)
# --- Route Registration Logic ---
def register_route(self, method: str, path: str, handler: Callable):
"""Register a route handler via RPC."""
if not self._rpc:
logger.error("RPC not initialized in PromptServerStub")
return
_pending_child_routes: list = []
# Fire registration async
try:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
except RuntimeError:
call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler)
def register_route(self, method: str, path: str, handler: Callable):
"""Buffer route registration. Routes are flushed via flush_child_routes()."""
PromptServerStub._pending_child_routes.append((method, path, handler))
logger.info("%s Buffered isolated route %s %s", LOG_PREFIX, method, path)
@classmethod
async def flush_child_routes(cls):
"""Send all buffered route registrations to host via RPC. Call from on_module_loaded."""
if not cls._rpc:
return 0
flushed = 0
for method, path, handler in cls._pending_child_routes:
try:
await cls._rpc.register_route_rpc(method, path, handler)
flushed += 1
except Exception as e:
logger.error("%s Child route flush failed %s %s: %s", LOG_PREFIX, method, path, e)
cls._pending_child_routes = []
return flushed
class RouteStub:
@ -205,7 +213,6 @@ class PromptServerService(ProxiedSingleton):
"""Host-side RPC Service for PromptServer."""
def __init__(self):
# We will bind to the real server instance lazily or via global import
pass
@property
@ -231,7 +238,7 @@ class PromptServerService(ProxiedSingleton):
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
"""RPC Target: Register a route that forwards to the Child."""
from aiohttp import web
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
logger.info("%s Registering isolated route %s %s", LOG_PREFIX, method, path)
async def route_wrapper(request: web.Request) -> web.Response:
# 1. Capture request data
@ -253,8 +260,8 @@ class PromptServerService(ProxiedSingleton):
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
return web.Response(status=500, text=str(e))
# Register loop
self.server.app.router.add_route(method, path, route_wrapper)
logger.info("%s Registered isolated route %s %s", LOG_PREFIX, method, path)
def _serialize_response(self, result: Any) -> Any:
"""Helper to convert Child result -> web.Response"""
@ -269,3 +276,32 @@ class PromptServerService(ProxiedSingleton):
return web.Response(text=result)
# Fallback
return web.Response(text=str(result))
class _RouterStub:
"""Captures router.add_route and router.add_static calls in isolation child."""
def __init__(self, stub):
self._stub = stub
def add_route(self, method, path, handler, **kwargs):
self._stub.register_route(method, path, handler)
def add_static(self, prefix, path, **kwargs):
# Static file serving not supported in isolation — silently skip
pass
class _AppStub:
"""Captures PromptServer.app access patterns in isolation child."""
def __init__(self, stub):
self.router = _RouterStub(stub)
self.frozen = False
def add_routes(self, routes):
# aiohttp route table — iterate and register each
for route in routes:
if hasattr(route, "method") and hasattr(route, "handler"):
self.router.add_route(route.method, route.path, route.handler)
# StaticDef and other non-method routes — silently skip

View File

@ -724,6 +724,7 @@ def capture_prompt_web_exact_relay() -> dict[str, object]:
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache
PromptServerStub.set_rpc(fake_rpc)
PromptServerStub._pending_child_routes = []
stub = PromptServerStub()
cache = WebDirectoryCache()
cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts))
@ -735,6 +736,7 @@ def capture_prompt_web_exact_relay() -> dict[str, object]:
stub.send_progress_text("hello", "node-17")
stub.routes.get("/demo")(demo_handler)
asyncio.run(PromptServerStub.flush_child_routes())
web_file = cache.get_file("demo_ext", "js/app.js")
imported = set(sys.modules) - before
return {

View File

@ -108,6 +108,119 @@ flash_attn = "flash-attn-special"
}
def test_load_isolated_node_passes_share_torch_no_deps(tmp_path, monkeypatch):
node_dir = tmp_path / "node"
node_dir.mkdir()
manifest_path = node_dir / "pyproject.toml"
_write_manifest(
node_dir,
"""
[project]
name = "demo-node"
dependencies = ["timm", "pyyaml"]
[tool.comfy.isolation]
can_isolate = true
share_torch = true
share_torch_no_deps = ["timm"]
""".strip(),
)
captured: dict[str, object] = {}
class DummyManager:
def __init__(self, *args, **kwargs) -> None:
return None
def load_extension(self, config):
captured.update(config)
return _DummyExtension()
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
monkeypatch.setattr(
extension_loader_module,
"load_host_policy",
lambda base_path: {
"sandbox_mode": "disabled",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
},
)
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
monkeypatch.setattr(
extension_loader_module,
"load_from_cache",
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
)
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
specs = asyncio.run(
load_isolated_node(
node_dir,
manifest_path,
logging.getLogger("test"),
lambda *args, **kwargs: object,
tmp_path / "venvs",
[],
)
)
assert len(specs) == 1
assert captured["share_torch_no_deps"] == ["timm"]
def test_on_module_loaded_registers_legacy_routes(monkeypatch):
captured: list[tuple[str, str, Any]] = []
def demo_handler(body):
return body
module = SimpleNamespace(
__file__="/tmp/demo_node/__init__.py",
__name__="demo_node",
NODE_CLASS_MAPPINGS={},
NODE_DISPLAY_NAME_MAPPINGS={},
ROUTES=[
{"method": "POST", "path": "/sam3/interactive_segment_one", "handler": "demo_handler"},
],
demo_handler=demo_handler,
)
def fake_register_route(self, method, path, handler):
captured.append((method, path, handler))
monkeypatch.setattr(
"comfy.isolation.proxies.prompt_server_impl.PromptServerStub.register_route",
fake_register_route,
)
extension = ComfyNodeExtension()
asyncio.run(extension.on_module_loaded(module))
assert captured == [("POST", "/sam3/interactive_segment_one", demo_handler)]
def test_prompt_server_stub_buffers_routes_without_rpc():
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
def demo_handler(body):
return body
old_rpc = PromptServerStub._rpc
old_pending = list(PromptServerStub._pending_child_routes)
try:
PromptServerStub._rpc = None
PromptServerStub._pending_child_routes = []
PromptServerStub().register_route("POST", "/sam3/interactive_segment_one", demo_handler)
assert PromptServerStub._pending_child_routes == [
("POST", "/sam3/interactive_segment_one", demo_handler)
]
finally:
PromptServerStub._rpc = old_rpc
PromptServerStub._pending_child_routes = old_pending
def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency(
tmp_path, monkeypatch
):
@ -362,6 +475,70 @@ can_isolate = true
assert "cuda_wheels" not in captured
def test_load_isolated_node_passes_extra_index_urls(tmp_path, monkeypatch):
node_dir = tmp_path / "node"
node_dir.mkdir()
manifest_path = node_dir / "pyproject.toml"
_write_manifest(
node_dir,
"""
[project]
name = "demo-node"
dependencies = ["fbxsdkpy==2020.1.post2", "numpy>=1.0"]
[tool.comfy.isolation]
can_isolate = true
share_torch = true
extra_index_urls = ["https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple"]
""".strip(),
)
captured: dict[str, object] = {}
class DummyManager:
def __init__(self, *args, **kwargs) -> None:
return None
def load_extension(self, config):
captured.update(config)
return _DummyExtension()
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
monkeypatch.setattr(
extension_loader_module,
"load_host_policy",
lambda base_path: {
"sandbox_mode": "disabled",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
},
)
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
monkeypatch.setattr(
extension_loader_module,
"load_from_cache",
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
)
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
specs = asyncio.run(
load_isolated_node(
node_dir,
manifest_path,
logging.getLogger("test"),
lambda *args, **kwargs: object,
tmp_path / "venvs",
[],
)
)
assert len(specs) == 1
assert captured["extra_index_urls"] == [
"https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple"
]
def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch):
class DummyRegistry:
def register(self, model):