From 7d512fa9c375f3a3cf4baea011322b5d4c92844f Mon Sep 17 00:00:00 2001 From: John Pollock Date: Tue, 7 Apr 2026 05:58:03 -0500 Subject: [PATCH] feat(isolation): core infrastructure and pyisolate integration Adds the isolation system foundation: ComfyUIAdapter, extension loader, manifest discovery, child/host process hooks, RPC bridge, runtime helpers, SHM forensics, and the --use-process-isolation CLI flag. pyisolate added to requirements.txt. .pyisolate_venvs/ added to .gitignore. --- .gitignore | 1 + comfy/cli_args.py | 2 + comfy/isolation/__init__.py | 442 ++++++++++++ comfy/isolation/adapter.py | 965 +++++++++++++++++++++++++++ comfy/isolation/child_hooks.py | 126 ++++ comfy/isolation/extension_loader.py | 521 +++++++++++++++ comfy/isolation/extension_wrapper.py | 896 +++++++++++++++++++++++++ comfy/isolation/host_hooks.py | 30 + comfy/isolation/manifest_loader.py | 221 ++++++ comfy/isolation/rpc_bridge.py | 49 ++ comfy/isolation/runtime_helpers.py | 471 +++++++++++++ comfy/isolation/shm_forensics.py | 217 ++++++ requirements.txt | 2 + 13 files changed, 3943 insertions(+) create mode 100644 comfy/isolation/__init__.py create mode 100644 comfy/isolation/adapter.py create mode 100644 comfy/isolation/child_hooks.py create mode 100644 comfy/isolation/extension_loader.py create mode 100644 comfy/isolation/extension_wrapper.py create mode 100644 comfy/isolation/host_hooks.py create mode 100644 comfy/isolation/manifest_loader.py create mode 100644 comfy/isolation/rpc_bridge.py create mode 100644 comfy/isolation/runtime_helpers.py create mode 100644 comfy/isolation/shm_forensics.py diff --git a/.gitignore b/.gitignore index 2700ad5c2..f893b5f14 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ web_custom_versions/ openapi.yaml filtered-openapi.yaml uv.lock +.pyisolate_venvs/ diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..7f57bc269 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -184,6 +184,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") +parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.") + parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") diff --git a/comfy/isolation/__init__.py b/comfy/isolation/__init__.py new file mode 100644 index 000000000..18ce059c6 --- /dev/null +++ b/comfy/isolation/__init__.py @@ -0,0 +1,442 @@ +# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation +from __future__ import annotations +import asyncio +import inspect +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Set, TYPE_CHECKING +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +load_isolated_node = None +find_manifest_directories = None +build_stub_class = None +get_class_types_for_extension = None +scan_shm_forensics = None +start_shm_forensics = None + +if _IMPORT_TORCH: + import folder_paths + from .extension_loader import load_isolated_node + from .manifest_loader import find_manifest_directories + from .runtime_helpers import build_stub_class, get_class_types_for_extension + from .shm_forensics import scan_shm_forensics, start_shm_forensics + +if TYPE_CHECKING: + from pyisolate import ExtensionManager + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +isolated_node_timings: List[tuple[float, Path, int]] = [] + +if _IMPORT_TORCH: + PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs" + PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True) + +logger = logging.getLogger(__name__) +_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 +_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000 + + +def initialize_proxies() -> None: + from .child_hooks import is_child_process + + is_child = is_child_process() + logger.warning( + "%s DIAG:initialize_proxies | is_child=%s | PYISOLATE_CHILD=%s", + LOG_PREFIX, is_child, os.environ.get("PYISOLATE_CHILD"), + ) + + if is_child: + from .child_hooks import initialize_child_process + + initialize_child_process() + logger.warning("%s DIAG:initialize_proxies child_process initialized", LOG_PREFIX) + else: + from .host_hooks import initialize_host_process + + initialize_host_process() + logger.warning("%s DIAG:initialize_proxies host_process initialized", LOG_PREFIX) + if start_shm_forensics is not None: + start_shm_forensics() + + +@dataclass(frozen=True) +class IsolatedNodeSpec: + node_name: str + display_name: str + stub_class: type + module_path: Path + + +_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = [] +_CLAIMED_PATHS: Set[Path] = set() +_ISOLATION_SCAN_ATTEMPTED = False +_EXTENSION_MANAGERS: List["ExtensionManager"] = [] +_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {} +_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None +_EARLY_START_TIME: Optional[float] = None + + +def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + return + _EARLY_START_TIME = time.perf_counter() + _ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes()) + + +async def await_isolation_loading() -> List[IsolatedNodeSpec]: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + specs = await _ISOLATION_BACKGROUND_TASK + return specs + return await initialize_isolation_nodes() + + +async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]: + global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS + + if _ISOLATED_NODE_SPECS: + return _ISOLATED_NODE_SPECS + + if _ISOLATION_SCAN_ATTEMPTED: + return [] + + _ISOLATION_SCAN_ATTEMPTED = True + if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None: + return [] + manifest_entries = find_manifest_directories() + _CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries} + + if not manifest_entries: + return [] + + os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1" + concurrency_limit = max(1, (os.cpu_count() or 4) // 2) + semaphore = asyncio.Semaphore(concurrency_limit) + + async def load_with_semaphore( + node_dir: Path, manifest: Path + ) -> List[IsolatedNodeSpec]: + async with semaphore: + load_start = time.perf_counter() + spec_list = await load_isolated_node( + node_dir, + manifest, + logger, + lambda name, info, extension: build_stub_class( + name, + info, + extension, + _RUNNING_EXTENSIONS, + logger, + ), + PYISOLATE_VENV_ROOT, + _EXTENSION_MANAGERS, + ) + spec_list = [ + IsolatedNodeSpec( + node_name=node_name, + display_name=display_name, + stub_class=stub_cls, + module_path=node_dir, + ) + for node_name, display_name, stub_cls in spec_list + ] + isolated_node_timings.append( + (time.perf_counter() - load_start, node_dir, len(spec_list)) + ) + return spec_list + + tasks = [ + load_with_semaphore(node_dir, manifest) + for node_dir, manifest in manifest_entries + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + specs: List[IsolatedNodeSpec] = [] + for result in results: + if isinstance(result, Exception): + logger.error( + "%s Isolated node failed during startup; continuing: %s", + LOG_PREFIX, + result, + ) + continue + specs.extend(result) + + _ISOLATED_NODE_SPECS = specs + return list(_ISOLATED_NODE_SPECS) + + +def _get_class_types_for_extension(extension_name: str) -> Set[str]: + """Get all node class types (node names) belonging to an extension.""" + extension = _RUNNING_EXTENSIONS.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in _ISOLATED_NODE_SPECS: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + + return class_types + + +async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None: + """Evict running extensions not needed for current execution. + + When *caches* is provided, cache entries for evicted extensions' node + class_types are invalidated to prevent stale ``RemoteObjectHandle`` + references from surviving in the output cache. + """ + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:notify_graph_wait_idle", + ) + + evicted_class_types: Set[str] = set() + + async def _stop_extension( + ext_name: str, extension: "ComfyNodeExtension", reason: str + ) -> None: + # Collect class_types BEFORE stopping so we can invalidate cache entries. + ext_class_types = _get_class_types_for_extension(ext_name) + evicted_class_types.update(ext_class_types) + logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason) + logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name) + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + _RUNNING_EXTENSIONS.pop(ext_name, None) + logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name) + if scan_shm_forensics is not None: + scan_shm_forensics("ISO:stop_extension", refresh_model_context=True) + + if scan_shm_forensics is not None: + scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True) + isolated_class_types_in_graph = needed_class_types.intersection( + {spec.node_name for spec in _ISOLATED_NODE_SPECS} + ) + graph_uses_isolation = bool(isolated_class_types_in_graph) + logger.debug( + "%s ISO:notify_graph_start running=%d needed=%d", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + len(needed_class_types), + ) + if graph_uses_isolation: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + ext_class_types = _get_class_types_for_extension(ext_name) + + # If NONE of this extension's nodes are in the execution graph -> evict. + if not ext_class_types.intersection(needed_class_types): + await _stop_extension( + ext_name, + extension, + "isolated custom_node not in execution graph, evicting", + ) + else: + logger.debug( + "%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + ) + + # Isolated child processes add steady VRAM pressure; reclaim host-side models + # at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom. + try: + import comfy.model_management as model_management + + device = model_management.get_torch_device() + if getattr(device, "type", None) == "cuda": + required = max( + model_management.minimum_inference_memory(), + _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES, + ) + free_before = model_management.get_free_memory(device) + if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + await _stop_extension( + ext_name, + extension, + f"boundary low-vram restart (free={int(free_before)} target={int(required)})", + ) + if model_management.get_free_memory(device) < required: + model_management.unload_all_models() + model_management.cleanup_models_gc() + model_management.cleanup_models() + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.soft_empty_cache() + except Exception: + logger.debug( + "%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True + ) + finally: + # Invalidate cached outputs for evicted extensions so stale + # RemoteObjectHandle references are not served from cache. + if evicted_class_types and caches: + total_invalidated = 0 + for cache in caches: + if hasattr(cache, "invalidate_by_class_types"): + total_invalidated += cache.invalidate_by_class_types( + evicted_class_types + ) + if total_invalidated > 0: + logger.info( + "%s ISO:cache_invalidated count=%d class_types=%s", + LOG_PREFIX, + total_invalidated, + evicted_class_types, + ) + scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True) + logger.debug( + "%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS) + ) + + +async def flush_running_extensions_transport_state() -> int: + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:flush_transport_wait_idle", + ) + total_flushed = 0 + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + flush_fn = getattr(extension, "flush_transport_state", None) + if not callable(flush_fn): + continue + try: + flushed = await flush_fn() + if isinstance(flushed, int): + total_flushed += flushed + if flushed > 0: + logger.debug( + "%s %s workflow-end flush released=%d", + LOG_PREFIX, + ext_name, + flushed, + ) + except Exception: + logger.debug( + "%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True + ) + scan_shm_forensics( + "ISO:flush_running_extensions_transport_state", refresh_model_context=True + ) + return total_flushed + + +async def wait_for_model_patcher_quiescence( + timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS, + *, + fail_loud: bool = False, + marker: str = "ISO:wait_model_patcher_idle", +) -> bool: + try: + from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry + + registry = ModelPatcherRegistry() + start = time.perf_counter() + idle = await registry.wait_all_idle(timeout_ms) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + if idle: + logger.debug( + "%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + ) + return True + + states = await registry.get_all_operation_states() + logger.error( + "%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + states, + ) + if fail_loud: + raise TimeoutError( + f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms" + ) + return False + except Exception: + if fail_loud: + raise + logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True) + return False + + +def get_claimed_paths() -> Set[Path]: + return _CLAIMED_PATHS + + +def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None: + """Update all active RPC instances with the current event loop. + + This MUST be called at the start of each workflow execution to ensure + RPC calls are scheduled on the correct event loop. This handles the case + where asyncio.run() creates a new event loop for each workflow. + + Args: + loop: The event loop to use. If None, uses asyncio.get_running_loop(). + """ + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + + update_count = 0 + + # Update RPCs from ExtensionManagers + for manager in _EXTENSION_MANAGERS: + if not hasattr(manager, "extensions"): + continue + for name, extension in manager.extensions.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'") + + # Also update RPCs from running extensions (they may have direct RPC refs) + for name, extension in _RUNNING_EXTENSIONS.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'") + + if update_count > 0: + logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances") + else: + logger.debug( + f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})" + ) + + +__all__ = [ + "LOG_PREFIX", + "initialize_proxies", + "initialize_isolation_nodes", + "start_isolation_loading_early", + "await_isolation_loading", + "notify_execution_graph", + "flush_running_extensions_transport_state", + "wait_for_model_patcher_quiescence", + "get_claimed_paths", + "update_rpc_event_loops", + "IsolatedNodeSpec", + "get_class_types_for_extension", +] diff --git a/comfy/isolation/adapter.py b/comfy/isolation/adapter.py new file mode 100644 index 000000000..4751dee51 --- /dev/null +++ b/comfy/isolation/adapter.py @@ -0,0 +1,965 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position +from __future__ import annotations + +import logging +import os +import inspect +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, cast + +from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped] +from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped] + +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp. +# Safe to import in sealed workers without host framework modules. +from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy +from comfy.isolation.proxies.helper_proxies import HelperProxiesService +from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + +# Singleton proxies that transitively import torch, PIL, or heavy host modules. +# Only available when torch/host framework is present. +CLIPProxy = None +CLIPRegistry = None +ModelPatcherProxy = None +ModelPatcherRegistry = None +ModelSamplingProxy = None +ModelSamplingRegistry = None +VAEProxy = None +VAERegistry = None +FirstStageModelRegistry = None +ModelManagementProxy = None +PromptServerService = None +ProgressProxy = None +UtilsProxy = None +_HAS_TORCH_PROXIES = False +if _IMPORT_TORCH: + from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry + from comfy.isolation.model_patcher_proxy import ( + ModelPatcherProxy, + ModelPatcherRegistry, + ) + from comfy.isolation.model_sampling_proxy import ( + ModelSamplingProxy, + ModelSamplingRegistry, + ) + from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + from comfy.isolation.proxies.prompt_server_impl import PromptServerService + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + _HAS_TORCH_PROXIES = True + +logger = logging.getLogger(__name__) + +# Force /dev/shm for shared memory (bwrap makes /tmp private) +import tempfile + +if os.path.exists("/dev/shm"): + # Only override if not already set or if default is not /dev/shm + current_tmp = tempfile.gettempdir() + if not current_tmp.startswith("/dev/shm"): + logger.debug( + f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm" + ) + os.environ["TMPDIR"] = "/dev/shm" + tempfile.tempdir = None # Clear cache to force re-evaluation + + +class ComfyUIAdapter(IsolationAdapter): + # ComfyUI-specific IsolationAdapter implementation + + @property + def identifier(self) -> str: + return "comfyui" + + def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]: + if "ComfyUI" in module_path and "custom_nodes" in module_path: + parts = module_path.split("ComfyUI") + if len(parts) > 1: + comfy_root = parts[0] + "ComfyUI" + return { + "preferred_root": comfy_root, + "additional_paths": [ + os.path.join(comfy_root, "custom_nodes"), + os.path.join(comfy_root, "comfy"), + ], + "filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"], + } + return None + + def get_sandbox_system_paths(self) -> Optional[List[str]]: + """Returns required application paths to mount in the sandbox.""" + # By inspecting where our adapter is loaded from, we can determine the comfy root + adapter_file = inspect.getfile(self.__class__) + # adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py + comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file))) + if os.path.exists(comfy_root): + return [comfy_root] + return None + + def setup_child_environment(self, snapshot: Dict[str, Any]) -> None: + comfy_root = snapshot.get("preferred_root") + if not comfy_root: + return + + requirements_path = Path(comfy_root) / "requirements.txt" + if requirements_path.exists(): + import re + + for line in requirements_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + pkg_name = re.split(r"[<>=!~\[]", line)[0].strip() + if pkg_name: + logging.getLogger(pkg_name).setLevel(logging.ERROR) + + def register_serializers(self, registry: SerializerRegistryProtocol) -> None: + if not _IMPORT_TORCH: + # Sealed worker without torch — register torch-free TensorValue handler + # so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts. + import numpy as np + + _TORCH_DTYPE_TO_NUMPY = { + "torch.float32": np.float32, + "torch.float64": np.float64, + "torch.float16": np.float16, + "torch.bfloat16": np.float32, # numpy has no bfloat16; upcast + "torch.int32": np.int32, + "torch.int64": np.int64, + "torch.int16": np.int16, + "torch.int8": np.int8, + "torch.uint8": np.uint8, + "torch.bool": np.bool_, + } + + def _deserialize_tensor_value(data: Dict[str, Any]) -> Any: + dtype_str = data["dtype"] + np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32) + shape = tuple(data["tensor_size"]) + arr = np.array(data["data"], dtype=np_dtype).reshape(shape) + return arr + + _NUMPY_TO_TORCH_DTYPE = { + np.float32: "torch.float32", + np.float64: "torch.float64", + np.float16: "torch.float16", + np.int32: "torch.int32", + np.int64: "torch.int64", + np.int16: "torch.int16", + np.int8: "torch.int8", + np.uint8: "torch.uint8", + np.bool_: "torch.bool", + } + + def _serialize_tensor_value(obj: Any) -> Dict[str, Any]: + arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj) + dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32") + return { + "__type__": "TensorValue", + "dtype": dtype_str, + "tensor_size": list(arr.shape), + "requires_grad": False, + "data": arr.tolist(), + } + + registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True) + # ndarray output from sealed workers serializes as TensorValue for host torch reconstruction + registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True) + return + + import torch + + def serialize_device(obj: Any) -> Dict[str, Any]: + return {"__type__": "device", "device_str": str(obj)} + + def deserialize_device(data: Dict[str, Any]) -> Any: + return torch.device(data["device_str"]) + + registry.register("device", serialize_device, deserialize_device) + + _VALID_DTYPES = { + "float16", "float32", "float64", "bfloat16", + "int8", "int16", "int32", "int64", + "uint8", "bool", + } + + def serialize_dtype(obj: Any) -> Dict[str, Any]: + return {"__type__": "dtype", "dtype_str": str(obj)} + + def deserialize_dtype(data: Dict[str, Any]) -> Any: + dtype_name = data["dtype_str"].replace("torch.", "") + if dtype_name not in _VALID_DTYPES: + raise ValueError(f"Invalid dtype: {data['dtype_str']}") + return getattr(torch, dtype_name) + + registry.register("dtype", serialize_dtype, deserialize_dtype) + + from comfy_api.latest._io import FolderType + from comfy_api.latest._ui import SavedImages, SavedResult + + def serialize_saved_result(obj: Any) -> Dict[str, Any]: + return { + "__type__": "SavedResult", + "filename": obj.filename, + "subfolder": obj.subfolder, + "folder_type": obj.type.value, + } + + def deserialize_saved_result(data: Dict[str, Any]) -> Any: + if isinstance(data, SavedResult): + return data + folder_type = data["folder_type"] if "folder_type" in data else data["type"] + return SavedResult( + filename=data["filename"], + subfolder=data["subfolder"], + type=FolderType(folder_type), + ) + + registry.register( + "SavedResult", + serialize_saved_result, + deserialize_saved_result, + data_type=True, + ) + + def serialize_saved_images(obj: Any) -> Dict[str, Any]: + return { + "__type__": "SavedImages", + "results": [serialize_saved_result(result) for result in obj.results], + "is_animated": obj.is_animated, + } + + def deserialize_saved_images(data: Dict[str, Any]) -> Any: + return SavedImages( + results=[deserialize_saved_result(result) for result in data["results"]], + is_animated=data.get("is_animated", False), + ) + + registry.register( + "SavedImages", + serialize_saved_images, + deserialize_saved_images, + data_type=True, + ) + + def serialize_model_patcher(obj: Any) -> Dict[str, Any]: + # Child-side: must already have _instance_id (proxy) + if os.environ.get("PYISOLATE_CHILD") == "1": + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + raise RuntimeError( + f"ModelPatcher in child lacks _instance_id: " + f"{type(obj).__module__}.{type(obj).__name__}" + ) + # Host-side: register with registry + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + model_id = ModelPatcherRegistry().register(obj) + return {"__type__": "ModelPatcherRef", "model_id": model_id} + + def deserialize_model_patcher(data: Any) -> Any: + """Deserialize ModelPatcher refs; pass through already-materialized objects.""" + if isinstance(data, dict): + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + return data + + def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelPatcherRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + else: + return ModelPatcherRegistry()._get_instance(data["model_id"]) + + # Register ModelPatcher type for serialization + registry.register( + "ModelPatcher", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherProxy type (already a proxy, just return ref) + registry.register( + "ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherRef for deserialization (context-aware: host or child) + registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref) + + def serialize_clip(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "CLIPRef", "clip_id": obj._instance_id} + clip_id = CLIPRegistry().register(obj) + return {"__type__": "CLIPRef", "clip_id": clip_id} + + def deserialize_clip(data: Any) -> Any: + if isinstance(data, dict): + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + return data + + def deserialize_clip_ref(data: Dict[str, Any]) -> Any: + """Context-aware CLIPRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + else: + return CLIPRegistry()._get_instance(data["clip_id"]) + + # Register CLIP type for serialization + registry.register("CLIP", serialize_clip, deserialize_clip) + # Register CLIPProxy type (already a proxy, just return ref) + registry.register("CLIPProxy", serialize_clip, deserialize_clip) + # Register CLIPRef for deserialization (context-aware: host or child) + registry.register("CLIPRef", None, deserialize_clip_ref) + + def serialize_vae(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "VAERef", "vae_id": obj._instance_id} + vae_id = VAERegistry().register(obj) + return {"__type__": "VAERef", "vae_id": vae_id} + + def deserialize_vae(data: Any) -> Any: + if isinstance(data, dict): + return VAEProxy(data["vae_id"]) + return data + + def deserialize_vae_ref(data: Dict[str, Any]) -> Any: + """Context-aware VAERef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + # Child: create a proxy + return VAEProxy(data["vae_id"]) + else: + # Host: lookup real VAE from registry + return VAERegistry()._get_instance(data["vae_id"]) + + # Register VAE type for serialization + registry.register("VAE", serialize_vae, deserialize_vae) + # Register VAEProxy type (already a proxy, just return ref) + registry.register("VAEProxy", serialize_vae, deserialize_vae) + # Register VAERef for deserialization (context-aware: host or child) + registry.register("VAERef", None, deserialize_vae_ref) + + # ModelSampling serialization - handles ModelSampling* types + # copyreg removed - no pickle fallback allowed + + def serialize_model_sampling(obj: Any) -> Dict[str, Any]: + # Proxy with _instance_id — return ref (works from both host and child) + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} + # Child-side: object created locally in child (e.g. ModelSamplingAdvanced + # in nodes_z_image_turbo.py). Serialize as inline data so the host can + # reconstruct the real torch.nn.Module. + if os.environ.get("PYISOLATE_CHILD") == "1": + import base64 + import io as _io + + # Identify base classes from comfy.model_sampling + bases = [] + for base in type(obj).__mro__: + if base.__module__ == "comfy.model_sampling" and base.__name__ != "object": + bases.append(base.__name__) + # Serialize state_dict as base64 safetensors-like + sd = obj.state_dict() + sd_serialized = {} + for k, v in sd.items(): + buf = _io.BytesIO() + torch.save(v, buf) + sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii") + # Capture plain attrs (shift, multiplier, sigma_data, etc.) + plain_attrs = {} + for k, v in obj.__dict__.items(): + if k.startswith("_"): + continue + if isinstance(v, (bool, int, float, str)): + plain_attrs[k] = v + return { + "__type__": "ModelSamplingInline", + "bases": bases, + "state_dict": sd_serialized, + "attrs": plain_attrs, + } + # Host-side: register with ModelSamplingRegistry and return JSON-safe dict + ms_id = ModelSamplingRegistry().register(obj) + return {"__type__": "ModelSamplingRef", "ms_id": ms_id} + + def deserialize_model_sampling(data: Any) -> Any: + """Deserialize ModelSampling refs or inline data.""" + if isinstance(data, dict): + if data.get("__type__") == "ModelSamplingInline": + return _reconstruct_model_sampling_inline(data) + return ModelSamplingProxy(data["ms_id"]) + return data + + def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any: + """Reconstruct a ModelSampling object on the host from inline child data.""" + import comfy.model_sampling as _ms + import base64 + import io as _io + + # Resolve base classes + base_classes = [] + for name in data["bases"]: + cls = getattr(_ms, name, None) + if cls is not None: + base_classes.append(cls) + if not base_classes: + raise RuntimeError( + f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}" + ) + # Create dynamic class matching the child's class hierarchy + ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {}) + obj = ReconstructedSampling.__new__(ReconstructedSampling) + torch.nn.Module.__init__(obj) + # Restore plain attributes first + for k, v in data.get("attrs", {}).items(): + setattr(obj, k, v) + # Restore state_dict (buffers like sigmas) + for k, v_b64 in data.get("state_dict", {}).items(): + buf = _io.BytesIO(base64.b64decode(v_b64)) + tensor = torch.load(buf, weights_only=True) + # Register as buffer so it's part of state_dict + parts = k.split(".") + if len(parts) == 1: + cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member + else: + setattr(obj, parts[0], tensor) + # Register on host so future references use proxy pattern. + # Skip in child process — register() is async RPC and cannot be + # called synchronously during deserialization. + if os.environ.get("PYISOLATE_CHILD") != "1": + ModelSamplingRegistry().register(obj) + return obj + + def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelSamplingRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelSamplingProxy(data["ms_id"]) + else: + return ModelSamplingRegistry()._get_instance(data["ms_id"]) + + # Register all ModelSampling* and StableCascadeSampling classes dynamically + import comfy.model_sampling + + for ms_cls in vars(comfy.model_sampling).values(): + if not isinstance(ms_cls, type): + continue + if not issubclass(ms_cls, torch.nn.Module): + continue + if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"): + continue + registry.register( + ms_cls.__name__, + serialize_model_sampling, + deserialize_model_sampling, + ) + registry.register( + "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling + ) + # Register ModelSamplingRef for deserialization (context-aware: host or child) + registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref) + # Register ModelSamplingInline for deserialization (child→host inline transfer) + registry.register( + "ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data) + ) + + def serialize_cond(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "cond": obj.cond, + } + + def deserialize_cond(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + return cls(data["cond"]) + + def _serialize_public_state(obj: Any) -> Dict[str, Any]: + state: Dict[str, Any] = {} + for key, value in obj.__dict__.items(): + if key.startswith("_"): + continue + if callable(value): + continue + state[key] = value + return state + + def serialize_latent_format(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "state": _serialize_public_state(obj), + } + + def deserialize_latent_format(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + obj = cls() + for key, value in data.get("state", {}).items(): + prop = getattr(type(obj), key, None) + if isinstance(prop, property) and prop.fset is None: + continue + setattr(obj, key, value) + return obj + + import comfy.conds + + for cond_cls in vars(comfy.conds).values(): + if not isinstance(cond_cls, type): + continue + if not issubclass(cond_cls, comfy.conds.CONDRegular): + continue + type_key = f"{cond_cls.__module__}.{cond_cls.__name__}" + registry.register(type_key, serialize_cond, deserialize_cond) + registry.register(cond_cls.__name__, serialize_cond, deserialize_cond) + + import comfy.latent_formats + + for latent_cls in vars(comfy.latent_formats).values(): + if not isinstance(latent_cls, type): + continue + if not issubclass(latent_cls, comfy.latent_formats.LatentFormat): + continue + type_key = f"{latent_cls.__module__}.{latent_cls.__name__}" + registry.register( + type_key, serialize_latent_format, deserialize_latent_format + ) + registry.register( + latent_cls.__name__, serialize_latent_format, deserialize_latent_format + ) + + # V3 API: unwrap NodeOutput.args + def deserialize_node_output(data: Any) -> Any: + return getattr(data, "args", data) + + registry.register("NodeOutput", None, deserialize_node_output) + + # KSAMPLER serializer: stores sampler name instead of function object + # sampler_function is a callable which gets filtered out by JSONSocketTransport + def serialize_ksampler(obj: Any) -> Dict[str, Any]: + func_name = obj.sampler_function.__name__ + # Map function name back to sampler name + if func_name == "sample_unipc": + sampler_name = "uni_pc" + elif func_name == "sample_unipc_bh2": + sampler_name = "uni_pc_bh2" + elif func_name == "dpm_fast_function": + sampler_name = "dpm_fast" + elif func_name == "dpm_adaptive_function": + sampler_name = "dpm_adaptive" + elif func_name.startswith("sample_"): + sampler_name = func_name[7:] # Remove "sample_" prefix + else: + sampler_name = func_name + return { + "__type__": "KSAMPLER", + "sampler_name": sampler_name, + "extra_options": obj.extra_options, + "inpaint_options": obj.inpaint_options, + } + + def deserialize_ksampler(data: Dict[str, Any]) -> Any: + import comfy.samplers + + return comfy.samplers.ksampler( + data["sampler_name"], + data.get("extra_options", {}), + data.get("inpaint_options", {}), + ) + + registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler) + + from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers + + register_hooks_serializers(registry) + + # Generic Numpy Serializer + def serialize_numpy(obj: Any) -> Any: + import torch + + try: + # Attempt zero-copy conversion to Tensor + return torch.from_numpy(obj) + except Exception: + # Fallback for non-numeric arrays (strings, objects, mixes) + return obj.tolist() + + def deserialize_numpy_b64(data: Any) -> Any: + """Deserialize base64-encoded ndarray from sealed worker.""" + import base64 + import numpy as np + if isinstance(data, dict) and "data" in data and "dtype" in data: + raw = base64.b64decode(data["data"]) + arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"]) + return torch.from_numpy(arr.copy()) + return data + + registry.register("ndarray", serialize_numpy, deserialize_numpy_b64) + + # -- File3D (comfy_api.latest._util.geometry_types) --------------------- + # Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129 + + def serialize_file3d(obj: Any) -> Dict[str, Any]: + import base64 + return { + "__type__": "File3D", + "format": obj.format, + "data": base64.b64encode(obj.get_bytes()).decode("ascii"), + } + + def deserialize_file3d(data: Any) -> Any: + import base64 + from io import BytesIO + from comfy_api.latest._util.geometry_types import File3D + return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"]) + + registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True) + + # -- VIDEO (comfy_api.latest._input_impl.video_types) ------------------- + # Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962 + + def serialize_video(obj: Any) -> Dict[str, Any]: + components = obj.get_components() + images = components.images.detach() if components.images.requires_grad else components.images + result: Dict[str, Any] = { + "__type__": "VIDEO", + "images": images, + "frame_rate_num": components.frame_rate.numerator, + "frame_rate_den": components.frame_rate.denominator, + } + if components.audio is not None: + waveform = components.audio["waveform"] + if waveform.requires_grad: + waveform = waveform.detach() + result["audio_waveform"] = waveform + result["audio_sample_rate"] = components.audio["sample_rate"] + if components.metadata is not None: + result["metadata"] = components.metadata + return result + + def deserialize_video(data: Any) -> Any: + from fractions import Fraction + from comfy_api.latest._input_impl.video_types import VideoFromComponents + from comfy_api.latest._util.video_types import VideoComponents + audio = None + if "audio_waveform" in data: + audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]} + components = VideoComponents( + images=data["images"], + frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]), + audio=audio, + metadata=data.get("metadata"), + ) + return VideoFromComponents(components) + + registry.register("VIDEO", serialize_video, deserialize_video, data_type=True) + registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True) + registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True) + + def setup_web_directory(self, module: Any) -> None: + """Detect WEB_DIRECTORY on a module and populate/register it. + + Called by the sealed worker after loading the node module. + Mirrors extension_wrapper.py:216-227 for host-coupled nodes. + Does NOT import extension_wrapper.py (it has `import torch` at module level). + """ + import shutil + + web_dir_attr = getattr(module, "WEB_DIRECTORY", None) + if web_dir_attr is None: + return + + module_dir = os.path.dirname(os.path.abspath(module.__file__)) + web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr)) + + # Read extension name from pyproject.toml + ext_name = os.path.basename(module_dir) + pyproject = os.path.join(module_dir, "pyproject.toml") + if os.path.exists(pyproject): + try: + import tomllib + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + try: + with open(pyproject, "rb") as f: + data = tomllib.load(f) + name = data.get("project", {}).get("name") + if name: + ext_name = name + except Exception: + pass + + # Populate web dir if empty (mirrors _run_prestartup_web_copy) + if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))): + os.makedirs(web_dir_path, exist_ok=True) + + # Module-defined copy spec + copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None) + if copy_spec is not None and callable(copy_spec): + try: + copy_spec(web_dir_path) + except Exception as e: + logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e) + + # Fallback: comfy_3d_viewers + try: + from comfy_3d_viewers import copy_viewer, VIEWER_FILES + for viewer in VIEWER_FILES: + try: + copy_viewer(viewer, web_dir_path) + except Exception: + pass + except ImportError: + pass + + # Fallback: comfy_dynamic_widgets + try: + from comfy_dynamic_widgets import get_js_path + src = os.path.realpath(get_js_path()) + if os.path.exists(src): + dst_dir = os.path.join(web_dir_path, "js") + os.makedirs(dst_dir, exist_ok=True) + shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js")) + except ImportError: + pass + + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + WebDirectoryProxy.register_web_dir(ext_name, web_dir_path) + logger.info( + "][ Adapter: registered web dir for %s (%d files)", + ext_name, + sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()), + ) + + @staticmethod + def register_host_event_handlers(extension: Any) -> None: + """Register host-side event handlers for an isolated extension. + + Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK`` + so the ComfyUI frontend receives progress bar updates. + """ + register_event_handler = inspect.getattr_static( + extension, "register_event_handler", None + ) + if not callable(register_event_handler): + return + + def _host_progress_handler(payload: dict) -> None: + import comfy.utils + + hook = comfy.utils.PROGRESS_BAR_HOOK + if hook is not None: + hook( + payload.get("value", 0), + payload.get("total", 0), + payload.get("preview"), + payload.get("node_id"), + ) + + extension.register_event_handler("progress", _host_progress_handler) + + def setup_child_event_hooks(self, extension: Any) -> None: + """Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension. + + Host-coupled only — sealed workers do not have comfy.utils (torch). + """ + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child) + if not is_child: + return + + if not _IMPORT_TORCH: + logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)") + return + + import comfy.utils + + def _event_progress_hook(value, total, preview=None, node_id=None): + logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id) + extension.emit_event("progress", { + "value": value, + "total": total, + "node_id": node_id, + }) + + comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook + logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel") + + def provide_rpc_services(self) -> List[type[ProxiedSingleton]]: + # Always available — no torch/PIL dependency + services: List[type[ProxiedSingleton]] = [ + FolderPathsProxy, + HelperProxiesService, + WebDirectoryProxy, + ] + # Torch/PIL-dependent proxies + if _HAS_TORCH_PROXIES: + services.extend([ + PromptServerService, + ModelManagementProxy, + UtilsProxy, + ProgressProxy, + VAERegistry, + CLIPRegistry, + ModelPatcherRegistry, + ModelSamplingRegistry, + FirstStageModelRegistry, + ]) + return services + + def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None: + # Resolve the real name whether it's an instance or the Singleton class itself + api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__ + + if api_name == "FolderPathsProxy": + import folder_paths + + # Replace module-level functions with proxy methods + # This is aggressive but necessary for transparent proxying + # Handle both instance and class cases + instance = api() if isinstance(api, type) else api + for name in dir(instance): + if not name.startswith("_"): + setattr(folder_paths, name, getattr(instance, name)) + + # Fence: isolated children get writable temp inside sandbox + if os.environ.get("PYISOLATE_CHILD") == "1": + import tempfile + _child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp") + os.makedirs(_child_temp, exist_ok=True) + folder_paths.temp_directory = _child_temp + + return + + if api_name == "ModelManagementProxy": + if _IMPORT_TORCH: + import comfy.model_management + + instance = api() if isinstance(api, type) else api + # Replace module-level functions with proxy methods + for name in dir(instance): + if not name.startswith("_"): + setattr(comfy.model_management, name, getattr(instance, name)) + return + + if api_name == "UtilsProxy": + if not _IMPORT_TORCH: + logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)") + return + + import comfy.utils + + # Static Injection of RPC mechanism to ensure Child can access it + # independent of instance lifecycle. + api.set_rpc(rpc) + + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child) + + # Progress hook wiring moved to setup_child_event_hooks via event channel + + return + + if api_name == "PromptServerProxy": + if not _IMPORT_TORCH: + return + # Defer heavy import to child context + import server + + instance = api() if isinstance(api, type) else api + proxy = ( + instance.instance + ) # PromptServerProxy instance has .instance property returning self + + original_register_route = proxy.register_route + + def register_route_wrapper( + method: str, path: str, handler: Callable[..., Any] + ) -> None: + callback_id = rpc.register_callback(handler) + loop = getattr(rpc, "loop", None) + if loop and loop.is_running(): + import asyncio + + asyncio.create_task( + original_register_route( + method, path, handler=callback_id, is_callback=True + ) + ) + else: + original_register_route( + method, path, handler=callback_id, is_callback=True + ) + return None + + proxy.register_route = register_route_wrapper + + class RouteTableDefProxy: + def __init__(self, proxy_instance: Any): + self.proxy = proxy_instance + + def get( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("GET", path, handler) + return handler + + return decorator + + def post( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("POST", path, handler) + return handler + + return decorator + + def patch( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("PATCH", path, handler) + return handler + + return decorator + + def put( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("PUT", path, handler) + return handler + + return decorator + + def delete( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("DELETE", path, handler) + return handler + + return decorator + + proxy.routes = RouteTableDefProxy(proxy) + + if ( + hasattr(server, "PromptServer") + and getattr(server.PromptServer, "instance", None) != proxy + ): + server.PromptServer.instance = proxy diff --git a/comfy/isolation/child_hooks.py b/comfy/isolation/child_hooks.py new file mode 100644 index 000000000..a009929eb --- /dev/null +++ b/comfy/isolation/child_hooks.py @@ -0,0 +1,126 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation +# Child process initialization for PyIsolate +import logging +import os + +logger = logging.getLogger(__name__) + + +def is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def _load_extra_model_paths() -> None: + """Load extra_model_paths.yaml so the child's folder_paths has the same search paths as the host. + + The host loads this in main.py:143-145. The child is spawned by + pyisolate's uds_client.py and never runs main.py, so folder_paths + only has the base model directories. Any isolated node calling + folder_paths.get_filename_list() in define_schema() would get empty + results for folders whose files live in extra_model_paths locations. + """ + import folder_paths # noqa: F401 — side-effect import; load_extra_path_config writes to folder_paths internals + from utils.extra_config import load_extra_path_config + + extra_config_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "extra_model_paths.yaml", + ) + if os.path.isfile(extra_config_path): + load_extra_path_config(extra_config_path) + + +def initialize_child_process() -> None: + logger.warning("][ DIAG:child_hooks initialize_child_process START") + if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0": + _load_extra_model_paths() + _setup_child_loop_bridge() + + # Manual RPC injection + try: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + logger.warning("][ DIAG:child_hooks RPC instance: %s", rpc is not None) + if rpc: + _setup_proxy_callers(rpc) + logger.warning("][ DIAG:child_hooks proxy callers configured with RPC") + else: + logger.warning("][ DIAG:child_hooks NO RPC — proxy callers cleared") + _setup_proxy_callers() + except Exception as e: + logger.error(f"][ DIAG:child_hooks Manual RPC Injection failed: {e}") + _setup_proxy_callers() + + _setup_logging() + + +def _setup_child_loop_bridge() -> None: + import asyncio + + main_loop = None + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + try: + main_loop = asyncio.get_event_loop() + except RuntimeError: + pass + + if main_loop is None: + return + + try: + from .proxies.base import set_global_loop + + set_global_loop(main_loop) + except ImportError: + pass + + +def _setup_prompt_server_stub(rpc=None) -> None: + try: + from .proxies.prompt_server_impl import PromptServerStub + + if rpc: + PromptServerStub.set_rpc(rpc) + elif hasattr(PromptServerStub, "clear_rpc"): + PromptServerStub.clear_rpc() + else: + PromptServerStub._rpc = None # type: ignore[attr-defined] + + except Exception as e: + logger.error(f"Failed to setup PromptServerStub: {e}") + + +def _setup_proxy_callers(rpc=None) -> None: + try: + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.helper_proxies import HelperProxiesService + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerStub + from .proxies.utils_proxy import UtilsProxy + + if rpc is None: + FolderPathsProxy.clear_rpc() + HelperProxiesService.clear_rpc() + ModelManagementProxy.clear_rpc() + ProgressProxy.clear_rpc() + PromptServerStub.clear_rpc() + UtilsProxy.clear_rpc() + return + + FolderPathsProxy.set_rpc(rpc) + HelperProxiesService.set_rpc(rpc) + ModelManagementProxy.set_rpc(rpc) + ProgressProxy.set_rpc(rpc) + PromptServerStub.set_rpc(rpc) + UtilsProxy.set_rpc(rpc) + + except Exception as e: + logger.error(f"Failed to setup child singleton proxy callers: {e}") + + +def _setup_logging() -> None: + logging.getLogger().setLevel(logging.INFO) diff --git a/comfy/isolation/extension_loader.py b/comfy/isolation/extension_loader.py new file mode 100644 index 000000000..0c65b234e --- /dev/null +++ b/comfy/isolation/extension_loader.py @@ -0,0 +1,521 @@ +# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name +from __future__ import annotations + +import logging +import os +import inspect +import sys +import types +import platform +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple + +import pyisolate +from pyisolate import ExtensionManager, ExtensionManagerConfig +from packaging.requirements import InvalidRequirement, Requirement +from packaging.utils import canonicalize_name + +from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache +from .host_policy import load_host_policy + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + + +def _register_web_directory(extension_name: str, node_dir: Path) -> None: + """Register an isolated extension's web directory on the host side.""" + import nodes + + # Method 1: pyproject.toml [tool.comfy] web field + pyproject = node_dir / "pyproject.toml" + if pyproject.exists(): + try: + with pyproject.open("rb") as f: + data = tomllib.load(f) + web_dir_name = data.get("tool", {}).get("comfy", {}).get("web") + if web_dir_name: + web_dir_path = str(node_dir / web_dir_name) + if os.path.isdir(web_dir_path): + nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path + logger.debug( + "][ Registered web dir for isolated %s: %s", + extension_name, + web_dir_path, + ) + return + except Exception: + pass + + # Method 2: __init__.py WEB_DIRECTORY constant (parse without importing) + init_file = node_dir / "__init__.py" + if init_file.exists(): + try: + source = init_file.read_text() + for line in source.splitlines(): + stripped = line.strip() + if stripped.startswith("WEB_DIRECTORY"): + # Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web" + _, _, value = stripped.partition("=") + value = value.strip().strip("\"'") + if value: + web_dir_path = str((node_dir / value).resolve()) + if os.path.isdir(web_dir_path): + nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path + logger.debug( + "][ Registered web dir for isolated %s: %s", + extension_name, + web_dir_path, + ) + return + except Exception: + pass + + +def _get_extension_type(execution_model: str) -> type[Any]: + if execution_model == "sealed_worker": + return pyisolate.SealedNodeExtension + + from .extension_wrapper import ComfyNodeExtension + + return ComfyNodeExtension + + +async def _stop_extension_safe(extension: Any, extension_name: str) -> None: + try: + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + except Exception: + logger.debug("][ %s stop failed", extension_name, exc_info=True) + + +def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str: + req, sep, marker = dep.partition(";") + req = req.strip() + marker_suffix = f";{marker}" if sep else "" + + def _resolve_local_path(local_path: str) -> Path | None: + for base in base_paths: + candidate = (base / local_path).resolve() + if candidate.exists(): + return candidate + return None + + if req.startswith("./") or req.startswith("../"): + resolved = _resolve_local_path(req) + if resolved is not None: + return f"{resolved}{marker_suffix}" + + if req.startswith("file://"): + raw = req[len("file://") :] + if raw.startswith("./") or raw.startswith("../"): + resolved = _resolve_local_path(raw) + if resolved is not None: + return f"file://{resolved}{marker_suffix}" + + return dep + + +def _dependency_name_from_spec(dep: str) -> str | None: + stripped = dep.strip() + if not stripped or stripped == "-e" or stripped.startswith("-e "): + return None + if stripped.startswith(("/", "./", "../", "file://")): + return None + + try: + return canonicalize_name(Requirement(stripped).name) + except InvalidRequirement: + return None + + +def _parse_cuda_wheels_config( + tool_config: dict[str, object], dependencies: list[str] +) -> dict[str, object] | None: + raw_config = tool_config.get("cuda_wheels") + if raw_config is None: + return None + if not isinstance(raw_config, dict): + raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table") + + index_url = raw_config.get("index_url") + index_urls = raw_config.get("index_urls") + if index_urls is not None: + if not isinstance(index_urls, list) or not all( + isinstance(u, str) and u.strip() for u in index_urls + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings" + ) + elif not isinstance(index_url, str) or not index_url.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string" + ) + + packages = raw_config.get("packages") + if not isinstance(packages, list) or not all( + isinstance(package_name, str) and package_name.strip() + for package_name in packages + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings" + ) + + declared_dependencies = { + dependency_name + for dep in dependencies + if (dependency_name := _dependency_name_from_spec(dep)) is not None + } + normalized_packages = [canonicalize_name(package_name) for package_name in packages] + missing = [ + package_name + for package_name in normalized_packages + if package_name not in declared_dependencies + ] + if missing: + missing_joined = ", ".join(sorted(missing)) + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: " + f"{missing_joined}" + ) + + package_map = raw_config.get("package_map", {}) + if not isinstance(package_map, dict): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] must be a table" + ) + + normalized_package_map: dict[str, str] = {} + for dependency_name, index_package_name in package_map.items(): + if not isinstance(dependency_name, str) or not dependency_name.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings" + ) + if not isinstance(index_package_name, str) or not index_package_name.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings" + ) + canonical_dependency_name = canonicalize_name(dependency_name) + if canonical_dependency_name not in normalized_packages: + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in " + "[tool.comfy.isolation.cuda_wheels.packages]" + ) + normalized_package_map[canonical_dependency_name] = index_package_name.strip() + + result: dict = { + "packages": normalized_packages, + "package_map": normalized_package_map, + } + if index_urls is not None: + result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls] + else: + result["index_url"] = index_url.rstrip("/") + "/" + return result + + +def get_enforcement_policy() -> Dict[str, bool]: + return { + "force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1", + "force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1", + } + + +class ExtensionLoadError(RuntimeError): + pass + + +def register_dummy_module(extension_name: str, node_dir: Path) -> None: + normalized_name = extension_name.replace("-", "_").replace(".", "_") + if normalized_name not in sys.modules: + dummy_module = types.ModuleType(normalized_name) + dummy_module.__file__ = str(node_dir / "__init__.py") + dummy_module.__path__ = [str(node_dir)] + dummy_module.__package__ = normalized_name + sys.modules[normalized_name] = dummy_module + + +def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool: + for details in cached_data.values(): + if not isinstance(details, dict): + return True + if details.get("is_v3") and "schema_v1" not in details: + return True + return False + + +async def load_isolated_node( + node_dir: Path, + manifest_path: Path, + logger: logging.Logger, + build_stub_class: Callable[[str, Dict[str, object], Any], type], + venv_root: Path, + extension_managers: List[ExtensionManager], +) -> List[Tuple[str, str, type]]: + try: + with manifest_path.open("rb") as handle: + manifest_data = tomllib.load(handle) + except Exception as e: + logger.warning(f"][ Failed to parse {manifest_path}: {e}") + return [] + + # Parse [tool.comfy.isolation] + tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {}) + can_isolate = tool_config.get("can_isolate", False) + share_torch = tool_config.get("share_torch", False) + package_manager = tool_config.get("package_manager", "uv") + is_conda = package_manager == "conda" + execution_model = tool_config.get("execution_model") + if execution_model is None: + execution_model = "sealed_worker" if is_conda else "host-coupled" + + if "sealed_host_ro_paths" in tool_config: + raise ValueError( + "Manifest field 'sealed_host_ro_paths' is not allowed. " + "Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy." + ) + + # Conda-specific manifest fields + conda_channels: list[str] = ( + tool_config.get("conda_channels", []) if is_conda else [] + ) + conda_dependencies: list[str] = ( + tool_config.get("conda_dependencies", []) if is_conda else [] + ) + conda_platforms: list[str] = ( + tool_config.get("conda_platforms", []) if is_conda else [] + ) + conda_python: str = ( + tool_config.get("conda_python", "*") if is_conda else "*" + ) + + # Parse [project] dependencies + project_config = manifest_data.get("project", {}) + dependencies = project_config.get("dependencies", []) + if not isinstance(dependencies, list): + dependencies = [] + + # Get extension name (default to folder name if not in project.name) + extension_name = project_config.get("name", node_dir.name) + + # LOGIC: Isolation Decision + policy = get_enforcement_policy() + isolated = can_isolate or policy["force_isolated"] + + if not isolated: + return [] + + import folder_paths + + base_paths = [Path(folder_paths.base_path), node_dir] + dependencies = [ + _normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep + for dep in dependencies + ] + cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies) + + manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root)) + extension_type = _get_extension_type(execution_model) + manager: ExtensionManager = pyisolate.ExtensionManager( + extension_type, manager_config + ) + extension_managers.append(manager) + + host_policy = load_host_policy(Path(folder_paths.base_path)) + + sandbox_config = {} + is_linux = platform.system() == "Linux" + + if is_conda: + share_torch = False + share_cuda_ipc = False + else: + share_cuda_ipc = share_torch and is_linux + + if is_linux and isolated: + sandbox_config = { + "network": host_policy["allow_network"], + "writable_paths": host_policy["writable_paths"], + "readonly_paths": host_policy["readonly_paths"], + } + + extension_config: dict = { + "name": extension_name, + "module_path": str(node_dir), + "isolated": True, + "dependencies": dependencies, + "share_torch": share_torch, + "share_cuda_ipc": share_cuda_ipc, + "sandbox_mode": host_policy["sandbox_mode"], + "sandbox": sandbox_config, + } + + _is_sealed = execution_model == "sealed_worker" + _is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux + logger.info( + "][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])", + extension_name, + "x" if share_torch else " ", + "x" if _is_sealed else " ", + "x" if _is_sandboxed else " ", + ) + + if cuda_wheels is not None: + extension_config["cuda_wheels"] = cuda_wheels + + # Conda-specific keys + if is_conda: + extension_config["package_manager"] = "conda" + extension_config["conda_channels"] = conda_channels + extension_config["conda_dependencies"] = conda_dependencies + extension_config["conda_python"] = conda_python + find_links = tool_config.get("find_links", []) + if find_links: + extension_config["find_links"] = find_links + if conda_platforms: + extension_config["conda_platforms"] = conda_platforms + + if execution_model != "host-coupled": + extension_config["execution_model"] = execution_model + if execution_model == "sealed_worker": + policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", []) + if isinstance(policy_ro_paths, list) and policy_ro_paths: + extension_config["sealed_host_ro_paths"] = list(policy_ro_paths) + # Sealed workers keep the host RPC service inventory even when the + # child resolves no API classes locally. + + extension = manager.load_extension(extension_config) + register_dummy_module(extension_name, node_dir) + + # Register host-side event handlers via adapter + from .adapter import ComfyUIAdapter + ComfyUIAdapter.register_host_event_handlers(extension) + + # Register web directory on the host — only when sandbox is disabled. + # In sandbox mode, serving untrusted JS to the browser is not safe. + if host_policy["sandbox_mode"] == "disabled": + _register_web_directory(extension_name, node_dir) + + # Register for proxied web serving — the child's web dir may have + # content that doesn't exist on the host (e.g., pip-installed viewer + # bundles). The WebDirectoryCache will lazily fetch via RPC. + from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache + cache = get_web_directory_cache() + cache.register_proxy(extension_name, WebDirectoryProxy()) + + # Try cache first (lazy spawn) + logger.warning("][ DIAG:ext_loader cache_valid_check for %s", extension_name) + if is_cache_valid(node_dir, manifest_path, venv_root): + cached_data = load_from_cache(node_dir, venv_root) + if cached_data: + if _is_stale_node_cache(cached_data): + logger.warning( + "][ DIAG:ext_loader %s cache is stale/incompatible; rebuilding metadata", + extension_name, + ) + else: + logger.warning("][ DIAG:ext_loader %s USING CACHE — dumping combo options:", extension_name) + for node_name, details in cached_data.items(): + schema_v1 = details.get("schema_v1", {}) + inp = schema_v1.get("input", {}) if schema_v1 else {} + for section_name, section in inp.items(): + if isinstance(section, dict): + for field_name, field_def in section.items(): + if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]: + opts = field_def[1]["options"] + logger.warning( + "][ DIAG:ext_loader CACHE %s.%s.%s options=%d first=%s", + node_name, section_name, field_name, + len(opts), + opts[:3] if opts else "EMPTY", + ) + specs: List[Tuple[str, str, type]] = [] + for node_name, details in cached_data.items(): + stub_cls = build_stub_class(node_name, details, extension) + specs.append( + (node_name, details.get("display_name", node_name), stub_cls) + ) + return specs + else: + logger.warning("][ DIAG:ext_loader %s cache INVALID or MISSING", extension_name) + + # Cache miss - spawn process and get metadata + logger.warning("][ DIAG:ext_loader %s cache miss, spawning process for metadata", extension_name) + + try: + remote_nodes: Dict[str, str] = await extension.list_nodes() + except Exception as exc: + logger.warning( + "][ %s metadata discovery failed, skipping isolated load: %s", + extension_name, + exc, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + if not remote_nodes: + logger.debug("][ %s exposed no isolated nodes; skipping", extension_name) + await _stop_extension_safe(extension, extension_name) + return [] + + specs: List[Tuple[str, str, type]] = [] + cache_data: Dict[str, Dict] = {} + + for node_name, display_name in remote_nodes.items(): + logger.warning("][ DIAG:ext_loader calling get_node_details for %s.%s", extension_name, node_name) + try: + details = await extension.get_node_details(node_name) + except Exception as exc: + logger.warning( + "][ %s failed to load metadata for %s, skipping node: %s", + extension_name, + node_name, + exc, + ) + continue + # DIAG: dump combo options from freshly-fetched details + schema_v1 = details.get("schema_v1", {}) + inp = schema_v1.get("input", {}) if schema_v1 else {} + for section_name, section in inp.items(): + if isinstance(section, dict): + for field_name, field_def in section.items(): + if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]: + opts = field_def[1]["options"] + logger.warning( + "][ DIAG:ext_loader FRESH %s.%s.%s options=%d first=%s", + node_name, section_name, field_name, + len(opts), + opts[:3] if opts else "EMPTY", + ) + details["display_name"] = display_name + cache_data[node_name] = details + stub_cls = build_stub_class(node_name, details, extension) + specs.append((node_name, display_name, stub_cls)) + + if not specs: + logger.warning( + "][ %s produced no usable nodes after metadata scan; skipping", + extension_name, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + # Save metadata to cache for future runs + save_to_cache(node_dir, venv_root, cache_data, manifest_path) + logger.debug(f"][ {extension_name} metadata cached") + + # Re-check web directory AFTER child has populated it + if host_policy["sandbox_mode"] == "disabled": + _register_web_directory(extension_name, node_dir) + + # EJECT: Kill process after getting metadata (will respawn on first execution) + await _stop_extension_safe(extension, extension_name) + + return specs + + +__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"] diff --git a/comfy/isolation/extension_wrapper.py b/comfy/isolation/extension_wrapper.py new file mode 100644 index 000000000..67ba1d5c4 --- /dev/null +++ b/comfy/isolation/extension_wrapper.py @@ -0,0 +1,896 @@ +# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position +from __future__ import annotations + +import asyncio +import torch + + +class AttrDict(dict): + def __getattr__(self, item): + try: + return self[item] + except KeyError as e: + raise AttributeError(item) from e + + def copy(self): + return AttrDict(super().copy()) + + +import importlib +import inspect +import json +import logging +import os +import sys +import uuid +from dataclasses import asdict +from typing import Any, Dict, List, Tuple + +from pyisolate import ExtensionBase + +from comfy_api.internal import _ComfyNodeInternal + +LOG_PREFIX = "][" +V3_DISCOVERY_TIMEOUT = 30 +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + +logger = logging.getLogger(__name__) + + +def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None: + """Run the web asset copy step that prestartup_script.py used to do. + + If the module's web/ directory is empty and the module had a + prestartup_script.py that copied assets from pip packages, this + function replicates that work inside the child process. + + Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if + defined, otherwise falls back to detecting common asset packages. + """ + import shutil + + # Already populated — nothing to do + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + return + + os.makedirs(web_dir_path, exist_ok=True) + + # Try module-defined copy spec first (generic hook for any node pack) + copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None) + if copy_spec is not None and callable(copy_spec): + try: + copy_spec(web_dir_path) + logger.info( + "%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir + ) + return + except Exception as e: + logger.warning( + "%s _PRESTARTUP_WEB_COPY failed for %s: %s", + LOG_PREFIX, module_dir, e, + ) + + # Fallback: detect comfy_3d_viewers and run copy_viewer() + try: + from comfy_3d_viewers import copy_viewer, VIEWER_FILES + viewers = list(VIEWER_FILES.keys()) + for viewer in viewers: + try: + copy_viewer(viewer, web_dir_path) + except Exception: + pass + if any(os.scandir(web_dir_path)): + logger.info( + "%s Copied %d viewer types from comfy_3d_viewers to %s", + LOG_PREFIX, len(viewers), web_dir_path, + ) + except ImportError: + pass + + # Fallback: detect comfy_dynamic_widgets + try: + from comfy_dynamic_widgets import get_js_path + src = os.path.realpath(get_js_path()) + if os.path.exists(src): + dst_dir = os.path.join(web_dir_path, "js") + os.makedirs(dst_dir, exist_ok=True) + dst = os.path.join(dst_dir, "dynamic_widgets.js") + shutil.copy2(src, dst) + except ImportError: + pass + + +def _read_extension_name(module_dir: str) -> str: + """Read extension name from pyproject.toml, falling back to directory name.""" + pyproject = os.path.join(module_dir, "pyproject.toml") + if os.path.exists(pyproject): + try: + import tomllib + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + try: + with open(pyproject, "rb") as f: + data = tomllib.load(f) + name = data.get("project", {}).get("name") + if name: + return name + except Exception: + pass + return os.path.basename(module_dir) + + +def _flush_tensor_transport_state(marker: str) -> int: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return 0 + if not callable(flush_tensor_keeper): + return 0 + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + return flushed + + +def _relieve_child_vram_pressure(marker: str) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _sanitize_for_transport(value): + primitives = (str, int, float, bool, type(None)) + if isinstance(value, primitives): + return value + + cls_name = value.__class__.__name__ + if cls_name == "FlexibleOptionalInputType": + return { + "__pyisolate_flexible_optional__": True, + "type": _sanitize_for_transport(getattr(value, "type", "*")), + } + if cls_name == "AnyType": + return {"__pyisolate_any_type__": True, "value": str(value)} + if cls_name == "ByPassTypeTuple": + return { + "__pyisolate_bypass_tuple__": [ + _sanitize_for_transport(v) for v in tuple(value) + ] + } + + if isinstance(value, dict): + return {k: _sanitize_for_transport(v) for k, v in value.items()} + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]} + if isinstance(value, list): + return [_sanitize_for_transport(v) for v in value] + + return str(value) + + +# Re-export RemoteObjectHandle from pyisolate for backward compatibility +# The canonical definition is now in pyisolate._internal.remote_handle +from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401 + + +class ComfyNodeExtension(ExtensionBase): + def __init__(self) -> None: + super().__init__() + self.node_classes: Dict[str, type] = {} + self.display_names: Dict[str, str] = {} + self.node_instances: Dict[str, Any] = {} + self.remote_objects: Dict[str, Any] = {} + self._route_handlers: Dict[str, Any] = {} + self._module: Any = None + + async def on_module_loaded(self, module: Any) -> None: + self._module = module + + # Registries are initialized in host_hooks.py initialize_host_process() + # They auto-register via ProxiedSingleton when instantiated + # NO additional setup required here - if a registry is missing from host_hooks, it WILL fail + + self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} + self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} + + # Register web directory with WebDirectoryProxy (child-side) + web_dir_attr = getattr(module, "WEB_DIRECTORY", None) + if web_dir_attr is not None: + module_dir = os.path.dirname(os.path.abspath(module.__file__)) + web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr)) + ext_name = _read_extension_name(module_dir) + + # If web dir is empty, run the copy step that prestartup_script.py did + _run_prestartup_web_copy(module, module_dir, web_dir_path) + + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + WebDirectoryProxy.register_web_dir(ext_name, web_dir_path) + + try: + from comfy_api.latest import ComfyExtension + + for name, obj in inspect.getmembers(module): + if not ( + inspect.isclass(obj) + and issubclass(obj, ComfyExtension) + and obj is not ComfyExtension + ): + continue + if not obj.__module__.startswith(module.__name__): + continue + try: + ext_instance = obj() + try: + await asyncio.wait_for( + ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in on_load()", + LOG_PREFIX, + name, + ) + continue + try: + v3_nodes = await asyncio.wait_for( + ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in get_node_list()", + LOG_PREFIX, + name, + ) + continue + for node_cls in v3_nodes: + if hasattr(node_cls, "GET_SCHEMA"): + schema = node_cls.GET_SCHEMA() + self.node_classes[schema.node_id] = node_cls + if schema.display_name: + self.display_names[schema.node_id] = schema.display_name + except Exception as e: + logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e) + except ImportError: + pass + + module_name = getattr(module, "__name__", "isolated_nodes") + for node_cls in self.node_classes.values(): + if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__): + node_cls.__module__ = module_name + + self.node_instances = {} + + async def list_nodes(self) -> Dict[str, str]: + return {name: self.display_names.get(name, name) for name in self.node_classes} + + async def get_node_info(self, node_name: str) -> Dict[str, Any]: + return await self.get_node_details(node_name) + + async def get_node_details(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + is_v3 = issubclass(node_cls, _ComfyNodeInternal) + logger.warning( + "%s DIAG:get_node_details START | node=%s | is_v3=%s | cls=%s", + LOG_PREFIX, node_name, is_v3, node_cls, + ) + + input_types_raw = ( + node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {} + ) + output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None) + if output_is_list is not None: + output_is_list = tuple(bool(x) for x in output_is_list) + + details: Dict[str, Any] = { + "input_types": _sanitize_for_transport(input_types_raw), + "return_types": tuple( + str(t) for t in getattr(node_cls, "RETURN_TYPES", ()) + ), + "return_names": getattr(node_cls, "RETURN_NAMES", None), + "function": str(getattr(node_cls, "FUNCTION", "execute")), + "category": str(getattr(node_cls, "CATEGORY", "")), + "output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)), + "output_is_list": output_is_list, + "is_v3": is_v3, + } + + if is_v3: + try: + logger.warning( + "%s DIAG:get_node_details calling GET_SCHEMA for %s", + LOG_PREFIX, node_name, + ) + schema = node_cls.GET_SCHEMA() + logger.warning( + "%s DIAG:get_node_details GET_SCHEMA returned for %s | schema_inputs=%s", + LOG_PREFIX, node_name, + [getattr(i, 'id', '?') for i in (schema.inputs or [])], + ) + schema_v1 = asdict(schema.get_v1_info(node_cls)) + try: + schema_v3 = asdict(schema.get_v3_info(node_cls)) + except (AttributeError, TypeError): + schema_v3 = self._build_schema_v3_fallback(schema) + details.update( + { + "schema_v1": schema_v1, + "schema_v3": schema_v3, + "hidden": [h.value for h in (schema.hidden or [])], + "description": getattr(schema, "description", ""), + "deprecated": bool(getattr(node_cls, "DEPRECATED", False)), + "experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)), + "api_node": bool(getattr(node_cls, "API_NODE", False)), + "input_is_list": bool( + getattr(node_cls, "INPUT_IS_LIST", False) + ), + "not_idempotent": bool( + getattr(node_cls, "NOT_IDEMPOTENT", False) + ), + "accept_all_inputs": bool( + getattr(node_cls, "ACCEPT_ALL_INPUTS", False) + ), + } + ) + except Exception as exc: + logger.warning( + "%s V3 schema serialization failed for %s: %s", + LOG_PREFIX, + node_name, + exc, + ) + return details + + def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]: + input_dict: Dict[str, Any] = {} + output_dict: Dict[str, Any] = {} + hidden_list: List[str] = [] + + if getattr(schema, "inputs", None): + for inp in schema.inputs: + self._add_schema_io_v3(inp, input_dict) + if getattr(schema, "outputs", None): + for out in schema.outputs: + self._add_schema_io_v3(out, output_dict) + if getattr(schema, "hidden", None): + for h in schema.hidden: + hidden_list.append(getattr(h, "value", str(h))) + + return { + "input": input_dict, + "output": output_dict, + "hidden": hidden_list, + "name": getattr(schema, "node_id", None), + "display_name": getattr(schema, "display_name", None), + "description": getattr(schema, "description", None), + "category": getattr(schema, "category", None), + "output_node": getattr(schema, "is_output_node", False), + "deprecated": getattr(schema, "is_deprecated", False), + "experimental": getattr(schema, "is_experimental", False), + "api_node": getattr(schema, "is_api_node", False), + } + + def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None: + io_id = getattr(io_obj, "id", None) + if io_id is None: + return + + io_type_fn = getattr(io_obj, "get_io_type", None) + io_type = ( + io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None) + ) + + as_dict_fn = getattr(io_obj, "as_dict", None) + payload = as_dict_fn() if callable(as_dict_fn) else {} + + target[str(io_id)] = (io_type, payload) + + async def get_input_types(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + if hasattr(node_cls, "INPUT_TYPES"): + return node_cls.INPUT_TYPES() + return {} + + async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]: + logger.debug( + "%s ISO:child_execute_start ext=%s node=%s input_keys=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + len(inputs), + ) + if os.environ.get("PYISOLATE_CHILD") == "1": + _relieve_child_vram_pressure("EXT:pre_execute") + + resolved_inputs = self._resolve_remote_objects(inputs) + + instance = self._get_node_instance(node_name) + node_cls = self._get_node_class(node_name) + + # V3 API nodes expect hidden parameters in cls.hidden, not as kwargs + # Hidden params come through RPC as string keys like "Hidden.prompt" + from comfy_api.latest._io import Hidden, HiddenHolder + + # Map string representations back to Hidden enum keys + hidden_string_map = { + "Hidden.unique_id": Hidden.unique_id, + "Hidden.prompt": Hidden.prompt, + "Hidden.extra_pnginfo": Hidden.extra_pnginfo, + "Hidden.dynprompt": Hidden.dynprompt, + "Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org, + "Hidden.api_key_comfy_org": Hidden.api_key_comfy_org, + # Uppercase enum VALUE forms — V3 execution engine passes these + "UNIQUE_ID": Hidden.unique_id, + "PROMPT": Hidden.prompt, + "EXTRA_PNGINFO": Hidden.extra_pnginfo, + "DYNPROMPT": Hidden.dynprompt, + "AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org, + "API_KEY_COMFY_ORG": Hidden.api_key_comfy_org, + } + + # Find and extract hidden parameters (both enum and string form) + hidden_found = {} + keys_to_remove = [] + + for key in list(resolved_inputs.keys()): + # Check string form first (from RPC serialization) + if key in hidden_string_map: + hidden_found[hidden_string_map[key]] = resolved_inputs[key] + keys_to_remove.append(key) + # Also check enum form (direct calls) + elif isinstance(key, Hidden): + hidden_found[key] = resolved_inputs[key] + keys_to_remove.append(key) + + # Remove hidden params from kwargs + for key in keys_to_remove: + resolved_inputs.pop(key) + + # Set hidden on node class if any hidden params found + if hidden_found: + if not hasattr(node_cls, "hidden") or node_cls.hidden is None: + node_cls.hidden = HiddenHolder.from_dict(hidden_found) + else: + # Update existing hidden holder + for key, value in hidden_found.items(): + setattr(node_cls.hidden, key.value.lower(), value) + + # INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this + # flag is set. The isolation RPC delivers unwrapped values, so we must + # wrap each input in a single-element list to match the contract. + if getattr(node_cls, "INPUT_IS_LIST", False): + resolved_inputs = {k: [v] for k, v in resolved_inputs.items()} + + function_name = getattr(node_cls, "FUNCTION", "execute") + if not hasattr(instance, function_name): + raise AttributeError(f"Node {node_name} missing callable '{function_name}'") + + handler = getattr(instance, function_name) + + try: + import torch + if asyncio.iscoroutinefunction(handler): + with torch.inference_mode(): + result = await handler(**resolved_inputs) + else: + import functools + + def _run_with_inference_mode(**kwargs): + with torch.inference_mode(): + return handler(**kwargs) + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, functools.partial(_run_with_inference_mode, **resolved_inputs) + ) + except Exception: + logger.exception( + "%s ISO:child_execute_error ext=%s node=%s", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + ) + raise + + if type(result).__name__ == "NodeOutput": + node_output_dict = { + "__node_output__": True, + "args": self._wrap_unpicklable_objects(result.args), + } + if result.ui is not None: + node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui) + if getattr(result, "expand", None) is not None: + node_output_dict["expand"] = result.expand + if getattr(result, "block_execution", None) is not None: + node_output_dict["block_execution"] = result.block_execution + return node_output_dict + if self._is_comfy_protocol_return(result): + wrapped = self._wrap_unpicklable_objects(result) + return wrapped + + if not isinstance(result, tuple): + result = (result,) + wrapped = self._wrap_unpicklable_objects(result) + return wrapped + + async def flush_transport_state(self) -> int: + if os.environ.get("PYISOLATE_CHILD") != "1": + return 0 + logger.debug( + "%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?") + ) + flushed = _flush_tensor_transport_state("EXT:workflow_end") + try: + from comfy.isolation.model_patcher_proxy_registry import ( + ModelPatcherRegistry, + ) + + registry = ModelPatcherRegistry() + removed = registry.sweep_pending_cleanup() + if removed > 0: + logger.debug( + "%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed + ) + except Exception: + logger.debug( + "%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True + ) + logger.debug( + "%s ISO:child_flush_done ext=%s flushed=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + flushed, + ) + return flushed + + async def get_remote_object(self, object_id: str) -> Any: + """Retrieve a remote object by ID for host-side deserialization.""" + if object_id not in self.remote_objects: + raise KeyError(f"Remote object {object_id} not found") + + return self.remote_objects[object_id] + + def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle: + object_id = str(uuid.uuid4()) + self.remote_objects[object_id] = obj + return RemoteObjectHandle(object_id, type(obj).__name__) + + async def call_remote_object_method( + self, + object_id: str, + method_name: str, + *args: Any, + **kwargs: Any, + ) -> Any: + """Invoke a method or attribute-backed accessor on a child-owned object.""" + obj = await self.get_remote_object(object_id) + + if method_name == "get_patcher_attr": + return getattr(obj, args[0]) + if method_name == "get_model_options": + return getattr(obj, "model_options") + if method_name == "set_model_options": + setattr(obj, "model_options", args[0]) + return None + if method_name == "get_object_patches": + return getattr(obj, "object_patches") + if method_name == "get_patches": + return getattr(obj, "patches") + if method_name == "get_wrappers": + return getattr(obj, "wrappers") + if method_name == "get_callbacks": + return getattr(obj, "callbacks") + if method_name == "get_load_device": + return getattr(obj, "load_device") + if method_name == "get_offload_device": + return getattr(obj, "offload_device") + if method_name == "get_hook_mode": + return getattr(obj, "hook_mode") + if method_name == "get_parent": + parent = getattr(obj, "parent", None) + if parent is None: + return None + return self._store_remote_object_handle(parent) + if method_name == "get_inner_model_attr": + attr_name = args[0] + if hasattr(obj.model, attr_name): + return getattr(obj.model, attr_name) + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + return None + if method_name == "inner_model_apply_model": + return obj.model.apply_model(*args[0], **args[1]) + if method_name == "inner_model_extra_conds_shapes": + return obj.model.extra_conds_shapes(*args[0], **args[1]) + if method_name == "inner_model_extra_conds": + return obj.model.extra_conds(*args[0], **args[1]) + if method_name == "inner_model_memory_required": + return obj.model.memory_required(*args[0], **args[1]) + if method_name == "process_latent_in": + return obj.model.process_latent_in(*args[0], **args[1]) + if method_name == "process_latent_out": + return obj.model.process_latent_out(*args[0], **args[1]) + if method_name == "scale_latent_inpaint": + return obj.model.scale_latent_inpaint(*args[0], **args[1]) + if method_name.startswith("get_"): + attr_name = method_name[4:] + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + + target = getattr(obj, method_name) + if callable(target): + result = target(*args, **kwargs) + if inspect.isawaitable(result): + result = await result + if type(result).__name__ == "ModelPatcher": + return self._store_remote_object_handle(result) + return result + if args or kwargs: + raise TypeError(f"{method_name} is not callable on remote object {object_id}") + return target + + def _wrap_unpicklable_objects(self, data: Any) -> Any: + if isinstance(data, (str, int, float, bool, type(None))): + return data + if isinstance(data, torch.Tensor): + tensor = data.detach() if data.requires_grad else data + if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu": + return tensor.cpu() + return tensor + + # Special-case clip vision outputs: preserve attribute access by packing fields + if hasattr(data, "penultimate_hidden_states") or hasattr( + data, "last_hidden_state" + ): + fields = {} + for attr in ( + "penultimate_hidden_states", + "last_hidden_state", + "image_embeds", + "text_embeds", + ): + if hasattr(data, attr): + try: + fields[attr] = self._wrap_unpicklable_objects( + getattr(data, attr) + ) + except Exception: + pass + if fields: + return {"__pyisolate_attribute_container__": True, "data": fields} + + # Avoid converting arbitrary objects with stateful methods (models, etc.) + # They will be handled via RemoteObjectHandle below. + + type_name = type(data).__name__ + if type_name == "ModelPatcherProxy": + return {"__type__": "ModelPatcherRef", "model_id": data._instance_id} + if type_name == "CLIPProxy": + return {"__type__": "CLIPRef", "clip_id": data._instance_id} + if type_name == "VAEProxy": + return {"__type__": "VAERef", "vae_id": data._instance_id} + if type_name == "ModelSamplingProxy": + return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id} + + if isinstance(data, (list, tuple)): + wrapped = [self._wrap_unpicklable_objects(item) for item in data] + return tuple(wrapped) if isinstance(data, tuple) else wrapped + if isinstance(data, dict): + converted_dict = { + k: self._wrap_unpicklable_objects(v) for k, v in data.items() + } + return {"__pyisolate_attrdict__": True, "data": converted_dict} + + from pyisolate._internal.serialization_registry import SerializerRegistry + + registry = SerializerRegistry.get_instance() + if registry.is_data_type(type_name): + serializer = registry.get_serializer(type_name) + if serializer: + return serializer(data) + + return self._store_remote_object_handle(data) + + def _resolve_remote_objects(self, data: Any) -> Any: + if isinstance(data, RemoteObjectHandle): + if data.object_id not in self.remote_objects: + raise KeyError(f"Remote object {data.object_id} not found") + return self.remote_objects[data.object_id] + + if isinstance(data, dict): + ref_type = data.get("__type__") + if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"): + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + if ref_type == "ModelSamplingRef": + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + return {k: self._resolve_remote_objects(v) for k, v in data.items()} + + if isinstance(data, (list, tuple)): + resolved = [self._resolve_remote_objects(item) for item in data] + return tuple(resolved) if isinstance(data, tuple) else resolved + return data + + def _get_node_class(self, node_name: str) -> type: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + return self.node_classes[node_name] + + def _get_node_instance(self, node_name: str) -> Any: + if node_name not in self.node_instances: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + self.node_instances[node_name] = self.node_classes[node_name]() + return self.node_instances[node_name] + + async def before_module_loaded(self) -> None: + # Inject initialization here if we think this is the child + logger.warning( + "%s DIAG:before_module_loaded START | is_child=%s", + LOG_PREFIX, os.environ.get("PYISOLATE_CHILD"), + ) + try: + from comfy.isolation import initialize_proxies + + initialize_proxies() + logger.warning("%s DIAG:before_module_loaded initialize_proxies OK", LOG_PREFIX) + except Exception as e: + logger.error( + "%s DIAG:before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e + ) + + await super().before_module_loaded() + try: + from comfy_api.latest import ComfyAPI_latest + from .proxies.progress_proxy import ProgressProxy + + ComfyAPI_latest.Execution = ProgressProxy + # ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision + # fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision + # latest_ui.folder_paths = fp_proxy + # latest_resources.folder_paths = fp_proxy + except Exception: + pass + + async def call_route_handler( + self, + handler_module: str, + handler_func: str, + request_data: Dict[str, Any], + ) -> Any: + cache_key = f"{handler_module}.{handler_func}" + if cache_key not in self._route_handlers: + if self._module is not None and hasattr(self._module, "__file__"): + node_dir = os.path.dirname(self._module.__file__) + if node_dir not in sys.path: + sys.path.insert(0, node_dir) + try: + module = importlib.import_module(handler_module) + self._route_handlers[cache_key] = getattr(module, handler_func) + except (ImportError, AttributeError) as e: + raise ValueError(f"Route handler not found: {cache_key}") from e + + handler = self._route_handlers[cache_key] + mock_request = MockRequest(request_data) + + if asyncio.iscoroutinefunction(handler): + result = await handler(mock_request) + else: + result = handler(mock_request) + return self._serialize_response(result) + + def _is_comfy_protocol_return(self, result: Any) -> bool: + """ + Check if the result matches the ComfyUI 'Protocol Return' schema. + + A Protocol Return is a dictionary containing specific reserved keys that + ComfyUI's execution engine interprets as instructions (UI updates, + Workflow expansion, etc.) rather than purely data outputs. + + Schema: + - Must be a dict + - Must contain at least one of: 'ui', 'result', 'expand' + """ + if not isinstance(result, dict): + return False + return any(key in result for key in ("ui", "result", "expand")) + + def _serialize_response(self, response: Any) -> Dict[str, Any]: + if response is None: + return {"type": "text", "body": "", "status": 204} + if isinstance(response, dict): + return {"type": "json", "body": response, "status": 200} + if isinstance(response, str): + return {"type": "text", "body": response, "status": 200} + if hasattr(response, "text") and hasattr(response, "status"): + return { + "type": "text", + "body": response.text + if hasattr(response, "text") + else str(response.body), + "status": response.status, + "headers": dict(response.headers) + if hasattr(response, "headers") + else {}, + } + if hasattr(response, "body") and hasattr(response, "status"): + body = response.body + if isinstance(body, bytes): + try: + return { + "type": "text", + "body": body.decode("utf-8"), + "status": response.status, + } + except UnicodeDecodeError: + return { + "type": "binary", + "body": body.hex(), + "status": response.status, + } + return {"type": "json", "body": body, "status": response.status} + return {"type": "text", "body": str(response), "status": 200} + + +class MockRequest: + def __init__(self, data: Dict[str, Any]): + self.method = data.get("method", "GET") + self.path = data.get("path", "/") + self.query = data.get("query", {}) + self._body = data.get("body", {}) + self._text = data.get("text", "") + self.headers = data.get("headers", {}) + self.content_type = data.get( + "content_type", self.headers.get("Content-Type", "application/json") + ) + self.match_info = data.get("match_info", {}) + + async def json(self) -> Any: + if isinstance(self._body, dict): + return self._body + if isinstance(self._body, str): + return json.loads(self._body) + return {} + + async def post(self) -> Dict[str, Any]: + if isinstance(self._body, dict): + return self._body + return {} + + async def text(self) -> str: + if self._text: + return self._text + if isinstance(self._body, str): + return self._body + if isinstance(self._body, dict): + return json.dumps(self._body) + return "" + + async def read(self) -> bytes: + return (await self.text()).encode("utf-8") diff --git a/comfy/isolation/host_hooks.py b/comfy/isolation/host_hooks.py new file mode 100644 index 000000000..e20143591 --- /dev/null +++ b/comfy/isolation/host_hooks.py @@ -0,0 +1,30 @@ +# pylint: disable=import-outside-toplevel +# Host process initialization for PyIsolate +import logging + +logger = logging.getLogger(__name__) + + +def initialize_host_process() -> None: + root = logging.getLogger() + for handler in root.handlers[:]: + root.removeHandler(handler) + root.addHandler(logging.NullHandler()) + + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.helper_proxies import HelperProxiesService + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerService + from .proxies.utils_proxy import UtilsProxy + from .proxies.web_directory_proxy import WebDirectoryProxy + from .vae_proxy import VAERegistry + + FolderPathsProxy() + HelperProxiesService() + ModelManagementProxy() + ProgressProxy() + PromptServerService() + UtilsProxy() + WebDirectoryProxy() + VAERegistry() diff --git a/comfy/isolation/manifest_loader.py b/comfy/isolation/manifest_loader.py new file mode 100644 index 000000000..4ae21d94d --- /dev/null +++ b/comfy/isolation/manifest_loader.py @@ -0,0 +1,221 @@ +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import hashlib +import json +import logging +import os +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import folder_paths + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + +CACHE_SUBDIR = "cache" +CACHE_KEY_FILE = "cache_key" +CACHE_DATA_FILE = "node_info.json" +CACHE_KEY_LENGTH = 16 +_NESTED_SCAN_ROOT = "packages" +_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"} + + +def _read_manifest(manifest_path: Path) -> dict[str, Any] | None: + try: + with manifest_path.open("rb") as f: + data = tomllib.load(f) + if isinstance(data, dict): + return data + except Exception: + return None + return None + + +def _is_isolation_manifest(data: dict[str, Any]) -> bool: + return ( + "tool" in data + and "comfy" in data["tool"] + and "isolation" in data["tool"]["comfy"] + ) + + +def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]: + packages_root = entry / _NESTED_SCAN_ROOT + if not packages_root.exists() or not packages_root.is_dir(): + return [] + + nested: List[Tuple[Path, Path]] = [] + for manifest in sorted(packages_root.rglob("pyproject.toml")): + node_dir = manifest.parent + if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts): + continue + + data = _read_manifest(manifest) + if not data or not _is_isolation_manifest(data): + continue + + isolation = data["tool"]["comfy"]["isolation"] + if isolation.get("standalone") is True: + nested.append((node_dir, manifest)) + + return nested + + +def find_manifest_directories() -> List[Tuple[Path, Path]]: + """Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation].""" + manifest_dirs: List[Tuple[Path, Path]] = [] + + # Standard custom_nodes paths + for base_path in folder_paths.get_folder_paths("custom_nodes"): + base = Path(base_path) + if not base.exists() or not base.is_dir(): + continue + + for entry in base.iterdir(): + if not entry.is_dir(): + continue + + # Look for pyproject.toml + manifest = entry / "pyproject.toml" + if not manifest.exists(): + continue + + data = _read_manifest(manifest) + if not data or not _is_isolation_manifest(data): + continue + + manifest_dirs.append((entry, manifest)) + manifest_dirs.extend(_discover_nested_manifests(entry)) + + return manifest_dirs + + +def compute_cache_key(node_dir: Path, manifest_path: Path) -> str: + """Hash manifest + .py mtimes + Python version + PyIsolate version.""" + hasher = hashlib.sha256() + + try: + # Hashing the manifest content ensures config changes invalidate cache + hasher.update(manifest_path.read_bytes()) + except OSError: + hasher.update(b"__manifest_read_error__") + + try: + py_files = sorted(node_dir.rglob("*.py")) + for py_file in py_files: + rel_path = py_file.relative_to(node_dir) + if "__pycache__" in str(rel_path) or ".venv" in str(rel_path): + continue + hasher.update(str(rel_path).encode("utf-8")) + try: + hasher.update(str(py_file.stat().st_mtime).encode("utf-8")) + except OSError: + hasher.update(b"__file_stat_error__") + except OSError: + hasher.update(b"__dir_scan_error__") + + hasher.update(sys.version.encode("utf-8")) + + try: + import pyisolate + + hasher.update(pyisolate.__version__.encode("utf-8")) + except (ImportError, AttributeError): + hasher.update(b"__pyisolate_unknown__") + + return hasher.hexdigest()[:CACHE_KEY_LENGTH] + + +def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]: + """Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/.""" + cache_dir = venv_root / node_dir.name / CACHE_SUBDIR + return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE) + + +def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool: + """Return True only if stored cache key matches current computed key.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_key_file.exists() or not cache_data_file.exists(): + return False + current_key = compute_cache_key(node_dir, manifest_path) + stored_key = cache_key_file.read_text(encoding="utf-8").strip() + return current_key == stored_key + except Exception as e: + logger.debug( + "%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e + ) + return False + + +def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]: + """Load node metadata from cache, return None on any error.""" + try: + _, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_data_file.exists(): + return None + data = json.loads(cache_data_file.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + return data + except Exception: + return None + + +def save_to_cache( + node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path +) -> None: + """Save node metadata and cache key atomically.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + cache_dir = cache_key_file.parent + cache_dir.mkdir(parents=True, exist_ok=True) + cache_key = compute_cache_key(node_dir, manifest_path) + + # Atomic write: data + tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f: + json.dump(node_data, f, indent=2) + os.replace(tmp_data_path, cache_data_file) + except Exception: + try: + os.unlink(tmp_data_path) + except OSError: + pass + raise + + # Atomic write: key + tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f: + f.write(cache_key) + os.replace(tmp_key_path, cache_key_file) + except Exception: + try: + os.unlink(tmp_key_path) + except OSError: + pass + raise + + except Exception as e: + logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e) + + +__all__ = [ + "LOG_PREFIX", + "find_manifest_directories", + "compute_cache_key", + "get_cache_path", + "is_cache_valid", + "load_from_cache", + "save_to_cache", +] diff --git a/comfy/isolation/rpc_bridge.py b/comfy/isolation/rpc_bridge.py new file mode 100644 index 000000000..2beb0f09f --- /dev/null +++ b/comfy/isolation/rpc_bridge.py @@ -0,0 +1,49 @@ +import asyncio +import logging +import threading + +logger = logging.getLogger(__name__) + + +class RpcBridge: + """Minimal helper to run coroutines synchronously inside isolated processes. + + If an event loop is already running, the coroutine is executed on a fresh + thread with its own loop to avoid nested run_until_complete errors. + """ + + def run_sync(self, maybe_coro): + if not asyncio.iscoroutine(maybe_coro): + return maybe_coro + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + result_container = {} + exc_container = {} + + def _runner(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + result_container["value"] = new_loop.run_until_complete(maybe_coro) + except Exception as exc: # pragma: no cover + exc_container["error"] = exc + finally: + try: + new_loop.close() + except Exception: + pass + + t = threading.Thread(target=_runner, daemon=True) + t.start() + t.join() + + if "error" in exc_container: + raise exc_container["error"] + return result_container.get("value") + + return asyncio.run(maybe_coro) diff --git a/comfy/isolation/runtime_helpers.py b/comfy/isolation/runtime_helpers.py new file mode 100644 index 000000000..f56b1859a --- /dev/null +++ b/comfy/isolation/runtime_helpers.py @@ -0,0 +1,471 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member +from __future__ import annotations + +import copy +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set, TYPE_CHECKING + +from .proxies.helper_proxies import restore_input_types +from .shm_forensics import scan_shm_forensics + +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +_ComfyNodeInternal = object +latest_io = None + +if _IMPORT_TORCH: + from comfy_api.internal import _ComfyNodeInternal + from comfy_api.latest import _io as latest_io + +if TYPE_CHECKING: + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + + +class _RemoteObjectRegistryCaller: + def __init__(self, extension: Any) -> None: + self._extension = extension + + def __getattr__(self, method_name: str) -> Any: + async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any: + return await self._extension.call_remote_object_method( + instance_id, + method_name, + *args, + **kwargs, + ) + + return _call + + +def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any: + from pyisolate._internal.remote_handle import RemoteObjectHandle + + if isinstance(value, RemoteObjectHandle): + if value.type_name == "ModelPatcher": + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "VAE": + from comfy.isolation.vae_proxy import VAEProxy + + proxy = VAEProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "CLIP": + from comfy.isolation.clip_proxy import CLIPProxy + + proxy = CLIPProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "ModelSampling": + from comfy.isolation.model_sampling_proxy import ModelSamplingProxy + + proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + return value + + if isinstance(value, dict): + return { + k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items() + } + + if isinstance(value, (list, tuple)): + wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value] + return type(value)(wrapped) + + return value + + +def _resource_snapshot() -> Dict[str, int]: + fd_count = -1 + shm_sender_files = 0 + try: + fd_count = len(os.listdir("/proc/self/fd")) + except Exception: + pass + try: + shm_root = Path("/dev/shm") + if shm_root.exists(): + prefix = f"torch_{os.getpid()}_" + shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*")) + except Exception: + pass + return {"fd_count": fd_count, "shm_sender_files": shm_sender_files} + + +def _tensor_transport_summary(value: Any) -> Dict[str, int]: + summary: Dict[str, int] = { + "tensor_count": 0, + "cpu_tensors": 0, + "cuda_tensors": 0, + "shared_cpu_tensors": 0, + "tensor_bytes": 0, + } + try: + import torch + except Exception: + return summary + + def visit(node: Any) -> None: + if isinstance(node, torch.Tensor): + summary["tensor_count"] += 1 + summary["tensor_bytes"] += int(node.numel() * node.element_size()) + if node.device.type == "cpu": + summary["cpu_tensors"] += 1 + if node.is_shared(): + summary["shared_cpu_tensors"] += 1 + elif node.device.type == "cuda": + summary["cuda_tensors"] += 1 + return + if isinstance(node, dict): + for v in node.values(): + visit(v) + return + if isinstance(node, (list, tuple)): + for v in node: + visit(v) + + visit(value) + return summary + + +def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None: + for key, value in inputs.items(): + key_text = str(key) + if "unique_id" in key_text: + return str(value) + return None + + +def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return + if not callable(flush_tensor_keeper): + return + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + + +def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _detach_shared_cpu_tensors(value: Any) -> Any: + try: + import torch + except Exception: + return value + + if isinstance(value, torch.Tensor): + if value.device.type == "cpu" and value.is_shared(): + clone = value.clone() + if value.requires_grad: + clone.requires_grad_(True) + return clone + return value + if isinstance(value, list): + return [_detach_shared_cpu_tensors(v) for v in value] + if isinstance(value, tuple): + return tuple(_detach_shared_cpu_tensors(v) for v in value) + if isinstance(value, dict): + return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()} + return value + + +def build_stub_class( + node_name: str, + info: Dict[str, object], + extension: "ComfyNodeExtension", + running_extensions: Dict[str, "ComfyNodeExtension"], + logger: logging.Logger, +) -> type: + if latest_io is None: + raise RuntimeError("comfy_api.latest._io is required to build isolation stubs") + is_v3 = bool(info.get("is_v3", False)) + function_name = "_pyisolate_execute" + restored_input_types = restore_input_types(info.get("input_types", {})) + + async def _execute(self, **inputs): + from comfy.isolation import _RUNNING_EXTENSIONS + + # Update BOTH the local dict AND the module-level dict + running_extensions[extension.name] = extension + _RUNNING_EXTENSIONS[extension.name] = extension + prev_child = None + node_unique_id = _extract_hidden_unique_id(inputs) + summary = _tensor_transport_summary(inputs) + resources = _resource_snapshot() + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + summary["tensor_count"], + summary["cpu_tensors"], + summary["cuda_tensors"], + summary["shared_cpu_tensors"], + summary["tensor_bytes"], + resources["fd_count"], + resources["shm_sender_files"], + ) + scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True) + try: + if os.environ.get("PYISOLATE_CHILD") != "1": + _relieve_host_vram_pressure("RUNTIME:pre_execute", logger) + scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True) + from pyisolate._internal.model_serialization import ( + serialize_for_isolation, + deserialize_from_isolation, + ) + + prev_child = os.environ.pop("PYISOLATE_CHILD", None) + logger.debug( + "%s ISO:serialize_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + # Unwrap NodeOutput-like dicts before serialization. + # OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)} + # and the executor may pass this dict as input to downstream nodes. + unwrapped_inputs = {} + for k, v in inputs.items(): + if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v): + result = v.get("result") + if isinstance(result, (tuple, list)) and len(result) > 0: + unwrapped_inputs[k] = result[0] + else: + unwrapped_inputs[k] = result + else: + unwrapped_inputs[k] = v + serialized = serialize_for_isolation(unwrapped_inputs) + logger.debug( + "%s ISO:serialize_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + logger.debug( + "%s ISO:dispatch_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + result = await extension.execute_node(node_name, **serialized) + logger.debug( + "%s ISO:dispatch_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + # Reconstruct NodeOutput if the child serialized one + if isinstance(result, dict) and result.get("__node_output__"): + from comfy_api.latest import io as latest_io + args_raw = result.get("args", ()) + deserialized_args = await deserialize_from_isolation(args_raw, extension) + deserialized_args = _wrap_remote_handles_as_host_proxies( + deserialized_args, extension + ) + deserialized_args = _detach_shared_cpu_tensors(deserialized_args) + ui_raw = result.get("ui") + deserialized_ui = None + if ui_raw is not None: + deserialized_ui = await deserialize_from_isolation(ui_raw, extension) + deserialized_ui = _wrap_remote_handles_as_host_proxies( + deserialized_ui, extension + ) + deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui) + scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) + return latest_io.NodeOutput( + *deserialized_args, + ui=deserialized_ui, + expand=result.get("expand"), + block_execution=result.get("block_execution"), + ) + # OUTPUT_NODE: if sealed worker returned a tuple/list whose first + # element is a {"ui": ...} dict, unwrap it for the executor. + if (isinstance(result, (tuple, list)) and len(result) == 1 + and isinstance(result[0], dict) and "ui" in result[0]): + return result[0] + deserialized = await deserialize_from_isolation(result, extension) + deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension) + scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) + return _detach_shared_cpu_tensors(deserialized) + except ImportError: + return await extension.execute_node(node_name, **inputs) + except Exception: + logger.exception( + "%s ISO:execute_error ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + raise + finally: + if prev_child is not None: + os.environ["PYISOLATE_CHILD"] = prev_child + logger.debug( + "%s ISO:execute_end ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True) + + def _input_types( + cls, + include_hidden: bool = True, + return_schema: bool = False, + live_inputs: Any = None, + ): + if not is_v3: + return restored_input_types + + inputs_copy = copy.deepcopy(restored_input_types) + if not include_hidden: + inputs_copy.pop("hidden", None) + + v3_data: Dict[str, Any] = {"hidden_inputs": {}} + dynamic = inputs_copy.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + + if return_schema: + hidden_vals = info.get("hidden", []) or [] + hidden_enums = [] + for h in hidden_vals: + try: + hidden_enums.append(latest_io.Hidden(h)) + except Exception: + hidden_enums.append(h) + + class SchemaProxy: + hidden = hidden_enums + + return inputs_copy, SchemaProxy, v3_data + return inputs_copy + + def _validate_class(cls): + return True + + def _get_node_info_v1(cls): + node_info = copy.deepcopy(info.get("schema_v1", {})) + relative_python_module = node_info.get("python_module") + if not isinstance(relative_python_module, str) or not relative_python_module: + relative_python_module = f"custom_nodes.{extension.name}" + node_info["python_module"] = relative_python_module + return node_info + + def _get_base_class(cls): + return latest_io.ComfyNode + + attributes: Dict[str, object] = { + "FUNCTION": function_name, + "CATEGORY": info.get("category", ""), + "OUTPUT_NODE": info.get("output_node", False), + "RETURN_TYPES": tuple(info.get("return_types", ()) or ()), + "RETURN_NAMES": info.get("return_names"), + function_name: _execute, + "_pyisolate_extension": extension, + "_pyisolate_node_name": node_name, + "INPUT_TYPES": classmethod(_input_types), + } + + output_is_list = info.get("output_is_list") + if output_is_list is not None: + attributes["OUTPUT_IS_LIST"] = tuple(output_is_list) + + if is_v3: + attributes["VALIDATE_CLASS"] = classmethod(_validate_class) + attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1) + attributes["GET_BASE_CLASS"] = classmethod(_get_base_class) + attributes["DESCRIPTION"] = info.get("description", "") + attributes["EXPERIMENTAL"] = info.get("experimental", False) + attributes["DEPRECATED"] = info.get("deprecated", False) + attributes["API_NODE"] = info.get("api_node", False) + attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False) + attributes["ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False) + attributes["_ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False) + attributes["INPUT_IS_LIST"] = info.get("input_is_list", False) + + class_name = f"PyIsolate_{node_name}".replace(" ", "_") + bases = (_ComfyNodeInternal,) if is_v3 else () + stub_cls = type(class_name, bases, attributes) + + if is_v3: + try: + stub_cls.VALIDATE_CLASS() + except Exception as e: + logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e) + + return stub_cls + + +def get_class_types_for_extension( + extension_name: str, + running_extensions: Dict[str, "ComfyNodeExtension"], + specs: List[Any], +) -> Set[str]: + extension = running_extensions.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in specs: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + return class_types + + +__all__ = ["build_stub_class", "get_class_types_for_extension"] diff --git a/comfy/isolation/shm_forensics.py b/comfy/isolation/shm_forensics.py new file mode 100644 index 000000000..36223505a --- /dev/null +++ b/comfy/isolation/shm_forensics.py @@ -0,0 +1,217 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel +from __future__ import annotations + +import atexit +import hashlib +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + + +def _shm_debug_enabled() -> bool: + return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1" + + +class _SHMForensicsTracker: + def __init__(self) -> None: + self._started = False + self._tracked_files: Set[str] = set() + self._current_model_context: Dict[str, str] = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + + @staticmethod + def _snapshot_shm() -> Set[str]: + shm_path = Path("/dev/shm") + if not shm_path.exists(): + return set() + return {f.name for f in shm_path.glob("torch_*")} + + def start(self) -> None: + if self._started or not _shm_debug_enabled(): + return + self._tracked_files = self._snapshot_shm() + self._started = True + logger.debug( + "%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files) + ) + + def stop(self) -> None: + if not self._started: + return + self.scan("shutdown", refresh_model_context=True) + self._started = False + logger.debug("%s SHM:forensics_disabled", LOG_PREFIX) + + def _compute_model_hash(self, model_patcher: Any) -> str: + try: + model_instance_id = getattr(model_patcher, "_instance_id", None) + if model_instance_id is not None: + model_id_text = str(model_instance_id) + return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text + + import torch + + real_model = ( + model_patcher.model + if hasattr(model_patcher, "model") + else model_patcher + ) + tensor = None + if hasattr(real_model, "parameters"): + for p in real_model.parameters(): + if torch.is_tensor(p) and p.numel() > 0: + tensor = p + break + + if tensor is None: + return "0000" + + flat = tensor.flatten() + values = [] + indices = [0, flat.shape[0] // 2, flat.shape[0] - 1] + for i in indices: + if i < flat.shape[0]: + values.append(flat[i].item()) + + size = 0 + if hasattr(model_patcher, "model_size"): + size = model_patcher.model_size() + sample_str = f"{values}_{id(model_patcher):016x}_{size}" + return hashlib.sha256(sample_str.encode()).hexdigest()[-4:] + except Exception: + return "err!" + + def _get_models_snapshot(self) -> List[Dict[str, Any]]: + try: + import comfy.model_management as model_management + except Exception: + return [] + + snapshot: List[Dict[str, Any]] = [] + try: + for loaded_model in model_management.current_loaded_models: + model = loaded_model.model + if model is None: + continue + if str(getattr(loaded_model, "device", "")) != "cuda:0": + continue + + name = ( + model.model.__class__.__name__ + if hasattr(model, "model") + else type(model).__name__ + ) + model_hash = self._compute_model_hash(model) + model_instance_id = getattr(model, "_instance_id", None) + if model_instance_id is None: + model_instance_id = model_hash + snapshot.append( + { + "name": str(name), + "id": str(model_instance_id), + "hash": str(model_hash or "????"), + "used": bool(getattr(loaded_model, "currently_used", False)), + } + ) + except Exception: + return [] + + return snapshot + + def _update_model_context(self) -> None: + snapshot = self._get_models_snapshot() + selected = None + + used_models = [m for m in snapshot if m.get("used") and m.get("id")] + if used_models: + selected = used_models[-1] + else: + live_models = [m for m in snapshot if m.get("id")] + if live_models: + selected = live_models[-1] + + if selected is None: + self._current_model_context = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + return + + self._current_model_context = { + "id": str(selected.get("id", "unknown")), + "name": str(selected.get("name", "unknown")), + "hash": str(selected.get("hash", "????") or "????"), + } + + def scan(self, marker: str, refresh_model_context: bool = True) -> None: + if not self._started or not _shm_debug_enabled(): + return + + if refresh_model_context: + self._update_model_context() + + current = self._snapshot_shm() + added = current - self._tracked_files + removed = self._tracked_files - current + self._tracked_files = current + + if not added and not removed: + logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker) + return + + for filename in sorted(added): + logger.info("%s SHM:created | %s", LOG_PREFIX, filename) + model_id = self._current_model_context["id"] + if model_id == "unknown": + logger.error( + "%s SHM:model_association_missing | file=%s | reason=no_active_model_context", + LOG_PREFIX, + filename, + ) + else: + logger.info( + "%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s", + LOG_PREFIX, + model_id, + filename, + self._current_model_context["name"], + self._current_model_context["hash"], + ) + + for filename in sorted(removed): + logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename) + + logger.debug( + "%s SHM:scan marker=%s created=%d deleted=%d active=%d", + LOG_PREFIX, + marker, + len(added), + len(removed), + len(self._tracked_files), + ) + + +_TRACKER = _SHMForensicsTracker() + + +def start_shm_forensics() -> None: + _TRACKER.start() + + +def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None: + _TRACKER.scan(marker, refresh_model_context=refresh_model_context) + + +def stop_shm_forensics() -> None: + _TRACKER.stop() + + +atexit.register(stop_shm_forensics) diff --git a/requirements.txt b/requirements.txt index 1a8e1ea1c..2c5a520c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,5 @@ pydantic~=2.0 pydantic-settings~=2.0 PyOpenGL glfw + +pyisolate==0.10.0