diff --git a/comfy/isolation/__init__.py b/comfy/isolation/__init__.py index 34ccb34dc..640092f45 100644 --- a/comfy/isolation/__init__.py +++ b/comfy/isolation/__init__.py @@ -8,11 +8,21 @@ import time from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Set, TYPE_CHECKING -import folder_paths -from .extension_loader import load_isolated_node -from .manifest_loader import find_manifest_directories -from .runtime_helpers import build_stub_class, get_class_types_for_extension -from .shm_forensics import scan_shm_forensics, start_shm_forensics +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +load_isolated_node = None +find_manifest_directories = None +build_stub_class = None +get_class_types_for_extension = None +scan_shm_forensics = None +start_shm_forensics = None + +if _IMPORT_TORCH: + import folder_paths + from .extension_loader import load_isolated_node + from .manifest_loader import find_manifest_directories + from .runtime_helpers import build_stub_class, get_class_types_for_extension + from .shm_forensics import scan_shm_forensics, start_shm_forensics if TYPE_CHECKING: from pyisolate import ExtensionManager @@ -21,8 +31,9 @@ if TYPE_CHECKING: LOG_PREFIX = "][" isolated_node_timings: List[tuple[float, Path, int]] = [] -PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs" -PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True) +if _IMPORT_TORCH: + PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs" + PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True) logger = logging.getLogger(__name__) _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 @@ -42,7 +53,8 @@ def initialize_proxies() -> None: from .host_hooks import initialize_host_process initialize_host_process() - start_shm_forensics() + if start_shm_forensics is not None: + start_shm_forensics() @dataclass(frozen=True) @@ -88,6 +100,8 @@ async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]: return [] _ISOLATION_SCAN_ATTEMPTED = True + if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None: + return [] manifest_entries = find_manifest_directories() _CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries} @@ -167,17 +181,27 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]: return class_types -async def notify_execution_graph(needed_class_types: Set[str]) -> None: - """Evict running extensions not needed for current execution.""" +async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None: + """Evict running extensions not needed for current execution. + + When *caches* is provided, cache entries for evicted extensions' node + class_types are invalidated to prevent stale ``RemoteObjectHandle`` + references from surviving in the output cache. + """ await wait_for_model_patcher_quiescence( timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, fail_loud=True, marker="ISO:notify_graph_wait_idle", ) + evicted_class_types: Set[str] = set() + async def _stop_extension( ext_name: str, extension: "ComfyNodeExtension", reason: str ) -> None: + # Collect class_types BEFORE stopping so we can invalidate cache entries. + ext_class_types = _get_class_types_for_extension(ext_name) + evicted_class_types.update(ext_class_types) logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason) logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name) stop_result = extension.stop() @@ -185,9 +209,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None: await stop_result _RUNNING_EXTENSIONS.pop(ext_name, None) logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name) - scan_shm_forensics("ISO:stop_extension", refresh_model_context=True) + if scan_shm_forensics is not None: + scan_shm_forensics("ISO:stop_extension", refresh_model_context=True) - scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True) + if scan_shm_forensics is not None: + scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True) isolated_class_types_in_graph = needed_class_types.intersection( {spec.node_name for spec in _ISOLATED_NODE_SPECS} ) @@ -247,6 +273,22 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None: "%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True ) finally: + # Invalidate cached outputs for evicted extensions so stale + # RemoteObjectHandle references are not served from cache. + if evicted_class_types and caches: + total_invalidated = 0 + for cache in caches: + if hasattr(cache, "invalidate_by_class_types"): + total_invalidated += cache.invalidate_by_class_types( + evicted_class_types + ) + if total_invalidated > 0: + logger.info( + "%s ISO:cache_invalidated count=%d class_types=%s", + LOG_PREFIX, + total_invalidated, + evicted_class_types, + ) scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True) logger.debug( "%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS) diff --git a/comfy/isolation/adapter.py b/comfy/isolation/adapter.py index 99beaa191..4751dee51 100644 --- a/comfy/isolation/adapter.py +++ b/comfy/isolation/adapter.py @@ -3,13 +3,38 @@ from __future__ import annotations import logging import os +import inspect from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, cast from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped] from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped] -try: +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp. +# Safe to import in sealed workers without host framework modules. +from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy +from comfy.isolation.proxies.helper_proxies import HelperProxiesService +from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + +# Singleton proxies that transitively import torch, PIL, or heavy host modules. +# Only available when torch/host framework is present. +CLIPProxy = None +CLIPRegistry = None +ModelPatcherProxy = None +ModelPatcherRegistry = None +ModelSamplingProxy = None +ModelSamplingRegistry = None +VAEProxy = None +VAERegistry = None +FirstStageModelRegistry = None +ModelManagementProxy = None +PromptServerService = None +ProgressProxy = None +UtilsProxy = None +_HAS_TORCH_PROXIES = False +if _IMPORT_TORCH: from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry from comfy.isolation.model_patcher_proxy import ( ModelPatcherProxy, @@ -20,13 +45,11 @@ try: ModelSamplingRegistry, ) from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry - from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy from comfy.isolation.proxies.prompt_server_impl import PromptServerService - from comfy.isolation.proxies.utils_proxy import UtilsProxy from comfy.isolation.proxies.progress_proxy import ProgressProxy -except ImportError as exc: # Fail loud if Comfy environment is incomplete - raise ImportError(f"ComfyUI environment incomplete: {exc}") + from comfy.isolation.proxies.utils_proxy import UtilsProxy + _HAS_TORCH_PROXIES = True logger = logging.getLogger(__name__) @@ -62,9 +85,20 @@ class ComfyUIAdapter(IsolationAdapter): os.path.join(comfy_root, "custom_nodes"), os.path.join(comfy_root, "comfy"), ], + "filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"], } return None + def get_sandbox_system_paths(self) -> Optional[List[str]]: + """Returns required application paths to mount in the sandbox.""" + # By inspecting where our adapter is loaded from, we can determine the comfy root + adapter_file = inspect.getfile(self.__class__) + # adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py + comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file))) + if os.path.exists(comfy_root): + return [comfy_root] + return None + def setup_child_environment(self, snapshot: Dict[str, Any]) -> None: comfy_root = snapshot.get("preferred_root") if not comfy_root: @@ -83,6 +117,59 @@ class ComfyUIAdapter(IsolationAdapter): logging.getLogger(pkg_name).setLevel(logging.ERROR) def register_serializers(self, registry: SerializerRegistryProtocol) -> None: + if not _IMPORT_TORCH: + # Sealed worker without torch — register torch-free TensorValue handler + # so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts. + import numpy as np + + _TORCH_DTYPE_TO_NUMPY = { + "torch.float32": np.float32, + "torch.float64": np.float64, + "torch.float16": np.float16, + "torch.bfloat16": np.float32, # numpy has no bfloat16; upcast + "torch.int32": np.int32, + "torch.int64": np.int64, + "torch.int16": np.int16, + "torch.int8": np.int8, + "torch.uint8": np.uint8, + "torch.bool": np.bool_, + } + + def _deserialize_tensor_value(data: Dict[str, Any]) -> Any: + dtype_str = data["dtype"] + np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32) + shape = tuple(data["tensor_size"]) + arr = np.array(data["data"], dtype=np_dtype).reshape(shape) + return arr + + _NUMPY_TO_TORCH_DTYPE = { + np.float32: "torch.float32", + np.float64: "torch.float64", + np.float16: "torch.float16", + np.int32: "torch.int32", + np.int64: "torch.int64", + np.int16: "torch.int16", + np.int8: "torch.int8", + np.uint8: "torch.uint8", + np.bool_: "torch.bool", + } + + def _serialize_tensor_value(obj: Any) -> Dict[str, Any]: + arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj) + dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32") + return { + "__type__": "TensorValue", + "dtype": dtype_str, + "tensor_size": list(arr.shape), + "requires_grad": False, + "data": arr.tolist(), + } + + registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True) + # ndarray output from sealed workers serializes as TensorValue for host torch reconstruction + registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True) + return + import torch def serialize_device(obj: Any) -> Dict[str, Any]: @@ -110,6 +197,54 @@ class ComfyUIAdapter(IsolationAdapter): registry.register("dtype", serialize_dtype, deserialize_dtype) + from comfy_api.latest._io import FolderType + from comfy_api.latest._ui import SavedImages, SavedResult + + def serialize_saved_result(obj: Any) -> Dict[str, Any]: + return { + "__type__": "SavedResult", + "filename": obj.filename, + "subfolder": obj.subfolder, + "folder_type": obj.type.value, + } + + def deserialize_saved_result(data: Dict[str, Any]) -> Any: + if isinstance(data, SavedResult): + return data + folder_type = data["folder_type"] if "folder_type" in data else data["type"] + return SavedResult( + filename=data["filename"], + subfolder=data["subfolder"], + type=FolderType(folder_type), + ) + + registry.register( + "SavedResult", + serialize_saved_result, + deserialize_saved_result, + data_type=True, + ) + + def serialize_saved_images(obj: Any) -> Dict[str, Any]: + return { + "__type__": "SavedImages", + "results": [serialize_saved_result(result) for result in obj.results], + "is_animated": obj.is_animated, + } + + def deserialize_saved_images(data: Dict[str, Any]) -> Any: + return SavedImages( + results=[deserialize_saved_result(result) for result in data["results"]], + is_animated=data.get("is_animated", False), + ) + + registry.register( + "SavedImages", + serialize_saved_images, + deserialize_saved_images, + data_type=True, + ) + def serialize_model_patcher(obj: Any) -> Dict[str, Any]: # Child-side: must already have _instance_id (proxy) if os.environ.get("PYISOLATE_CHILD") == "1": @@ -212,28 +347,93 @@ class ComfyUIAdapter(IsolationAdapter): # copyreg removed - no pickle fallback allowed def serialize_model_sampling(obj: Any) -> Dict[str, Any]: - # Child-side: must already have _instance_id (proxy) - if os.environ.get("PYISOLATE_CHILD") == "1": - if hasattr(obj, "_instance_id"): - return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} - raise RuntimeError( - f"ModelSampling in child lacks _instance_id: " - f"{type(obj).__module__}.{type(obj).__name__}" - ) - # Host-side pass-through for proxies: do not re-register a proxy as a - # new ModelSamplingRef, or we create proxy-of-proxy indirection. + # Proxy with _instance_id — return ref (works from both host and child) if hasattr(obj, "_instance_id"): return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} + # Child-side: object created locally in child (e.g. ModelSamplingAdvanced + # in nodes_z_image_turbo.py). Serialize as inline data so the host can + # reconstruct the real torch.nn.Module. + if os.environ.get("PYISOLATE_CHILD") == "1": + import base64 + import io as _io + + # Identify base classes from comfy.model_sampling + bases = [] + for base in type(obj).__mro__: + if base.__module__ == "comfy.model_sampling" and base.__name__ != "object": + bases.append(base.__name__) + # Serialize state_dict as base64 safetensors-like + sd = obj.state_dict() + sd_serialized = {} + for k, v in sd.items(): + buf = _io.BytesIO() + torch.save(v, buf) + sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii") + # Capture plain attrs (shift, multiplier, sigma_data, etc.) + plain_attrs = {} + for k, v in obj.__dict__.items(): + if k.startswith("_"): + continue + if isinstance(v, (bool, int, float, str)): + plain_attrs[k] = v + return { + "__type__": "ModelSamplingInline", + "bases": bases, + "state_dict": sd_serialized, + "attrs": plain_attrs, + } # Host-side: register with ModelSamplingRegistry and return JSON-safe dict ms_id = ModelSamplingRegistry().register(obj) return {"__type__": "ModelSamplingRef", "ms_id": ms_id} def deserialize_model_sampling(data: Any) -> Any: - """Deserialize ModelSampling refs; pass through already-materialized objects.""" + """Deserialize ModelSampling refs or inline data.""" if isinstance(data, dict): + if data.get("__type__") == "ModelSamplingInline": + return _reconstruct_model_sampling_inline(data) return ModelSamplingProxy(data["ms_id"]) return data + def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any: + """Reconstruct a ModelSampling object on the host from inline child data.""" + import comfy.model_sampling as _ms + import base64 + import io as _io + + # Resolve base classes + base_classes = [] + for name in data["bases"]: + cls = getattr(_ms, name, None) + if cls is not None: + base_classes.append(cls) + if not base_classes: + raise RuntimeError( + f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}" + ) + # Create dynamic class matching the child's class hierarchy + ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {}) + obj = ReconstructedSampling.__new__(ReconstructedSampling) + torch.nn.Module.__init__(obj) + # Restore plain attributes first + for k, v in data.get("attrs", {}).items(): + setattr(obj, k, v) + # Restore state_dict (buffers like sigmas) + for k, v_b64 in data.get("state_dict", {}).items(): + buf = _io.BytesIO(base64.b64decode(v_b64)) + tensor = torch.load(buf, weights_only=True) + # Register as buffer so it's part of state_dict + parts = k.split(".") + if len(parts) == 1: + cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member + else: + setattr(obj, parts[0], tensor) + # Register on host so future references use proxy pattern. + # Skip in child process — register() is async RPC and cannot be + # called synchronously during deserialization. + if os.environ.get("PYISOLATE_CHILD") != "1": + ModelSamplingRegistry().register(obj) + return obj + def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any: """Context-aware ModelSamplingRef deserializer for both host and child.""" is_child = os.environ.get("PYISOLATE_CHILD") == "1" @@ -262,6 +462,10 @@ class ComfyUIAdapter(IsolationAdapter): ) # Register ModelSamplingRef for deserialization (context-aware: host or child) registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref) + # Register ModelSamplingInline for deserialization (child→host inline transfer) + registry.register( + "ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data) + ) def serialize_cond(obj: Any) -> Dict[str, Any]: type_key = f"{type(obj).__module__}.{type(obj).__name__}" @@ -393,52 +597,20 @@ class ComfyUIAdapter(IsolationAdapter): # Fallback for non-numeric arrays (strings, objects, mixes) return obj.tolist() - registry.register("ndarray", serialize_numpy, None) - - def serialize_ply(obj: Any) -> Dict[str, Any]: + def deserialize_numpy_b64(data: Any) -> Any: + """Deserialize base64-encoded ndarray from sealed worker.""" import base64 - import torch - if obj.raw_data is not None: - return { - "__type__": "PLY", - "raw_data": base64.b64encode(obj.raw_data).decode("ascii"), - } - result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)} - if obj.colors is not None: - result["colors"] = torch.from_numpy(obj.colors) - if obj.confidence is not None: - result["confidence"] = torch.from_numpy(obj.confidence) - if obj.view_id is not None: - result["view_id"] = torch.from_numpy(obj.view_id) - return result + import numpy as np + if isinstance(data, dict) and "data" in data and "dtype" in data: + raw = base64.b64decode(data["data"]) + arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"]) + return torch.from_numpy(arr.copy()) + return data - def deserialize_ply(data: Any) -> Any: - import base64 - from comfy_api.latest._util.ply_types import PLY - if "raw_data" in data: - return PLY(raw_data=base64.b64decode(data["raw_data"])) - return PLY( - points=data["points"], - colors=data.get("colors"), - confidence=data.get("confidence"), - view_id=data.get("view_id"), - ) + registry.register("ndarray", serialize_numpy, deserialize_numpy_b64) - registry.register("PLY", serialize_ply, deserialize_ply, data_type=True) - - def serialize_npz(obj: Any) -> Dict[str, Any]: - import base64 - return { - "__type__": "NPZ", - "frames": [base64.b64encode(f).decode("ascii") for f in obj.frames], - } - - def deserialize_npz(data: Any) -> Any: - import base64 - from comfy_api.latest._util.npz_types import NPZ - return NPZ(frames=[base64.b64decode(f) for f in data["frames"]]) - - registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True) + # -- File3D (comfy_api.latest._util.geometry_types) --------------------- + # Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129 def serialize_file3d(obj: Any) -> Dict[str, Any]: import base64 @@ -456,6 +628,9 @@ class ComfyUIAdapter(IsolationAdapter): registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True) + # -- VIDEO (comfy_api.latest._input_impl.video_types) ------------------- + # Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962 + def serialize_video(obj: Any) -> Dict[str, Any]: components = obj.get_components() images = components.images.detach() if components.images.requires_grad else components.images @@ -494,19 +669,156 @@ class ComfyUIAdapter(IsolationAdapter): registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True) registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True) + def setup_web_directory(self, module: Any) -> None: + """Detect WEB_DIRECTORY on a module and populate/register it. + + Called by the sealed worker after loading the node module. + Mirrors extension_wrapper.py:216-227 for host-coupled nodes. + Does NOT import extension_wrapper.py (it has `import torch` at module level). + """ + import shutil + + web_dir_attr = getattr(module, "WEB_DIRECTORY", None) + if web_dir_attr is None: + return + + module_dir = os.path.dirname(os.path.abspath(module.__file__)) + web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr)) + + # Read extension name from pyproject.toml + ext_name = os.path.basename(module_dir) + pyproject = os.path.join(module_dir, "pyproject.toml") + if os.path.exists(pyproject): + try: + import tomllib + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + try: + with open(pyproject, "rb") as f: + data = tomllib.load(f) + name = data.get("project", {}).get("name") + if name: + ext_name = name + except Exception: + pass + + # Populate web dir if empty (mirrors _run_prestartup_web_copy) + if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))): + os.makedirs(web_dir_path, exist_ok=True) + + # Module-defined copy spec + copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None) + if copy_spec is not None and callable(copy_spec): + try: + copy_spec(web_dir_path) + except Exception as e: + logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e) + + # Fallback: comfy_3d_viewers + try: + from comfy_3d_viewers import copy_viewer, VIEWER_FILES + for viewer in VIEWER_FILES: + try: + copy_viewer(viewer, web_dir_path) + except Exception: + pass + except ImportError: + pass + + # Fallback: comfy_dynamic_widgets + try: + from comfy_dynamic_widgets import get_js_path + src = os.path.realpath(get_js_path()) + if os.path.exists(src): + dst_dir = os.path.join(web_dir_path, "js") + os.makedirs(dst_dir, exist_ok=True) + shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js")) + except ImportError: + pass + + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + WebDirectoryProxy.register_web_dir(ext_name, web_dir_path) + logger.info( + "][ Adapter: registered web dir for %s (%d files)", + ext_name, + sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()), + ) + + @staticmethod + def register_host_event_handlers(extension: Any) -> None: + """Register host-side event handlers for an isolated extension. + + Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK`` + so the ComfyUI frontend receives progress bar updates. + """ + register_event_handler = inspect.getattr_static( + extension, "register_event_handler", None + ) + if not callable(register_event_handler): + return + + def _host_progress_handler(payload: dict) -> None: + import comfy.utils + + hook = comfy.utils.PROGRESS_BAR_HOOK + if hook is not None: + hook( + payload.get("value", 0), + payload.get("total", 0), + payload.get("preview"), + payload.get("node_id"), + ) + + extension.register_event_handler("progress", _host_progress_handler) + + def setup_child_event_hooks(self, extension: Any) -> None: + """Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension. + + Host-coupled only — sealed workers do not have comfy.utils (torch). + """ + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child) + if not is_child: + return + + if not _IMPORT_TORCH: + logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)") + return + + import comfy.utils + + def _event_progress_hook(value, total, preview=None, node_id=None): + logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id) + extension.emit_event("progress", { + "value": value, + "total": total, + "node_id": node_id, + }) + + comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook + logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel") + def provide_rpc_services(self) -> List[type[ProxiedSingleton]]: - return [ - PromptServerService, + # Always available — no torch/PIL dependency + services: List[type[ProxiedSingleton]] = [ FolderPathsProxy, - ModelManagementProxy, - UtilsProxy, - ProgressProxy, - VAERegistry, - CLIPRegistry, - ModelPatcherRegistry, - ModelSamplingRegistry, - FirstStageModelRegistry, + HelperProxiesService, + WebDirectoryProxy, ] + # Torch/PIL-dependent proxies + if _HAS_TORCH_PROXIES: + services.extend([ + PromptServerService, + ModelManagementProxy, + UtilsProxy, + ProgressProxy, + VAERegistry, + CLIPRegistry, + ModelPatcherRegistry, + ModelSamplingRegistry, + FirstStageModelRegistry, + ]) + return services def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None: # Resolve the real name whether it's an instance or the Singleton class itself @@ -525,33 +837,45 @@ class ComfyUIAdapter(IsolationAdapter): # Fence: isolated children get writable temp inside sandbox if os.environ.get("PYISOLATE_CHILD") == "1": - _child_temp = os.path.join("/tmp", "comfyui_temp") + import tempfile + _child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp") os.makedirs(_child_temp, exist_ok=True) folder_paths.temp_directory = _child_temp return if api_name == "ModelManagementProxy": - import comfy.model_management + if _IMPORT_TORCH: + import comfy.model_management - instance = api() if isinstance(api, type) else api - # Replace module-level functions with proxy methods - for name in dir(instance): - if not name.startswith("_"): - setattr(comfy.model_management, name, getattr(instance, name)) + instance = api() if isinstance(api, type) else api + # Replace module-level functions with proxy methods + for name in dir(instance): + if not name.startswith("_"): + setattr(comfy.model_management, name, getattr(instance, name)) return if api_name == "UtilsProxy": + if not _IMPORT_TORCH: + logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)") + return + import comfy.utils # Static Injection of RPC mechanism to ensure Child can access it # independent of instance lifecycle. api.set_rpc(rpc) - # Don't overwrite host hook (infinite recursion) + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child) + + # Progress hook wiring moved to setup_child_event_hooks via event channel + return if api_name == "PromptServerProxy": + if not _IMPORT_TORCH: + return # Defer heavy import to child context import server diff --git a/comfy/isolation/child_hooks.py b/comfy/isolation/child_hooks.py index a1ba201ac..a5ec33c21 100644 --- a/comfy/isolation/child_hooks.py +++ b/comfy/isolation/child_hooks.py @@ -11,130 +11,90 @@ def is_child_process() -> bool: def initialize_child_process() -> None: + _setup_child_loop_bridge() + # Manual RPC injection try: from pyisolate._internal.rpc_protocol import get_child_rpc_instance rpc = get_child_rpc_instance() if rpc: - _setup_prompt_server_stub(rpc) - _setup_utils_proxy(rpc) + _setup_proxy_callers(rpc) else: logger.warning("Could not get child RPC instance for manual injection") - _setup_prompt_server_stub() - _setup_utils_proxy() + _setup_proxy_callers() except Exception as e: logger.error(f"Manual RPC Injection failed: {e}") - _setup_prompt_server_stub() - _setup_utils_proxy() + _setup_proxy_callers() _setup_logging() +def _setup_child_loop_bridge() -> None: + import asyncio + + main_loop = None + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + try: + main_loop = asyncio.get_event_loop() + except RuntimeError: + pass + + if main_loop is None: + return + + try: + from .proxies.base import set_global_loop + + set_global_loop(main_loop) + except ImportError: + pass + + def _setup_prompt_server_stub(rpc=None) -> None: try: from .proxies.prompt_server_impl import PromptServerStub - import sys - import types - - # Mock server module - if "server" not in sys.modules: - mock_server = types.ModuleType("server") - sys.modules["server"] = mock_server - - server = sys.modules["server"] - - if not hasattr(server, "PromptServer"): - - class MockPromptServer: - pass - - server.PromptServer = MockPromptServer - - stub = PromptServerStub() if rpc: PromptServerStub.set_rpc(rpc) - if hasattr(stub, "set_rpc"): - stub.set_rpc(rpc) - - server.PromptServer.instance = stub + elif hasattr(PromptServerStub, "clear_rpc"): + PromptServerStub.clear_rpc() + else: + PromptServerStub._rpc = None # type: ignore[attr-defined] except Exception as e: logger.error(f"Failed to setup PromptServerStub: {e}") -def _setup_utils_proxy(rpc=None) -> None: +def _setup_proxy_callers(rpc=None) -> None: try: - import comfy.utils - import asyncio + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.helper_proxies import HelperProxiesService + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerStub + from .proxies.utils_proxy import UtilsProxy - # Capture main loop during initialization (safe context) - main_loop = None - try: - main_loop = asyncio.get_running_loop() - except RuntimeError: - try: - main_loop = asyncio.get_event_loop() - except RuntimeError: - pass + if rpc is None: + FolderPathsProxy.clear_rpc() + HelperProxiesService.clear_rpc() + ModelManagementProxy.clear_rpc() + ProgressProxy.clear_rpc() + PromptServerStub.clear_rpc() + UtilsProxy.clear_rpc() + return - try: - from .proxies.base import set_global_loop - - if main_loop: - set_global_loop(main_loop) - except ImportError: - pass - - # Sync hook wrapper for progress updates - def sync_hook_wrapper( - value: int, total: int, preview: None = None, node_id: None = None - ) -> None: - if node_id is None: - try: - from comfy_execution.utils import get_executing_context - - ctx = get_executing_context() - if ctx: - node_id = ctx.node_id - else: - pass - except Exception: - pass - - # Bypass blocked event loop by direct outbox injection - if rpc: - try: - # Use captured main loop if available (for threaded execution), or current loop - loop = main_loop - if loop is None: - loop = asyncio.get_event_loop() - - rpc.outbox.put( - { - "kind": "call", - "object_id": "UtilsProxy", - "parent_call_id": None, # We are root here usually - "calling_loop": loop, - "future": loop.create_future(), # Dummy future - "method": "progress_bar_hook", - "args": (value, total, preview, node_id), - "kwargs": {}, - } - ) - - except Exception as e: - logging.getLogger(__name__).error(f"Manual Inject Failed: {e}") - else: - logging.getLogger(__name__).warning( - "No RPC instance available for progress update" - ) - - comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper + FolderPathsProxy.set_rpc(rpc) + HelperProxiesService.set_rpc(rpc) + ModelManagementProxy.set_rpc(rpc) + ProgressProxy.set_rpc(rpc) + PromptServerStub.set_rpc(rpc) + UtilsProxy.set_rpc(rpc) except Exception as e: - logger.error(f"Failed to setup UtilsProxy hook: {e}") + logger.error(f"Failed to setup child singleton proxy callers: {e}") def _setup_logging() -> None: diff --git a/comfy/isolation/custom_node_serializers.py b/comfy/isolation/custom_node_serializers.py new file mode 100644 index 000000000..e7a6e78c2 --- /dev/null +++ b/comfy/isolation/custom_node_serializers.py @@ -0,0 +1,16 @@ +"""Compatibility shim for the indexed serializer path.""" + +from __future__ import annotations + +from typing import Any + + +def register_custom_node_serializers(_registry: Any) -> None: + """Legacy no-op shim. + + Serializer registration now lives directly in the active isolation adapter. + This module remains importable because the isolation index still references it. + """ + return None + +__all__ = ["register_custom_node_serializers"] diff --git a/comfy/isolation/extension_loader.py b/comfy/isolation/extension_loader.py index a93681ffa..1b644a305 100644 --- a/comfy/isolation/extension_loader.py +++ b/comfy/isolation/extension_loader.py @@ -8,14 +8,13 @@ import sys import types import platform from pathlib import Path -from typing import Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Tuple import pyisolate from pyisolate import ExtensionManager, ExtensionManagerConfig from packaging.requirements import InvalidRequirement, Requirement from packaging.utils import canonicalize_name -from .extension_wrapper import ComfyNodeExtension from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache from .host_policy import load_host_policy @@ -42,7 +41,11 @@ def _register_web_directory(extension_name: str, node_dir: Path) -> None: web_dir_path = str(node_dir / web_dir_name) if os.path.isdir(web_dir_path): nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path - logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path) + logger.debug( + "][ Registered web dir for isolated %s: %s", + extension_name, + web_dir_path, + ) return except Exception: pass @@ -62,15 +65,26 @@ def _register_web_directory(extension_name: str, node_dir: Path) -> None: web_dir_path = str((node_dir / value).resolve()) if os.path.isdir(web_dir_path): nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path - logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path) + logger.debug( + "][ Registered web dir for isolated %s: %s", + extension_name, + web_dir_path, + ) return except Exception: pass -async def _stop_extension_safe( - extension: ComfyNodeExtension, extension_name: str -) -> None: +def _get_extension_type(execution_model: str) -> type[Any]: + if execution_model == "sealed_worker": + return pyisolate.SealedNodeExtension + + from .extension_wrapper import ComfyNodeExtension + + return ComfyNodeExtension + + +async def _stop_extension_safe(extension: Any, extension_name: str) -> None: try: stop_result = extension.stop() if inspect.isawaitable(stop_result): @@ -126,12 +140,18 @@ def _parse_cuda_wheels_config( if raw_config is None: return None if not isinstance(raw_config, dict): - raise ExtensionLoadError( - "[tool.comfy.isolation.cuda_wheels] must be a table" - ) + raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table") index_url = raw_config.get("index_url") - if not isinstance(index_url, str) or not index_url.strip(): + index_urls = raw_config.get("index_urls") + if index_urls is not None: + if not isinstance(index_urls, list) or not all( + isinstance(u, str) and u.strip() for u in index_urls + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings" + ) + elif not isinstance(index_url, str) or not index_url.strip(): raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string" ) @@ -187,11 +207,15 @@ def _parse_cuda_wheels_config( ) normalized_package_map[canonical_dependency_name] = index_package_name.strip() - return { - "index_url": index_url.rstrip("/") + "/", + result: dict = { "packages": normalized_packages, "package_map": normalized_package_map, } + if index_urls is not None: + result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls] + else: + result["index_url"] = index_url.rstrip("/") + "/" + return result def get_enforcement_policy() -> Dict[str, bool]: @@ -228,7 +252,7 @@ async def load_isolated_node( node_dir: Path, manifest_path: Path, logger: logging.Logger, - build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type], + build_stub_class: Callable[[str, Dict[str, object], Any], type], venv_root: Path, extension_managers: List[ExtensionManager], ) -> List[Tuple[str, str, type]]: @@ -243,6 +267,31 @@ async def load_isolated_node( tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {}) can_isolate = tool_config.get("can_isolate", False) share_torch = tool_config.get("share_torch", False) + package_manager = tool_config.get("package_manager", "uv") + is_conda = package_manager == "conda" + execution_model = tool_config.get("execution_model") + if execution_model is None: + execution_model = "sealed_worker" if is_conda else "host-coupled" + + if "sealed_host_ro_paths" in tool_config: + raise ValueError( + "Manifest field 'sealed_host_ro_paths' is not allowed. " + "Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy." + ) + + # Conda-specific manifest fields + conda_channels: list[str] = ( + tool_config.get("conda_channels", []) if is_conda else [] + ) + conda_dependencies: list[str] = ( + tool_config.get("conda_dependencies", []) if is_conda else [] + ) + conda_platforms: list[str] = ( + tool_config.get("conda_platforms", []) if is_conda else [] + ) + conda_python: str = ( + tool_config.get("conda_python", "*") if is_conda else "*" + ) # Parse [project] dependencies project_config = manifest_data.get("project", {}) @@ -260,8 +309,6 @@ async def load_isolated_node( if not isolated: return [] - logger.info(f"][ Loading isolated node: {extension_name}") - import folder_paths base_paths = [Path(folder_paths.base_path), node_dir] @@ -272,8 +319,9 @@ async def load_isolated_node( cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies) manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root)) + extension_type = _get_extension_type(execution_model) manager: ExtensionManager = pyisolate.ExtensionManager( - ComfyNodeExtension, manager_config + extension_type, manager_config ) extension_managers.append(manager) @@ -281,15 +329,21 @@ async def load_isolated_node( sandbox_config = {} is_linux = platform.system() == "Linux" + + if is_conda: + share_torch = False + share_cuda_ipc = False + else: + share_cuda_ipc = share_torch and is_linux + if is_linux and isolated: sandbox_config = { "network": host_policy["allow_network"], "writable_paths": host_policy["writable_paths"], "readonly_paths": host_policy["readonly_paths"], } - share_cuda_ipc = share_torch and is_linux - extension_config = { + extension_config: dict = { "name": extension_name, "module_path": str(node_dir), "isolated": True, @@ -299,17 +353,60 @@ async def load_isolated_node( "sandbox_mode": host_policy["sandbox_mode"], "sandbox": sandbox_config, } + + _is_sealed = execution_model == "sealed_worker" + _is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux + logger.info( + "][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])", + extension_name, + "x" if share_torch else " ", + "x" if _is_sealed else " ", + "x" if _is_sandboxed else " ", + ) + if cuda_wheels is not None: extension_config["cuda_wheels"] = cuda_wheels + # Conda-specific keys + if is_conda: + extension_config["package_manager"] = "conda" + extension_config["conda_channels"] = conda_channels + extension_config["conda_dependencies"] = conda_dependencies + extension_config["conda_python"] = conda_python + find_links = tool_config.get("find_links", []) + if find_links: + extension_config["find_links"] = find_links + if conda_platforms: + extension_config["conda_platforms"] = conda_platforms + + if execution_model != "host-coupled": + extension_config["execution_model"] = execution_model + if execution_model == "sealed_worker": + policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", []) + if isinstance(policy_ro_paths, list) and policy_ro_paths: + extension_config["sealed_host_ro_paths"] = list(policy_ro_paths) + # Sealed workers keep the host RPC service inventory even when the + # child resolves no API classes locally. + extension = manager.load_extension(extension_config) register_dummy_module(extension_name, node_dir) + # Register host-side event handlers via adapter + from .adapter import ComfyUIAdapter + ComfyUIAdapter.register_host_event_handlers(extension) + # Register web directory on the host — only when sandbox is disabled. # In sandbox mode, serving untrusted JS to the browser is not safe. if host_policy["sandbox_mode"] == "disabled": _register_web_directory(extension_name, node_dir) + # Register for proxied web serving — the child's web dir may have + # content that doesn't exist on the host (e.g., pip-installed viewer + # bundles). The WebDirectoryCache will lazily fetch via RPC. + from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache + cache = get_web_directory_cache() + cache.register_proxy(extension_name, WebDirectoryProxy()) + # Try cache first (lazy spawn) if is_cache_valid(node_dir, manifest_path, venv_root): cached_data = load_from_cache(node_dir, venv_root) @@ -379,6 +476,10 @@ async def load_isolated_node( save_to_cache(node_dir, venv_root, cache_data, manifest_path) logger.debug(f"][ {extension_name} metadata cached") + # Re-check web directory AFTER child has populated it + if host_policy["sandbox_mode"] == "disabled": + _register_web_directory(extension_name, node_dir) + # EJECT: Kill process after getting metadata (will respawn on first execution) await _stop_extension_safe(extension, extension_name) diff --git a/comfy/isolation/extension_wrapper.py b/comfy/isolation/extension_wrapper.py index 60f6288a4..a6a8a1d2f 100644 --- a/comfy/isolation/extension_wrapper.py +++ b/comfy/isolation/extension_wrapper.py @@ -37,6 +37,88 @@ _PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 logger = logging.getLogger(__name__) +def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None: + """Run the web asset copy step that prestartup_script.py used to do. + + If the module's web/ directory is empty and the module had a + prestartup_script.py that copied assets from pip packages, this + function replicates that work inside the child process. + + Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if + defined, otherwise falls back to detecting common asset packages. + """ + import shutil + + # Already populated — nothing to do + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + return + + os.makedirs(web_dir_path, exist_ok=True) + + # Try module-defined copy spec first (generic hook for any node pack) + copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None) + if copy_spec is not None and callable(copy_spec): + try: + copy_spec(web_dir_path) + logger.info( + "%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir + ) + return + except Exception as e: + logger.warning( + "%s _PRESTARTUP_WEB_COPY failed for %s: %s", + LOG_PREFIX, module_dir, e, + ) + + # Fallback: detect comfy_3d_viewers and run copy_viewer() + try: + from comfy_3d_viewers import copy_viewer, VIEWER_FILES + viewers = list(VIEWER_FILES.keys()) + for viewer in viewers: + try: + copy_viewer(viewer, web_dir_path) + except Exception: + pass + if any(os.scandir(web_dir_path)): + logger.info( + "%s Copied %d viewer types from comfy_3d_viewers to %s", + LOG_PREFIX, len(viewers), web_dir_path, + ) + except ImportError: + pass + + # Fallback: detect comfy_dynamic_widgets + try: + from comfy_dynamic_widgets import get_js_path + src = os.path.realpath(get_js_path()) + if os.path.exists(src): + dst_dir = os.path.join(web_dir_path, "js") + os.makedirs(dst_dir, exist_ok=True) + dst = os.path.join(dst_dir, "dynamic_widgets.js") + shutil.copy2(src, dst) + except ImportError: + pass + + +def _read_extension_name(module_dir: str) -> str: + """Read extension name from pyproject.toml, falling back to directory name.""" + pyproject = os.path.join(module_dir, "pyproject.toml") + if os.path.exists(pyproject): + try: + import tomllib + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + try: + with open(pyproject, "rb") as f: + data = tomllib.load(f) + name = data.get("project", {}).get("name") + if name: + return name + except Exception: + pass + return os.path.basename(module_dir) + + def _flush_tensor_transport_state(marker: str) -> int: try: from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] @@ -130,6 +212,20 @@ class ComfyNodeExtension(ExtensionBase): self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} + # Register web directory with WebDirectoryProxy (child-side) + web_dir_attr = getattr(module, "WEB_DIRECTORY", None) + if web_dir_attr is not None: + module_dir = os.path.dirname(os.path.abspath(module.__file__)) + web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr)) + ext_name = _read_extension_name(module_dir) + + # If web dir is empty, run the copy step that prestartup_script.py did + _run_prestartup_web_copy(module, module_dir, web_dir_path) + + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + WebDirectoryProxy.register_web_dir(ext_name, web_dir_path) + try: from comfy_api.latest import ComfyExtension @@ -365,6 +461,12 @@ class ComfyNodeExtension(ExtensionBase): for key, value in hidden_found.items(): setattr(node_cls.hidden, key.value.lower(), value) + # INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this + # flag is set. The isolation RPC delivers unwrapped values, so we must + # wrap each input in a single-element list to match the contract. + if getattr(node_cls, "INPUT_IS_LIST", False): + resolved_inputs = {k: [v] for k, v in resolved_inputs.items()} + function_name = getattr(node_cls, "FUNCTION", "execute") if not hasattr(instance, function_name): raise AttributeError(f"Node {node_name} missing callable '{function_name}'") @@ -402,7 +504,7 @@ class ComfyNodeExtension(ExtensionBase): "args": self._wrap_unpicklable_objects(result.args), } if result.ui is not None: - node_output_dict["ui"] = result.ui + node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui) if getattr(result, "expand", None) is not None: node_output_dict["expand"] = result.expand if getattr(result, "block_execution", None) is not None: @@ -454,6 +556,85 @@ class ComfyNodeExtension(ExtensionBase): return self.remote_objects[object_id] + def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle: + object_id = str(uuid.uuid4()) + self.remote_objects[object_id] = obj + return RemoteObjectHandle(object_id, type(obj).__name__) + + async def call_remote_object_method( + self, + object_id: str, + method_name: str, + *args: Any, + **kwargs: Any, + ) -> Any: + """Invoke a method or attribute-backed accessor on a child-owned object.""" + obj = await self.get_remote_object(object_id) + + if method_name == "get_patcher_attr": + return getattr(obj, args[0]) + if method_name == "get_model_options": + return getattr(obj, "model_options") + if method_name == "set_model_options": + setattr(obj, "model_options", args[0]) + return None + if method_name == "get_object_patches": + return getattr(obj, "object_patches") + if method_name == "get_patches": + return getattr(obj, "patches") + if method_name == "get_wrappers": + return getattr(obj, "wrappers") + if method_name == "get_callbacks": + return getattr(obj, "callbacks") + if method_name == "get_load_device": + return getattr(obj, "load_device") + if method_name == "get_offload_device": + return getattr(obj, "offload_device") + if method_name == "get_hook_mode": + return getattr(obj, "hook_mode") + if method_name == "get_parent": + parent = getattr(obj, "parent", None) + if parent is None: + return None + return self._store_remote_object_handle(parent) + if method_name == "get_inner_model_attr": + attr_name = args[0] + if hasattr(obj.model, attr_name): + return getattr(obj.model, attr_name) + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + return None + if method_name == "inner_model_apply_model": + return obj.model.apply_model(*args[0], **args[1]) + if method_name == "inner_model_extra_conds_shapes": + return obj.model.extra_conds_shapes(*args[0], **args[1]) + if method_name == "inner_model_extra_conds": + return obj.model.extra_conds(*args[0], **args[1]) + if method_name == "inner_model_memory_required": + return obj.model.memory_required(*args[0], **args[1]) + if method_name == "process_latent_in": + return obj.model.process_latent_in(*args[0], **args[1]) + if method_name == "process_latent_out": + return obj.model.process_latent_out(*args[0], **args[1]) + if method_name == "scale_latent_inpaint": + return obj.model.scale_latent_inpaint(*args[0], **args[1]) + if method_name.startswith("get_"): + attr_name = method_name[4:] + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + + target = getattr(obj, method_name) + if callable(target): + result = target(*args, **kwargs) + if inspect.isawaitable(result): + result = await result + if type(result).__name__ == "ModelPatcher": + return self._store_remote_object_handle(result) + return result + if args or kwargs: + raise TypeError(f"{method_name} is not callable on remote object {object_id}") + return target + def _wrap_unpicklable_objects(self, data: Any) -> Any: if isinstance(data, (str, int, float, bool, type(None))): return data @@ -514,9 +695,7 @@ class ComfyNodeExtension(ExtensionBase): if serializer: return serializer(data) - object_id = str(uuid.uuid4()) - self.remote_objects[object_id] = data - return RemoteObjectHandle(object_id, type(data).__name__) + return self._store_remote_object_handle(data) def _resolve_remote_objects(self, data: Any) -> Any: if isinstance(data, RemoteObjectHandle): diff --git a/comfy/isolation/host_hooks.py b/comfy/isolation/host_hooks.py index 86cde10a8..e20143591 100644 --- a/comfy/isolation/host_hooks.py +++ b/comfy/isolation/host_hooks.py @@ -12,15 +12,19 @@ def initialize_host_process() -> None: root.addHandler(logging.NullHandler()) from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.helper_proxies import HelperProxiesService from .proxies.model_management_proxy import ModelManagementProxy from .proxies.progress_proxy import ProgressProxy from .proxies.prompt_server_impl import PromptServerService from .proxies.utils_proxy import UtilsProxy + from .proxies.web_directory_proxy import WebDirectoryProxy from .vae_proxy import VAERegistry FolderPathsProxy() + HelperProxiesService() ModelManagementProxy() ProgressProxy() PromptServerService() UtilsProxy() + WebDirectoryProxy() VAERegistry() diff --git a/comfy/isolation/host_policy.py b/comfy/isolation/host_policy.py index f9b48cf4d..f637e89d9 100644 --- a/comfy/isolation/host_policy.py +++ b/comfy/isolation/host_policy.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging import os from pathlib import Path +from pathlib import PurePosixPath from typing import Dict, List, TypedDict try: @@ -15,6 +16,7 @@ logger = logging.getLogger(__name__) HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH" VALID_SANDBOX_MODES = frozenset({"required", "disabled"}) +FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"}) class HostSecurityPolicy(TypedDict): @@ -22,14 +24,16 @@ class HostSecurityPolicy(TypedDict): allow_network: bool writable_paths: List[str] readonly_paths: List[str] + sealed_worker_ro_import_paths: List[str] whitelist: Dict[str, str] DEFAULT_POLICY: HostSecurityPolicy = { "sandbox_mode": "required", "allow_network": False, - "writable_paths": ["/dev/shm", "/tmp"], + "writable_paths": ["/dev/shm"], "readonly_paths": [], + "sealed_worker_ro_import_paths": [], "whitelist": {}, } @@ -40,10 +44,68 @@ def _default_policy() -> HostSecurityPolicy: "allow_network": DEFAULT_POLICY["allow_network"], "writable_paths": list(DEFAULT_POLICY["writable_paths"]), "readonly_paths": list(DEFAULT_POLICY["readonly_paths"]), + "sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]), "whitelist": dict(DEFAULT_POLICY["whitelist"]), } +def _normalize_writable_paths(paths: list[object]) -> list[str]: + normalized_paths: list[str] = [] + for raw_path in paths: + # Host-policy paths are contract-style POSIX paths; keep representation + # stable across Windows/Linux so tests and config behavior stay consistent. + normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/"))) + if normalized_path in FORBIDDEN_WRITABLE_PATHS: + continue + normalized_paths.append(normalized_path) + return normalized_paths + + +def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]: + if not file_path.is_absolute(): + file_path = config_path.parent / file_path + if not file_path.exists(): + logger.warning("whitelist_file %s not found, skipping.", file_path) + return {} + entries: Dict[str, str] = {} + for line in file_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + entries[line] = "*" + logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path) + return entries + + +def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]: + if not isinstance(raw_paths, list): + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths." + ) + + normalized_paths: list[str] = [] + seen: set[str] = set() + for raw_path in raw_paths: + if not isinstance(raw_path, str) or not raw_path.strip(): + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings." + ) + normalized_path = str(PurePosixPath(raw_path.replace("\\", "/"))) + # Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...) + is_absolute = normalized_path.startswith("/") or ( + len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/" + ) + if not is_absolute: + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths." + ) + if normalized_path not in seen: + seen.add(normalized_path) + normalized_paths.append(normalized_path) + + return normalized_paths + + def load_host_policy(comfy_root: Path) -> HostSecurityPolicy: config_override = os.environ.get(HOST_POLICY_PATH_ENV) config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml" @@ -86,14 +148,23 @@ def load_host_policy(comfy_root: Path) -> HostSecurityPolicy: policy["allow_network"] = bool(tool_config["allow_network"]) if "writable_paths" in tool_config: - policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]] + policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"]) if "readonly_paths" in tool_config: policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]] + if "sealed_worker_ro_import_paths" in tool_config: + policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths( + tool_config["sealed_worker_ro_import_paths"] + ) + + whitelist_file = tool_config.get("whitelist_file") + if isinstance(whitelist_file, str): + policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path)) + whitelist_raw = tool_config.get("whitelist") if isinstance(whitelist_raw, dict): - policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()} + policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()}) logger.debug( "Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s", diff --git a/comfy/isolation/manifest_loader.py b/comfy/isolation/manifest_loader.py index 42007302f..4ae21d94d 100644 --- a/comfy/isolation/manifest_loader.py +++ b/comfy/isolation/manifest_loader.py @@ -24,6 +24,49 @@ CACHE_SUBDIR = "cache" CACHE_KEY_FILE = "cache_key" CACHE_DATA_FILE = "node_info.json" CACHE_KEY_LENGTH = 16 +_NESTED_SCAN_ROOT = "packages" +_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"} + + +def _read_manifest(manifest_path: Path) -> dict[str, Any] | None: + try: + with manifest_path.open("rb") as f: + data = tomllib.load(f) + if isinstance(data, dict): + return data + except Exception: + return None + return None + + +def _is_isolation_manifest(data: dict[str, Any]) -> bool: + return ( + "tool" in data + and "comfy" in data["tool"] + and "isolation" in data["tool"]["comfy"] + ) + + +def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]: + packages_root = entry / _NESTED_SCAN_ROOT + if not packages_root.exists() or not packages_root.is_dir(): + return [] + + nested: List[Tuple[Path, Path]] = [] + for manifest in sorted(packages_root.rglob("pyproject.toml")): + node_dir = manifest.parent + if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts): + continue + + data = _read_manifest(manifest) + if not data or not _is_isolation_manifest(data): + continue + + isolation = data["tool"]["comfy"]["isolation"] + if isolation.get("standalone") is True: + nested.append((node_dir, manifest)) + + return nested def find_manifest_directories() -> List[Tuple[Path, Path]]: @@ -45,21 +88,13 @@ def find_manifest_directories() -> List[Tuple[Path, Path]]: if not manifest.exists(): continue - # Validate [tool.comfy.isolation] section existence - try: - with manifest.open("rb") as f: - data = tomllib.load(f) - - if ( - "tool" in data - and "comfy" in data["tool"] - and "isolation" in data["tool"]["comfy"] - ): - manifest_dirs.append((entry, manifest)) - - except Exception: + data = _read_manifest(manifest) + if not data or not _is_isolation_manifest(data): continue + manifest_dirs.append((entry, manifest)) + manifest_dirs.extend(_discover_nested_manifests(entry)) + return manifest_dirs diff --git a/comfy/isolation/runtime_helpers.py b/comfy/isolation/runtime_helpers.py index ccd1bed77..f56b1859a 100644 --- a/comfy/isolation/runtime_helpers.py +++ b/comfy/isolation/runtime_helpers.py @@ -8,10 +8,17 @@ from pathlib import Path from typing import Any, Dict, List, Set, TYPE_CHECKING from .proxies.helper_proxies import restore_input_types -from comfy_api.internal import _ComfyNodeInternal -from comfy_api.latest import _io as latest_io from .shm_forensics import scan_shm_forensics +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +_ComfyNodeInternal = object +latest_io = None + +if _IMPORT_TORCH: + from comfy_api.internal import _ComfyNodeInternal + from comfy_api.latest import _io as latest_io + if TYPE_CHECKING: from .extension_wrapper import ComfyNodeExtension @@ -19,6 +26,68 @@ LOG_PREFIX = "][" _PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 +class _RemoteObjectRegistryCaller: + def __init__(self, extension: Any) -> None: + self._extension = extension + + def __getattr__(self, method_name: str) -> Any: + async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any: + return await self._extension.call_remote_object_method( + instance_id, + method_name, + *args, + **kwargs, + ) + + return _call + + +def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any: + from pyisolate._internal.remote_handle import RemoteObjectHandle + + if isinstance(value, RemoteObjectHandle): + if value.type_name == "ModelPatcher": + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "VAE": + from comfy.isolation.vae_proxy import VAEProxy + + proxy = VAEProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "CLIP": + from comfy.isolation.clip_proxy import CLIPProxy + + proxy = CLIPProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "ModelSampling": + from comfy.isolation.model_sampling_proxy import ModelSamplingProxy + + proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + return value + + if isinstance(value, dict): + return { + k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items() + } + + if isinstance(value, (list, tuple)): + wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value] + return type(value)(wrapped) + + return value + + def _resource_snapshot() -> Dict[str, int]: fd_count = -1 shm_sender_files = 0 @@ -146,6 +215,8 @@ def build_stub_class( running_extensions: Dict[str, "ComfyNodeExtension"], logger: logging.Logger, ) -> type: + if latest_io is None: + raise RuntimeError("comfy_api.latest._io is required to build isolation stubs") is_v3 = bool(info.get("is_v3", False)) function_name = "_pyisolate_execute" restored_input_types = restore_input_types(info.get("input_types", {})) @@ -160,6 +231,13 @@ def build_stub_class( node_unique_id = _extract_hidden_unique_id(inputs) summary = _tensor_transport_summary(inputs) resources = _resource_snapshot() + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) logger.debug( "%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d", LOG_PREFIX, @@ -192,7 +270,20 @@ def build_stub_class( node_name, node_unique_id or "-", ) - serialized = serialize_for_isolation(inputs) + # Unwrap NodeOutput-like dicts before serialization. + # OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)} + # and the executor may pass this dict as input to downstream nodes. + unwrapped_inputs = {} + for k, v in inputs.items(): + if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v): + result = v.get("result") + if isinstance(result, (tuple, list)) and len(result) > 0: + unwrapped_inputs[k] = result[0] + else: + unwrapped_inputs[k] = result + else: + unwrapped_inputs[k] = v + serialized = serialize_for_isolation(unwrapped_inputs) logger.debug( "%s ISO:serialize_done ext=%s node=%s uid=%s", LOG_PREFIX, @@ -220,15 +311,32 @@ def build_stub_class( from comfy_api.latest import io as latest_io args_raw = result.get("args", ()) deserialized_args = await deserialize_from_isolation(args_raw, extension) + deserialized_args = _wrap_remote_handles_as_host_proxies( + deserialized_args, extension + ) deserialized_args = _detach_shared_cpu_tensors(deserialized_args) + ui_raw = result.get("ui") + deserialized_ui = None + if ui_raw is not None: + deserialized_ui = await deserialize_from_isolation(ui_raw, extension) + deserialized_ui = _wrap_remote_handles_as_host_proxies( + deserialized_ui, extension + ) + deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui) scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) return latest_io.NodeOutput( *deserialized_args, - ui=result.get("ui"), + ui=deserialized_ui, expand=result.get("expand"), block_execution=result.get("block_execution"), ) + # OUTPUT_NODE: if sealed worker returned a tuple/list whose first + # element is a {"ui": ...} dict, unwrap it for the executor. + if (isinstance(result, (tuple, list)) and len(result) == 1 + and isinstance(result[0], dict) and "ui" in result[0]): + return result[0] deserialized = await deserialize_from_isolation(result, extension) + deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension) scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) return _detach_shared_cpu_tensors(deserialized) except ImportError: