diff --git a/comfy/isolation/__init__.py b/comfy/isolation/__init__.py index 18ce059c6..640092f45 100644 --- a/comfy/isolation/__init__.py +++ b/comfy/isolation/__init__.py @@ -44,21 +44,15 @@ def initialize_proxies() -> None: from .child_hooks import is_child_process is_child = is_child_process() - logger.warning( - "%s DIAG:initialize_proxies | is_child=%s | PYISOLATE_CHILD=%s", - LOG_PREFIX, is_child, os.environ.get("PYISOLATE_CHILD"), - ) if is_child: from .child_hooks import initialize_child_process initialize_child_process() - logger.warning("%s DIAG:initialize_proxies child_process initialized", LOG_PREFIX) else: from .host_hooks import initialize_host_process initialize_host_process() - logger.warning("%s DIAG:initialize_proxies host_process initialized", LOG_PREFIX) if start_shm_forensics is not None: start_shm_forensics() diff --git a/comfy/isolation/adapter.py b/comfy/isolation/adapter.py index 4751dee51..05fd5eb86 100644 --- a/comfy/isolation/adapter.py +++ b/comfy/isolation/adapter.py @@ -586,29 +586,6 @@ class ComfyUIAdapter(IsolationAdapter): register_hooks_serializers(registry) - # Generic Numpy Serializer - def serialize_numpy(obj: Any) -> Any: - import torch - - try: - # Attempt zero-copy conversion to Tensor - return torch.from_numpy(obj) - except Exception: - # Fallback for non-numeric arrays (strings, objects, mixes) - return obj.tolist() - - def deserialize_numpy_b64(data: Any) -> Any: - """Deserialize base64-encoded ndarray from sealed worker.""" - import base64 - import numpy as np - if isinstance(data, dict) and "data" in data and "dtype" in data: - raw = base64.b64decode(data["data"]) - arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"]) - return torch.from_numpy(arr.copy()) - return data - - registry.register("ndarray", serialize_numpy, deserialize_numpy_b64) - # -- File3D (comfy_api.latest._util.geometry_types) --------------------- # Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129 @@ -873,93 +850,15 @@ class ComfyUIAdapter(IsolationAdapter): return - if api_name == "PromptServerProxy": + if api_name == "PromptServerService": if not _IMPORT_TORCH: return - # Defer heavy import to child context import server + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub - instance = api() if isinstance(api, type) else api - proxy = ( - instance.instance - ) # PromptServerProxy instance has .instance property returning self - - original_register_route = proxy.register_route - - def register_route_wrapper( - method: str, path: str, handler: Callable[..., Any] - ) -> None: - callback_id = rpc.register_callback(handler) - loop = getattr(rpc, "loop", None) - if loop and loop.is_running(): - import asyncio - - asyncio.create_task( - original_register_route( - method, path, handler=callback_id, is_callback=True - ) - ) - else: - original_register_route( - method, path, handler=callback_id, is_callback=True - ) - return None - - proxy.register_route = register_route_wrapper - - class RouteTableDefProxy: - def __init__(self, proxy_instance: Any): - self.proxy = proxy_instance - - def get( - self, path: str, **kwargs: Any - ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: - self.proxy.register_route("GET", path, handler) - return handler - - return decorator - - def post( - self, path: str, **kwargs: Any - ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: - self.proxy.register_route("POST", path, handler) - return handler - - return decorator - - def patch( - self, path: str, **kwargs: Any - ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: - self.proxy.register_route("PATCH", path, handler) - return handler - - return decorator - - def put( - self, path: str, **kwargs: Any - ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: - self.proxy.register_route("PUT", path, handler) - return handler - - return decorator - - def delete( - self, path: str, **kwargs: Any - ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: - self.proxy.register_route("DELETE", path, handler) - return handler - - return decorator - - proxy.routes = RouteTableDefProxy(proxy) - + stub = PromptServerStub() if ( hasattr(server, "PromptServer") - and getattr(server.PromptServer, "instance", None) != proxy + and getattr(server.PromptServer, "instance", None) is not stub ): - server.PromptServer.instance = proxy + server.PromptServer.instance = stub diff --git a/comfy/isolation/child_hooks.py b/comfy/isolation/child_hooks.py index a009929eb..8aca5a18a 100644 --- a/comfy/isolation/child_hooks.py +++ b/comfy/isolation/child_hooks.py @@ -31,7 +31,6 @@ def _load_extra_model_paths() -> None: def initialize_child_process() -> None: - logger.warning("][ DIAG:child_hooks initialize_child_process START") if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0": _load_extra_model_paths() _setup_child_loop_bridge() @@ -41,15 +40,12 @@ def initialize_child_process() -> None: from pyisolate._internal.rpc_protocol import get_child_rpc_instance rpc = get_child_rpc_instance() - logger.warning("][ DIAG:child_hooks RPC instance: %s", rpc is not None) if rpc: _setup_proxy_callers(rpc) - logger.warning("][ DIAG:child_hooks proxy callers configured with RPC") else: - logger.warning("][ DIAG:child_hooks NO RPC — proxy callers cleared") _setup_proxy_callers() except Exception as e: - logger.error(f"][ DIAG:child_hooks Manual RPC Injection failed: {e}") + logger.error(f"][ child_hooks Manual RPC Injection failed: {e}") _setup_proxy_callers() _setup_logging() diff --git a/comfy/isolation/extension_loader.py b/comfy/isolation/extension_loader.py index 0c65b234e..10b149c60 100644 --- a/comfy/isolation/extension_loader.py +++ b/comfy/isolation/extension_loader.py @@ -354,6 +354,16 @@ async def load_isolated_node( "sandbox": sandbox_config, } + share_torch_no_deps = tool_config.get("share_torch_no_deps", []) + if share_torch_no_deps: + if not isinstance(share_torch_no_deps, list) or not all( + isinstance(dep, str) and dep.strip() for dep in share_torch_no_deps + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.share_torch_no_deps] must be a list of non-empty strings" + ) + extension_config["share_torch_no_deps"] = share_torch_no_deps + _is_sealed = execution_model == "sealed_worker" _is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux logger.info( @@ -367,6 +377,16 @@ async def load_isolated_node( if cuda_wheels is not None: extension_config["cuda_wheels"] = cuda_wheels + extra_index_urls = tool_config.get("extra_index_urls", []) + if extra_index_urls: + if not isinstance(extra_index_urls, list) or not all( + isinstance(u, str) and u.strip() for u in extra_index_urls + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.extra_index_urls] must be a list of non-empty strings" + ) + extension_config["extra_index_urls"] = extra_index_urls + # Conda-specific keys if is_conda: extension_config["package_manager"] = "conda" @@ -408,31 +428,17 @@ async def load_isolated_node( cache.register_proxy(extension_name, WebDirectoryProxy()) # Try cache first (lazy spawn) - logger.warning("][ DIAG:ext_loader cache_valid_check for %s", extension_name) if is_cache_valid(node_dir, manifest_path, venv_root): cached_data = load_from_cache(node_dir, venv_root) if cached_data: if _is_stale_node_cache(cached_data): - logger.warning( - "][ DIAG:ext_loader %s cache is stale/incompatible; rebuilding metadata", - extension_name, - ) + pass else: - logger.warning("][ DIAG:ext_loader %s USING CACHE — dumping combo options:", extension_name) - for node_name, details in cached_data.items(): - schema_v1 = details.get("schema_v1", {}) - inp = schema_v1.get("input", {}) if schema_v1 else {} - for section_name, section in inp.items(): - if isinstance(section, dict): - for field_name, field_def in section.items(): - if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]: - opts = field_def[1]["options"] - logger.warning( - "][ DIAG:ext_loader CACHE %s.%s.%s options=%d first=%s", - node_name, section_name, field_name, - len(opts), - opts[:3] if opts else "EMPTY", - ) + try: + flushed = await extension.flush_pending_routes() + logger.info("][ %s flushed %d routes", extension_name, flushed) + except Exception as exc: + logger.warning("][ %s route flush failed: %s", extension_name, exc) specs: List[Tuple[str, str, type]] = [] for node_name, details in cached_data.items(): stub_cls = build_stub_class(node_name, details, extension) @@ -440,11 +446,7 @@ async def load_isolated_node( (node_name, details.get("display_name", node_name), stub_cls) ) return specs - else: - logger.warning("][ DIAG:ext_loader %s cache INVALID or MISSING", extension_name) - # Cache miss - spawn process and get metadata - logger.warning("][ DIAG:ext_loader %s cache miss, spawning process for metadata", extension_name) try: remote_nodes: Dict[str, str] = await extension.list_nodes() @@ -466,7 +468,6 @@ async def load_isolated_node( cache_data: Dict[str, Dict] = {} for node_name, display_name in remote_nodes.items(): - logger.warning("][ DIAG:ext_loader calling get_node_details for %s.%s", extension_name, node_name) try: details = await extension.get_node_details(node_name) except Exception as exc: @@ -477,20 +478,6 @@ async def load_isolated_node( exc, ) continue - # DIAG: dump combo options from freshly-fetched details - schema_v1 = details.get("schema_v1", {}) - inp = schema_v1.get("input", {}) if schema_v1 else {} - for section_name, section in inp.items(): - if isinstance(section, dict): - for field_name, field_def in section.items(): - if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]: - opts = field_def[1]["options"] - logger.warning( - "][ DIAG:ext_loader FRESH %s.%s.%s options=%d first=%s", - node_name, section_name, field_name, - len(opts), - opts[:3] if opts else "EMPTY", - ) details["display_name"] = display_name cache_data[node_name] = details stub_cls = build_stub_class(node_name, details, extension) @@ -512,6 +499,14 @@ async def load_isolated_node( if host_policy["sandbox_mode"] == "disabled": _register_web_directory(extension_name, node_dir) + # Flush any routes the child buffered during module import — must happen + # before router freeze and before we kill the child process. + try: + flushed = await extension.flush_pending_routes() + logger.info("][ %s flushed %d routes", extension_name, flushed) + except Exception as exc: + logger.warning("][ %s route flush failed: %s", extension_name, exc) + # EJECT: Kill process after getting metadata (will respawn on first execution) await _stop_extension_safe(extension, extension_name) diff --git a/comfy/isolation/extension_wrapper.py b/comfy/isolation/extension_wrapper.py index 67ba1d5c4..059a788e1 100644 --- a/comfy/isolation/extension_wrapper.py +++ b/comfy/isolation/extension_wrapper.py @@ -211,6 +211,7 @@ class ComfyNodeExtension(ExtensionBase): self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} + self._register_module_routes(module) # Register web directory with WebDirectoryProxy (child-side) web_dir_attr = getattr(module, "WEB_DIRECTORY", None) @@ -280,6 +281,55 @@ class ComfyNodeExtension(ExtensionBase): self.node_instances = {} + def _register_module_routes(self, module: Any) -> None: + """Bridge legacy module-level ROUTES declarations into isolated routing.""" + routes = getattr(module, "ROUTES", None) or [] + if not routes: + return + + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + + prompt_server = PromptServerStub() + route_table = getattr(prompt_server, "routes", None) + if route_table is None: + logger.warning("%s Route registration unavailable for %s", LOG_PREFIX, module) + return + + for route_spec in routes: + if not isinstance(route_spec, dict): + logger.warning("%s Ignoring non-dict ROUTES entry: %r", LOG_PREFIX, route_spec) + continue + + method = str(route_spec.get("method", "")).strip().upper() + path = str(route_spec.get("path", "")).strip() + handler_ref = route_spec.get("handler") + if not method or not path: + logger.warning("%s Ignoring incomplete route spec: %r", LOG_PREFIX, route_spec) + continue + + if isinstance(handler_ref, str): + handler = getattr(module, handler_ref, None) + else: + handler = handler_ref + if not callable(handler): + logger.warning( + "%s Ignoring route with missing handler %r for %s %s", + LOG_PREFIX, + handler_ref, + method, + path, + ) + continue + + decorator = getattr(route_table, method.lower(), None) + if not callable(decorator): + logger.warning("%s Unsupported route method %s for %s", LOG_PREFIX, method, path) + continue + + decorator(path)(handler) + self._route_handlers[f"{method} {path}"] = handler + logger.info("%s buffered legacy route %s %s", LOG_PREFIX, method, path) + async def list_nodes(self) -> Dict[str, str]: return {name: self.display_names.get(name, name) for name in self.node_classes} @@ -289,10 +339,6 @@ class ComfyNodeExtension(ExtensionBase): async def get_node_details(self, node_name: str) -> Dict[str, Any]: node_cls = self._get_node_class(node_name) is_v3 = issubclass(node_cls, _ComfyNodeInternal) - logger.warning( - "%s DIAG:get_node_details START | node=%s | is_v3=%s | cls=%s", - LOG_PREFIX, node_name, is_v3, node_cls, - ) input_types_raw = ( node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {} @@ -316,16 +362,7 @@ class ComfyNodeExtension(ExtensionBase): if is_v3: try: - logger.warning( - "%s DIAG:get_node_details calling GET_SCHEMA for %s", - LOG_PREFIX, node_name, - ) schema = node_cls.GET_SCHEMA() - logger.warning( - "%s DIAG:get_node_details GET_SCHEMA returned for %s | schema_inputs=%s", - LOG_PREFIX, node_name, - [getattr(i, 'id', '?') for i in (schema.inputs or [])], - ) schema_v1 = asdict(schema.get_v1_info(node_cls)) try: schema_v3 = asdict(schema.get_v3_info(node_cls)) @@ -532,6 +569,11 @@ class ComfyNodeExtension(ExtensionBase): wrapped = self._wrap_unpicklable_objects(result) return wrapped + async def flush_pending_routes(self) -> int: + """Flush buffered route registrations to host via RPC. Called by host after node discovery.""" + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + return await PromptServerStub.flush_child_routes() + async def flush_transport_state(self) -> int: if os.environ.get("PYISOLATE_CHILD") != "1": return 0 @@ -750,19 +792,13 @@ class ComfyNodeExtension(ExtensionBase): return self.node_instances[node_name] async def before_module_loaded(self) -> None: - # Inject initialization here if we think this is the child - logger.warning( - "%s DIAG:before_module_loaded START | is_child=%s", - LOG_PREFIX, os.environ.get("PYISOLATE_CHILD"), - ) try: from comfy.isolation import initialize_proxies initialize_proxies() - logger.warning("%s DIAG:before_module_loaded initialize_proxies OK", LOG_PREFIX) except Exception as e: logger.error( - "%s DIAG:before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e + "%s before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e ) await super().before_module_loaded() diff --git a/comfy/isolation/host_policy.py b/comfy/isolation/host_policy.py index f637e89d9..91b8e5b96 100644 --- a/comfy/isolation/host_policy.py +++ b/comfy/isolation/host_policy.py @@ -166,6 +166,8 @@ def load_host_policy(comfy_root: Path) -> HostSecurityPolicy: if isinstance(whitelist_raw, dict): policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()}) + os.environ["PYISOLATE_SANDBOX_MODE"] = policy["sandbox_mode"] + logger.debug( "Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s", len(policy["whitelist"]), diff --git a/comfy/isolation/proxies/folder_paths_proxy.py b/comfy/isolation/proxies/folder_paths_proxy.py index b324da4e5..1fe7cff16 100644 --- a/comfy/isolation/proxies/folder_paths_proxy.py +++ b/comfy/isolation/proxies/folder_paths_proxy.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging import os -import traceback + from typing import Any, Dict, Optional from pyisolate import ProxiedSingleton @@ -152,24 +152,9 @@ class FolderPathsProxy(ProxiedSingleton): return list(_folder_paths().get_folder_paths(folder_name)) def get_filename_list(self, folder_name: str) -> list[str]: - caller_stack = "".join(traceback.format_stack()[-4:-1]) - _fp_logger.warning( - "][ DIAG:FolderPathsProxy.get_filename_list called | folder=%s | is_child=%s | rpc_configured=%s\n%s", - folder_name, _is_child_process(), self._rpc is not None, caller_stack, - ) if _is_child_process(): - result = list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name)) - _fp_logger.warning( - "][ DIAG:FolderPathsProxy.get_filename_list RPC result | folder=%s | count=%d | first=%s", - folder_name, len(result), result[:3] if result else "EMPTY", - ) - return result - result = list(_folder_paths().get_filename_list(folder_name)) - _fp_logger.warning( - "][ DIAG:FolderPathsProxy.get_filename_list LOCAL result | folder=%s | count=%d | first=%s", - folder_name, len(result), result[:3] if result else "EMPTY", - ) - return result + return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name)) + return list(_folder_paths().get_filename_list(folder_name)) def get_full_path(self, folder_name: str, filename: str) -> str | None: if _is_child_process(): diff --git a/comfy/isolation/proxies/prompt_server_impl.py b/comfy/isolation/proxies/prompt_server_impl.py index 3f500522e..04f4200a4 100644 --- a/comfy/isolation/proxies/prompt_server_impl.py +++ b/comfy/isolation/proxies/prompt_server_impl.py @@ -94,14 +94,13 @@ class PromptServerStub: def client_id(self) -> Optional[str]: return "isolated_client" - def supports(self, feature: str) -> bool: - return True + @property + def supports(self) -> set: + return {"custom_nodes_from_web"} @property def app(self): - raise RuntimeError( - "PromptServer.app is not accessible in isolated nodes. Use RPC routes instead." - ) + return _AppStub(self) @property def prompt_queue(self): @@ -140,18 +139,27 @@ class PromptServerStub: call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid) # --- Route Registration Logic --- - def register_route(self, method: str, path: str, handler: Callable): - """Register a route handler via RPC.""" - if not self._rpc: - logger.error("RPC not initialized in PromptServerStub") - return + _pending_child_routes: list = [] - # Fire registration async - try: - loop = asyncio.get_running_loop() - loop.create_task(self._rpc.register_route_rpc(method, path, handler)) - except RuntimeError: - call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler) + def register_route(self, method: str, path: str, handler: Callable): + """Buffer route registration. Routes are flushed via flush_child_routes().""" + PromptServerStub._pending_child_routes.append((method, path, handler)) + logger.info("%s Buffered isolated route %s %s", LOG_PREFIX, method, path) + + @classmethod + async def flush_child_routes(cls): + """Send all buffered route registrations to host via RPC. Call from on_module_loaded.""" + if not cls._rpc: + return 0 + flushed = 0 + for method, path, handler in cls._pending_child_routes: + try: + await cls._rpc.register_route_rpc(method, path, handler) + flushed += 1 + except Exception as e: + logger.error("%s Child route flush failed %s %s: %s", LOG_PREFIX, method, path, e) + cls._pending_child_routes = [] + return flushed class RouteStub: @@ -205,7 +213,6 @@ class PromptServerService(ProxiedSingleton): """Host-side RPC Service for PromptServer.""" def __init__(self): - # We will bind to the real server instance lazily or via global import pass @property @@ -231,7 +238,7 @@ class PromptServerService(ProxiedSingleton): async def register_route_rpc(self, method: str, path: str, child_handler_proxy): """RPC Target: Register a route that forwards to the Child.""" from aiohttp import web - logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}") + logger.info("%s Registering isolated route %s %s", LOG_PREFIX, method, path) async def route_wrapper(request: web.Request) -> web.Response: # 1. Capture request data @@ -253,8 +260,8 @@ class PromptServerService(ProxiedSingleton): logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}") return web.Response(status=500, text=str(e)) - # Register loop self.server.app.router.add_route(method, path, route_wrapper) + logger.info("%s Registered isolated route %s %s", LOG_PREFIX, method, path) def _serialize_response(self, result: Any) -> Any: """Helper to convert Child result -> web.Response""" @@ -269,3 +276,32 @@ class PromptServerService(ProxiedSingleton): return web.Response(text=result) # Fallback return web.Response(text=str(result)) + + +class _RouterStub: + """Captures router.add_route and router.add_static calls in isolation child.""" + + def __init__(self, stub): + self._stub = stub + + def add_route(self, method, path, handler, **kwargs): + self._stub.register_route(method, path, handler) + + def add_static(self, prefix, path, **kwargs): + # Static file serving not supported in isolation — silently skip + pass + + +class _AppStub: + """Captures PromptServer.app access patterns in isolation child.""" + + def __init__(self, stub): + self.router = _RouterStub(stub) + self.frozen = False + + def add_routes(self, routes): + # aiohttp route table — iterate and register each + for route in routes: + if hasattr(route, "method") and hasattr(route, "handler"): + self.router.add_route(route.method, route.path, route.handler) + # StaticDef and other non-method routes — silently skip diff --git a/tests/isolation/singleton_boundary_helpers.py b/tests/isolation/singleton_boundary_helpers.py index f113f6a81..471986cf0 100644 --- a/tests/isolation/singleton_boundary_helpers.py +++ b/tests/isolation/singleton_boundary_helpers.py @@ -724,6 +724,7 @@ def capture_prompt_web_exact_relay() -> dict[str, object]: from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache PromptServerStub.set_rpc(fake_rpc) + PromptServerStub._pending_child_routes = [] stub = PromptServerStub() cache = WebDirectoryCache() cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts)) @@ -735,6 +736,7 @@ def capture_prompt_web_exact_relay() -> dict[str, object]: stub.send_progress_text("hello", "node-17") stub.routes.get("/demo")(demo_handler) + asyncio.run(PromptServerStub.flush_child_routes()) web_file = cache.get_file("demo_ext", "js/app.js") imported = set(sys.modules) - before return { diff --git a/tests/isolation/test_cuda_wheels_and_env_flags.py b/tests/isolation/test_cuda_wheels_and_env_flags.py index f0361d5ef..da44a5f53 100644 --- a/tests/isolation/test_cuda_wheels_and_env_flags.py +++ b/tests/isolation/test_cuda_wheels_and_env_flags.py @@ -108,6 +108,119 @@ flash_attn = "flash-attn-special" } +def test_load_isolated_node_passes_share_torch_no_deps(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["timm", "pyyaml"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = true +share_torch_no_deps = ["timm"] +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "disabled", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + specs = asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert len(specs) == 1 + assert captured["share_torch_no_deps"] == ["timm"] + + +def test_on_module_loaded_registers_legacy_routes(monkeypatch): + captured: list[tuple[str, str, Any]] = [] + + def demo_handler(body): + return body + + module = SimpleNamespace( + __file__="/tmp/demo_node/__init__.py", + __name__="demo_node", + NODE_CLASS_MAPPINGS={}, + NODE_DISPLAY_NAME_MAPPINGS={}, + ROUTES=[ + {"method": "POST", "path": "/sam3/interactive_segment_one", "handler": "demo_handler"}, + ], + demo_handler=demo_handler, + ) + + def fake_register_route(self, method, path, handler): + captured.append((method, path, handler)) + + monkeypatch.setattr( + "comfy.isolation.proxies.prompt_server_impl.PromptServerStub.register_route", + fake_register_route, + ) + + extension = ComfyNodeExtension() + asyncio.run(extension.on_module_loaded(module)) + + assert captured == [("POST", "/sam3/interactive_segment_one", demo_handler)] + + +def test_prompt_server_stub_buffers_routes_without_rpc(): + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + + def demo_handler(body): + return body + + old_rpc = PromptServerStub._rpc + old_pending = list(PromptServerStub._pending_child_routes) + try: + PromptServerStub._rpc = None + PromptServerStub._pending_child_routes = [] + PromptServerStub().register_route("POST", "/sam3/interactive_segment_one", demo_handler) + assert PromptServerStub._pending_child_routes == [ + ("POST", "/sam3/interactive_segment_one", demo_handler) + ] + finally: + PromptServerStub._rpc = old_rpc + PromptServerStub._pending_child_routes = old_pending + + def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency( tmp_path, monkeypatch ): @@ -362,6 +475,70 @@ can_isolate = true assert "cuda_wheels" not in captured +def test_load_isolated_node_passes_extra_index_urls(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["fbxsdkpy==2020.1.post2", "numpy>=1.0"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = true +extra_index_urls = ["https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple"] +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "disabled", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + specs = asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert len(specs) == 1 + assert captured["extra_index_urls"] == [ + "https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple" + ] + + def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch): class DummyRegistry: def register(self, model):