feat: isolation core — adapter, loader, manifest, hooks, runtime helpers

This commit is contained in:
John Pollock 2026-03-29 19:02:33 -05:00
parent c02372936d
commit 878684d8b2
10 changed files with 1067 additions and 227 deletions

View File

@ -8,11 +8,21 @@ import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set, TYPE_CHECKING
import folder_paths
from .extension_loader import load_isolated_node
from .manifest_loader import find_manifest_directories
from .runtime_helpers import build_stub_class, get_class_types_for_extension
from .shm_forensics import scan_shm_forensics, start_shm_forensics
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
load_isolated_node = None
find_manifest_directories = None
build_stub_class = None
get_class_types_for_extension = None
scan_shm_forensics = None
start_shm_forensics = None
if _IMPORT_TORCH:
import folder_paths
from .extension_loader import load_isolated_node
from .manifest_loader import find_manifest_directories
from .runtime_helpers import build_stub_class, get_class_types_for_extension
from .shm_forensics import scan_shm_forensics, start_shm_forensics
if TYPE_CHECKING:
from pyisolate import ExtensionManager
@ -21,8 +31,9 @@ if TYPE_CHECKING:
LOG_PREFIX = "]["
isolated_node_timings: List[tuple[float, Path, int]] = []
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
if _IMPORT_TORCH:
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(__name__)
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
@ -42,7 +53,8 @@ def initialize_proxies() -> None:
from .host_hooks import initialize_host_process
initialize_host_process()
start_shm_forensics()
if start_shm_forensics is not None:
start_shm_forensics()
@dataclass(frozen=True)
@ -88,6 +100,8 @@ async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
return []
_ISOLATION_SCAN_ATTEMPTED = True
if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None:
return []
manifest_entries = find_manifest_directories()
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
@ -167,17 +181,27 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
return class_types
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"""Evict running extensions not needed for current execution."""
async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None:
"""Evict running extensions not needed for current execution.
When *caches* is provided, cache entries for evicted extensions' node
class_types are invalidated to prevent stale ``RemoteObjectHandle``
references from surviving in the output cache.
"""
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:notify_graph_wait_idle",
)
evicted_class_types: Set[str] = set()
async def _stop_extension(
ext_name: str, extension: "ComfyNodeExtension", reason: str
) -> None:
# Collect class_types BEFORE stopping so we can invalidate cache entries.
ext_class_types = _get_class_types_for_extension(ext_name)
evicted_class_types.update(ext_class_types)
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
stop_result = extension.stop()
@ -185,9 +209,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
await stop_result
_RUNNING_EXTENSIONS.pop(ext_name, None)
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
if scan_shm_forensics is not None:
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
if scan_shm_forensics is not None:
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
isolated_class_types_in_graph = needed_class_types.intersection(
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
)
@ -247,6 +273,22 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
)
finally:
# Invalidate cached outputs for evicted extensions so stale
# RemoteObjectHandle references are not served from cache.
if evicted_class_types and caches:
total_invalidated = 0
for cache in caches:
if hasattr(cache, "invalidate_by_class_types"):
total_invalidated += cache.invalidate_by_class_types(
evicted_class_types
)
if total_invalidated > 0:
logger.info(
"%s ISO:cache_invalidated count=%d class_types=%s",
LOG_PREFIX,
total_invalidated,
evicted_class_types,
)
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
logger.debug(
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)

View File

@ -3,13 +3,38 @@ from __future__ import annotations
import logging
import os
import inspect
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, cast
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
try:
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp.
# Safe to import in sealed workers without host framework modules.
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
# Singleton proxies that transitively import torch, PIL, or heavy host modules.
# Only available when torch/host framework is present.
CLIPProxy = None
CLIPRegistry = None
ModelPatcherProxy = None
ModelPatcherRegistry = None
ModelSamplingProxy = None
ModelSamplingRegistry = None
VAEProxy = None
VAERegistry = None
FirstStageModelRegistry = None
ModelManagementProxy = None
PromptServerService = None
ProgressProxy = None
UtilsProxy = None
_HAS_TORCH_PROXIES = False
if _IMPORT_TORCH:
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
from comfy.isolation.model_patcher_proxy import (
ModelPatcherProxy,
@ -20,13 +45,11 @@ try:
ModelSamplingRegistry,
)
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
from comfy.isolation.proxies.utils_proxy import UtilsProxy
from comfy.isolation.proxies.progress_proxy import ProgressProxy
except ImportError as exc: # Fail loud if Comfy environment is incomplete
raise ImportError(f"ComfyUI environment incomplete: {exc}")
from comfy.isolation.proxies.utils_proxy import UtilsProxy
_HAS_TORCH_PROXIES = True
logger = logging.getLogger(__name__)
@ -62,9 +85,20 @@ class ComfyUIAdapter(IsolationAdapter):
os.path.join(comfy_root, "custom_nodes"),
os.path.join(comfy_root, "comfy"),
],
"filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"],
}
return None
def get_sandbox_system_paths(self) -> Optional[List[str]]:
"""Returns required application paths to mount in the sandbox."""
# By inspecting where our adapter is loaded from, we can determine the comfy root
adapter_file = inspect.getfile(self.__class__)
# adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py
comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file)))
if os.path.exists(comfy_root):
return [comfy_root]
return None
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
comfy_root = snapshot.get("preferred_root")
if not comfy_root:
@ -83,6 +117,59 @@ class ComfyUIAdapter(IsolationAdapter):
logging.getLogger(pkg_name).setLevel(logging.ERROR)
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
if not _IMPORT_TORCH:
# Sealed worker without torch — register torch-free TensorValue handler
# so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts.
import numpy as np
_TORCH_DTYPE_TO_NUMPY = {
"torch.float32": np.float32,
"torch.float64": np.float64,
"torch.float16": np.float16,
"torch.bfloat16": np.float32, # numpy has no bfloat16; upcast
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.int16": np.int16,
"torch.int8": np.int8,
"torch.uint8": np.uint8,
"torch.bool": np.bool_,
}
def _deserialize_tensor_value(data: Dict[str, Any]) -> Any:
dtype_str = data["dtype"]
np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32)
shape = tuple(data["tensor_size"])
arr = np.array(data["data"], dtype=np_dtype).reshape(shape)
return arr
_NUMPY_TO_TORCH_DTYPE = {
np.float32: "torch.float32",
np.float64: "torch.float64",
np.float16: "torch.float16",
np.int32: "torch.int32",
np.int64: "torch.int64",
np.int16: "torch.int16",
np.int8: "torch.int8",
np.uint8: "torch.uint8",
np.bool_: "torch.bool",
}
def _serialize_tensor_value(obj: Any) -> Dict[str, Any]:
arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj)
dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32")
return {
"__type__": "TensorValue",
"dtype": dtype_str,
"tensor_size": list(arr.shape),
"requires_grad": False,
"data": arr.tolist(),
}
registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
# ndarray output from sealed workers serializes as TensorValue for host torch reconstruction
registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
return
import torch
def serialize_device(obj: Any) -> Dict[str, Any]:
@ -110,6 +197,54 @@ class ComfyUIAdapter(IsolationAdapter):
registry.register("dtype", serialize_dtype, deserialize_dtype)
from comfy_api.latest._io import FolderType
from comfy_api.latest._ui import SavedImages, SavedResult
def serialize_saved_result(obj: Any) -> Dict[str, Any]:
return {
"__type__": "SavedResult",
"filename": obj.filename,
"subfolder": obj.subfolder,
"folder_type": obj.type.value,
}
def deserialize_saved_result(data: Dict[str, Any]) -> Any:
if isinstance(data, SavedResult):
return data
folder_type = data["folder_type"] if "folder_type" in data else data["type"]
return SavedResult(
filename=data["filename"],
subfolder=data["subfolder"],
type=FolderType(folder_type),
)
registry.register(
"SavedResult",
serialize_saved_result,
deserialize_saved_result,
data_type=True,
)
def serialize_saved_images(obj: Any) -> Dict[str, Any]:
return {
"__type__": "SavedImages",
"results": [serialize_saved_result(result) for result in obj.results],
"is_animated": obj.is_animated,
}
def deserialize_saved_images(data: Dict[str, Any]) -> Any:
return SavedImages(
results=[deserialize_saved_result(result) for result in data["results"]],
is_animated=data.get("is_animated", False),
)
registry.register(
"SavedImages",
serialize_saved_images,
deserialize_saved_images,
data_type=True,
)
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
@ -212,28 +347,93 @@ class ComfyUIAdapter(IsolationAdapter):
# copyreg removed - no pickle fallback allowed
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
raise RuntimeError(
f"ModelSampling in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}"
)
# Host-side pass-through for proxies: do not re-register a proxy as a
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
# Proxy with _instance_id — return ref (works from both host and child)
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
# Child-side: object created locally in child (e.g. ModelSamplingAdvanced
# in nodes_z_image_turbo.py). Serialize as inline data so the host can
# reconstruct the real torch.nn.Module.
if os.environ.get("PYISOLATE_CHILD") == "1":
import base64
import io as _io
# Identify base classes from comfy.model_sampling
bases = []
for base in type(obj).__mro__:
if base.__module__ == "comfy.model_sampling" and base.__name__ != "object":
bases.append(base.__name__)
# Serialize state_dict as base64 safetensors-like
sd = obj.state_dict()
sd_serialized = {}
for k, v in sd.items():
buf = _io.BytesIO()
torch.save(v, buf)
sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii")
# Capture plain attrs (shift, multiplier, sigma_data, etc.)
plain_attrs = {}
for k, v in obj.__dict__.items():
if k.startswith("_"):
continue
if isinstance(v, (bool, int, float, str)):
plain_attrs[k] = v
return {
"__type__": "ModelSamplingInline",
"bases": bases,
"state_dict": sd_serialized,
"attrs": plain_attrs,
}
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
ms_id = ModelSamplingRegistry().register(obj)
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
def deserialize_model_sampling(data: Any) -> Any:
"""Deserialize ModelSampling refs; pass through already-materialized objects."""
"""Deserialize ModelSampling refs or inline data."""
if isinstance(data, dict):
if data.get("__type__") == "ModelSamplingInline":
return _reconstruct_model_sampling_inline(data)
return ModelSamplingProxy(data["ms_id"])
return data
def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any:
"""Reconstruct a ModelSampling object on the host from inline child data."""
import comfy.model_sampling as _ms
import base64
import io as _io
# Resolve base classes
base_classes = []
for name in data["bases"]:
cls = getattr(_ms, name, None)
if cls is not None:
base_classes.append(cls)
if not base_classes:
raise RuntimeError(
f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}"
)
# Create dynamic class matching the child's class hierarchy
ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {})
obj = ReconstructedSampling.__new__(ReconstructedSampling)
torch.nn.Module.__init__(obj)
# Restore plain attributes first
for k, v in data.get("attrs", {}).items():
setattr(obj, k, v)
# Restore state_dict (buffers like sigmas)
for k, v_b64 in data.get("state_dict", {}).items():
buf = _io.BytesIO(base64.b64decode(v_b64))
tensor = torch.load(buf, weights_only=True)
# Register as buffer so it's part of state_dict
parts = k.split(".")
if len(parts) == 1:
cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member
else:
setattr(obj, parts[0], tensor)
# Register on host so future references use proxy pattern.
# Skip in child process — register() is async RPC and cannot be
# called synchronously during deserialization.
if os.environ.get("PYISOLATE_CHILD") != "1":
ModelSamplingRegistry().register(obj)
return obj
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
"""Context-aware ModelSamplingRef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
@ -262,6 +462,10 @@ class ComfyUIAdapter(IsolationAdapter):
)
# Register ModelSamplingRef for deserialization (context-aware: host or child)
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
# Register ModelSamplingInline for deserialization (child→host inline transfer)
registry.register(
"ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data)
)
def serialize_cond(obj: Any) -> Dict[str, Any]:
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
@ -393,52 +597,20 @@ class ComfyUIAdapter(IsolationAdapter):
# Fallback for non-numeric arrays (strings, objects, mixes)
return obj.tolist()
registry.register("ndarray", serialize_numpy, None)
def serialize_ply(obj: Any) -> Dict[str, Any]:
def deserialize_numpy_b64(data: Any) -> Any:
"""Deserialize base64-encoded ndarray from sealed worker."""
import base64
import torch
if obj.raw_data is not None:
return {
"__type__": "PLY",
"raw_data": base64.b64encode(obj.raw_data).decode("ascii"),
}
result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)}
if obj.colors is not None:
result["colors"] = torch.from_numpy(obj.colors)
if obj.confidence is not None:
result["confidence"] = torch.from_numpy(obj.confidence)
if obj.view_id is not None:
result["view_id"] = torch.from_numpy(obj.view_id)
return result
import numpy as np
if isinstance(data, dict) and "data" in data and "dtype" in data:
raw = base64.b64decode(data["data"])
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
return torch.from_numpy(arr.copy())
return data
def deserialize_ply(data: Any) -> Any:
import base64
from comfy_api.latest._util.ply_types import PLY
if "raw_data" in data:
return PLY(raw_data=base64.b64decode(data["raw_data"]))
return PLY(
points=data["points"],
colors=data.get("colors"),
confidence=data.get("confidence"),
view_id=data.get("view_id"),
)
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
registry.register("PLY", serialize_ply, deserialize_ply, data_type=True)
def serialize_npz(obj: Any) -> Dict[str, Any]:
import base64
return {
"__type__": "NPZ",
"frames": [base64.b64encode(f).decode("ascii") for f in obj.frames],
}
def deserialize_npz(data: Any) -> Any:
import base64
from comfy_api.latest._util.npz_types import NPZ
return NPZ(frames=[base64.b64decode(f) for f in data["frames"]])
registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True)
# -- File3D (comfy_api.latest._util.geometry_types) ---------------------
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
def serialize_file3d(obj: Any) -> Dict[str, Any]:
import base64
@ -456,6 +628,9 @@ class ComfyUIAdapter(IsolationAdapter):
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
# -- VIDEO (comfy_api.latest._input_impl.video_types) -------------------
# Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962
def serialize_video(obj: Any) -> Dict[str, Any]:
components = obj.get_components()
images = components.images.detach() if components.images.requires_grad else components.images
@ -494,19 +669,156 @@ class ComfyUIAdapter(IsolationAdapter):
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
def setup_web_directory(self, module: Any) -> None:
"""Detect WEB_DIRECTORY on a module and populate/register it.
Called by the sealed worker after loading the node module.
Mirrors extension_wrapper.py:216-227 for host-coupled nodes.
Does NOT import extension_wrapper.py (it has `import torch` at module level).
"""
import shutil
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
if web_dir_attr is None:
return
module_dir = os.path.dirname(os.path.abspath(module.__file__))
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
# Read extension name from pyproject.toml
ext_name = os.path.basename(module_dir)
pyproject = os.path.join(module_dir, "pyproject.toml")
if os.path.exists(pyproject):
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
try:
with open(pyproject, "rb") as f:
data = tomllib.load(f)
name = data.get("project", {}).get("name")
if name:
ext_name = name
except Exception:
pass
# Populate web dir if empty (mirrors _run_prestartup_web_copy)
if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))):
os.makedirs(web_dir_path, exist_ok=True)
# Module-defined copy spec
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
if copy_spec is not None and callable(copy_spec):
try:
copy_spec(web_dir_path)
except Exception as e:
logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e)
# Fallback: comfy_3d_viewers
try:
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
for viewer in VIEWER_FILES:
try:
copy_viewer(viewer, web_dir_path)
except Exception:
pass
except ImportError:
pass
# Fallback: comfy_dynamic_widgets
try:
from comfy_dynamic_widgets import get_js_path
src = os.path.realpath(get_js_path())
if os.path.exists(src):
dst_dir = os.path.join(web_dir_path, "js")
os.makedirs(dst_dir, exist_ok=True)
shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js"))
except ImportError:
pass
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
logger.info(
"][ Adapter: registered web dir for %s (%d files)",
ext_name,
sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()),
)
@staticmethod
def register_host_event_handlers(extension: Any) -> None:
"""Register host-side event handlers for an isolated extension.
Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK``
so the ComfyUI frontend receives progress bar updates.
"""
register_event_handler = inspect.getattr_static(
extension, "register_event_handler", None
)
if not callable(register_event_handler):
return
def _host_progress_handler(payload: dict) -> None:
import comfy.utils
hook = comfy.utils.PROGRESS_BAR_HOOK
if hook is not None:
hook(
payload.get("value", 0),
payload.get("total", 0),
payload.get("preview"),
payload.get("node_id"),
)
extension.register_event_handler("progress", _host_progress_handler)
def setup_child_event_hooks(self, extension: Any) -> None:
"""Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension.
Host-coupled only sealed workers do not have comfy.utils (torch).
"""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child)
if not is_child:
return
if not _IMPORT_TORCH:
logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)")
return
import comfy.utils
def _event_progress_hook(value, total, preview=None, node_id=None):
logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id)
extension.emit_event("progress", {
"value": value,
"total": total,
"node_id": node_id,
})
comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook
logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel")
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
return [
PromptServerService,
# Always available — no torch/PIL dependency
services: List[type[ProxiedSingleton]] = [
FolderPathsProxy,
ModelManagementProxy,
UtilsProxy,
ProgressProxy,
VAERegistry,
CLIPRegistry,
ModelPatcherRegistry,
ModelSamplingRegistry,
FirstStageModelRegistry,
HelperProxiesService,
WebDirectoryProxy,
]
# Torch/PIL-dependent proxies
if _HAS_TORCH_PROXIES:
services.extend([
PromptServerService,
ModelManagementProxy,
UtilsProxy,
ProgressProxy,
VAERegistry,
CLIPRegistry,
ModelPatcherRegistry,
ModelSamplingRegistry,
FirstStageModelRegistry,
])
return services
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
# Resolve the real name whether it's an instance or the Singleton class itself
@ -525,33 +837,45 @@ class ComfyUIAdapter(IsolationAdapter):
# Fence: isolated children get writable temp inside sandbox
if os.environ.get("PYISOLATE_CHILD") == "1":
_child_temp = os.path.join("/tmp", "comfyui_temp")
import tempfile
_child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp")
os.makedirs(_child_temp, exist_ok=True)
folder_paths.temp_directory = _child_temp
return
if api_name == "ModelManagementProxy":
import comfy.model_management
if _IMPORT_TORCH:
import comfy.model_management
instance = api() if isinstance(api, type) else api
# Replace module-level functions with proxy methods
for name in dir(instance):
if not name.startswith("_"):
setattr(comfy.model_management, name, getattr(instance, name))
instance = api() if isinstance(api, type) else api
# Replace module-level functions with proxy methods
for name in dir(instance):
if not name.startswith("_"):
setattr(comfy.model_management, name, getattr(instance, name))
return
if api_name == "UtilsProxy":
if not _IMPORT_TORCH:
logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)")
return
import comfy.utils
# Static Injection of RPC mechanism to ensure Child can access it
# independent of instance lifecycle.
api.set_rpc(rpc)
# Don't overwrite host hook (infinite recursion)
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child)
# Progress hook wiring moved to setup_child_event_hooks via event channel
return
if api_name == "PromptServerProxy":
if not _IMPORT_TORCH:
return
# Defer heavy import to child context
import server

View File

@ -11,130 +11,90 @@ def is_child_process() -> bool:
def initialize_child_process() -> None:
_setup_child_loop_bridge()
# Manual RPC injection
try:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc:
_setup_prompt_server_stub(rpc)
_setup_utils_proxy(rpc)
_setup_proxy_callers(rpc)
else:
logger.warning("Could not get child RPC instance for manual injection")
_setup_prompt_server_stub()
_setup_utils_proxy()
_setup_proxy_callers()
except Exception as e:
logger.error(f"Manual RPC Injection failed: {e}")
_setup_prompt_server_stub()
_setup_utils_proxy()
_setup_proxy_callers()
_setup_logging()
def _setup_child_loop_bridge() -> None:
import asyncio
main_loop = None
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
try:
main_loop = asyncio.get_event_loop()
except RuntimeError:
pass
if main_loop is None:
return
try:
from .proxies.base import set_global_loop
set_global_loop(main_loop)
except ImportError:
pass
def _setup_prompt_server_stub(rpc=None) -> None:
try:
from .proxies.prompt_server_impl import PromptServerStub
import sys
import types
# Mock server module
if "server" not in sys.modules:
mock_server = types.ModuleType("server")
sys.modules["server"] = mock_server
server = sys.modules["server"]
if not hasattr(server, "PromptServer"):
class MockPromptServer:
pass
server.PromptServer = MockPromptServer
stub = PromptServerStub()
if rpc:
PromptServerStub.set_rpc(rpc)
if hasattr(stub, "set_rpc"):
stub.set_rpc(rpc)
server.PromptServer.instance = stub
elif hasattr(PromptServerStub, "clear_rpc"):
PromptServerStub.clear_rpc()
else:
PromptServerStub._rpc = None # type: ignore[attr-defined]
except Exception as e:
logger.error(f"Failed to setup PromptServerStub: {e}")
def _setup_utils_proxy(rpc=None) -> None:
def _setup_proxy_callers(rpc=None) -> None:
try:
import comfy.utils
import asyncio
from .proxies.folder_paths_proxy import FolderPathsProxy
from .proxies.helper_proxies import HelperProxiesService
from .proxies.model_management_proxy import ModelManagementProxy
from .proxies.progress_proxy import ProgressProxy
from .proxies.prompt_server_impl import PromptServerStub
from .proxies.utils_proxy import UtilsProxy
# Capture main loop during initialization (safe context)
main_loop = None
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
try:
main_loop = asyncio.get_event_loop()
except RuntimeError:
pass
if rpc is None:
FolderPathsProxy.clear_rpc()
HelperProxiesService.clear_rpc()
ModelManagementProxy.clear_rpc()
ProgressProxy.clear_rpc()
PromptServerStub.clear_rpc()
UtilsProxy.clear_rpc()
return
try:
from .proxies.base import set_global_loop
if main_loop:
set_global_loop(main_loop)
except ImportError:
pass
# Sync hook wrapper for progress updates
def sync_hook_wrapper(
value: int, total: int, preview: None = None, node_id: None = None
) -> None:
if node_id is None:
try:
from comfy_execution.utils import get_executing_context
ctx = get_executing_context()
if ctx:
node_id = ctx.node_id
else:
pass
except Exception:
pass
# Bypass blocked event loop by direct outbox injection
if rpc:
try:
# Use captured main loop if available (for threaded execution), or current loop
loop = main_loop
if loop is None:
loop = asyncio.get_event_loop()
rpc.outbox.put(
{
"kind": "call",
"object_id": "UtilsProxy",
"parent_call_id": None, # We are root here usually
"calling_loop": loop,
"future": loop.create_future(), # Dummy future
"method": "progress_bar_hook",
"args": (value, total, preview, node_id),
"kwargs": {},
}
)
except Exception as e:
logging.getLogger(__name__).error(f"Manual Inject Failed: {e}")
else:
logging.getLogger(__name__).warning(
"No RPC instance available for progress update"
)
comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper
FolderPathsProxy.set_rpc(rpc)
HelperProxiesService.set_rpc(rpc)
ModelManagementProxy.set_rpc(rpc)
ProgressProxy.set_rpc(rpc)
PromptServerStub.set_rpc(rpc)
UtilsProxy.set_rpc(rpc)
except Exception as e:
logger.error(f"Failed to setup UtilsProxy hook: {e}")
logger.error(f"Failed to setup child singleton proxy callers: {e}")
def _setup_logging() -> None:

View File

@ -0,0 +1,16 @@
"""Compatibility shim for the indexed serializer path."""
from __future__ import annotations
from typing import Any
def register_custom_node_serializers(_registry: Any) -> None:
"""Legacy no-op shim.
Serializer registration now lives directly in the active isolation adapter.
This module remains importable because the isolation index still references it.
"""
return None
__all__ = ["register_custom_node_serializers"]

View File

@ -8,14 +8,13 @@ import sys
import types
import platform
from pathlib import Path
from typing import Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Tuple
import pyisolate
from pyisolate import ExtensionManager, ExtensionManagerConfig
from packaging.requirements import InvalidRequirement, Requirement
from packaging.utils import canonicalize_name
from .extension_wrapper import ComfyNodeExtension
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
from .host_policy import load_host_policy
@ -42,7 +41,11 @@ def _register_web_directory(extension_name: str, node_dir: Path) -> None:
web_dir_path = str(node_dir / web_dir_name)
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
logger.debug(
"][ Registered web dir for isolated %s: %s",
extension_name,
web_dir_path,
)
return
except Exception:
pass
@ -62,15 +65,26 @@ def _register_web_directory(extension_name: str, node_dir: Path) -> None:
web_dir_path = str((node_dir / value).resolve())
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
logger.debug(
"][ Registered web dir for isolated %s: %s",
extension_name,
web_dir_path,
)
return
except Exception:
pass
async def _stop_extension_safe(
extension: ComfyNodeExtension, extension_name: str
) -> None:
def _get_extension_type(execution_model: str) -> type[Any]:
if execution_model == "sealed_worker":
return pyisolate.SealedNodeExtension
from .extension_wrapper import ComfyNodeExtension
return ComfyNodeExtension
async def _stop_extension_safe(extension: Any, extension_name: str) -> None:
try:
stop_result = extension.stop()
if inspect.isawaitable(stop_result):
@ -126,12 +140,18 @@ def _parse_cuda_wheels_config(
if raw_config is None:
return None
if not isinstance(raw_config, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels] must be a table"
)
raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table")
index_url = raw_config.get("index_url")
if not isinstance(index_url, str) or not index_url.strip():
index_urls = raw_config.get("index_urls")
if index_urls is not None:
if not isinstance(index_urls, list) or not all(
isinstance(u, str) and u.strip() for u in index_urls
):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings"
)
elif not isinstance(index_url, str) or not index_url.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
)
@ -187,11 +207,15 @@ def _parse_cuda_wheels_config(
)
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
return {
"index_url": index_url.rstrip("/") + "/",
result: dict = {
"packages": normalized_packages,
"package_map": normalized_package_map,
}
if index_urls is not None:
result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls]
else:
result["index_url"] = index_url.rstrip("/") + "/"
return result
def get_enforcement_policy() -> Dict[str, bool]:
@ -228,7 +252,7 @@ async def load_isolated_node(
node_dir: Path,
manifest_path: Path,
logger: logging.Logger,
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
build_stub_class: Callable[[str, Dict[str, object], Any], type],
venv_root: Path,
extension_managers: List[ExtensionManager],
) -> List[Tuple[str, str, type]]:
@ -243,6 +267,31 @@ async def load_isolated_node(
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
can_isolate = tool_config.get("can_isolate", False)
share_torch = tool_config.get("share_torch", False)
package_manager = tool_config.get("package_manager", "uv")
is_conda = package_manager == "conda"
execution_model = tool_config.get("execution_model")
if execution_model is None:
execution_model = "sealed_worker" if is_conda else "host-coupled"
if "sealed_host_ro_paths" in tool_config:
raise ValueError(
"Manifest field 'sealed_host_ro_paths' is not allowed. "
"Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy."
)
# Conda-specific manifest fields
conda_channels: list[str] = (
tool_config.get("conda_channels", []) if is_conda else []
)
conda_dependencies: list[str] = (
tool_config.get("conda_dependencies", []) if is_conda else []
)
conda_platforms: list[str] = (
tool_config.get("conda_platforms", []) if is_conda else []
)
conda_python: str = (
tool_config.get("conda_python", "*") if is_conda else "*"
)
# Parse [project] dependencies
project_config = manifest_data.get("project", {})
@ -260,8 +309,6 @@ async def load_isolated_node(
if not isolated:
return []
logger.info(f"][ Loading isolated node: {extension_name}")
import folder_paths
base_paths = [Path(folder_paths.base_path), node_dir]
@ -272,8 +319,9 @@ async def load_isolated_node(
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
extension_type = _get_extension_type(execution_model)
manager: ExtensionManager = pyisolate.ExtensionManager(
ComfyNodeExtension, manager_config
extension_type, manager_config
)
extension_managers.append(manager)
@ -281,15 +329,21 @@ async def load_isolated_node(
sandbox_config = {}
is_linux = platform.system() == "Linux"
if is_conda:
share_torch = False
share_cuda_ipc = False
else:
share_cuda_ipc = share_torch and is_linux
if is_linux and isolated:
sandbox_config = {
"network": host_policy["allow_network"],
"writable_paths": host_policy["writable_paths"],
"readonly_paths": host_policy["readonly_paths"],
}
share_cuda_ipc = share_torch and is_linux
extension_config = {
extension_config: dict = {
"name": extension_name,
"module_path": str(node_dir),
"isolated": True,
@ -299,17 +353,60 @@ async def load_isolated_node(
"sandbox_mode": host_policy["sandbox_mode"],
"sandbox": sandbox_config,
}
_is_sealed = execution_model == "sealed_worker"
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
logger.info(
"][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])",
extension_name,
"x" if share_torch else " ",
"x" if _is_sealed else " ",
"x" if _is_sandboxed else " ",
)
if cuda_wheels is not None:
extension_config["cuda_wheels"] = cuda_wheels
# Conda-specific keys
if is_conda:
extension_config["package_manager"] = "conda"
extension_config["conda_channels"] = conda_channels
extension_config["conda_dependencies"] = conda_dependencies
extension_config["conda_python"] = conda_python
find_links = tool_config.get("find_links", [])
if find_links:
extension_config["find_links"] = find_links
if conda_platforms:
extension_config["conda_platforms"] = conda_platforms
if execution_model != "host-coupled":
extension_config["execution_model"] = execution_model
if execution_model == "sealed_worker":
policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", [])
if isinstance(policy_ro_paths, list) and policy_ro_paths:
extension_config["sealed_host_ro_paths"] = list(policy_ro_paths)
# Sealed workers keep the host RPC service inventory even when the
# child resolves no API classes locally.
extension = manager.load_extension(extension_config)
register_dummy_module(extension_name, node_dir)
# Register host-side event handlers via adapter
from .adapter import ComfyUIAdapter
ComfyUIAdapter.register_host_event_handlers(extension)
# Register web directory on the host — only when sandbox is disabled.
# In sandbox mode, serving untrusted JS to the browser is not safe.
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# Register for proxied web serving — the child's web dir may have
# content that doesn't exist on the host (e.g., pip-installed viewer
# bundles). The WebDirectoryCache will lazily fetch via RPC.
from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache
cache = get_web_directory_cache()
cache.register_proxy(extension_name, WebDirectoryProxy())
# Try cache first (lazy spawn)
if is_cache_valid(node_dir, manifest_path, venv_root):
cached_data = load_from_cache(node_dir, venv_root)
@ -379,6 +476,10 @@ async def load_isolated_node(
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
logger.debug(f"][ {extension_name} metadata cached")
# Re-check web directory AFTER child has populated it
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# EJECT: Kill process after getting metadata (will respawn on first execution)
await _stop_extension_safe(extension, extension_name)

View File

@ -37,6 +37,88 @@ _PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
logger = logging.getLogger(__name__)
def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None:
"""Run the web asset copy step that prestartup_script.py used to do.
If the module's web/ directory is empty and the module had a
prestartup_script.py that copied assets from pip packages, this
function replicates that work inside the child process.
Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if
defined, otherwise falls back to detecting common asset packages.
"""
import shutil
# Already populated — nothing to do
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
return
os.makedirs(web_dir_path, exist_ok=True)
# Try module-defined copy spec first (generic hook for any node pack)
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
if copy_spec is not None and callable(copy_spec):
try:
copy_spec(web_dir_path)
logger.info(
"%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir
)
return
except Exception as e:
logger.warning(
"%s _PRESTARTUP_WEB_COPY failed for %s: %s",
LOG_PREFIX, module_dir, e,
)
# Fallback: detect comfy_3d_viewers and run copy_viewer()
try:
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
viewers = list(VIEWER_FILES.keys())
for viewer in viewers:
try:
copy_viewer(viewer, web_dir_path)
except Exception:
pass
if any(os.scandir(web_dir_path)):
logger.info(
"%s Copied %d viewer types from comfy_3d_viewers to %s",
LOG_PREFIX, len(viewers), web_dir_path,
)
except ImportError:
pass
# Fallback: detect comfy_dynamic_widgets
try:
from comfy_dynamic_widgets import get_js_path
src = os.path.realpath(get_js_path())
if os.path.exists(src):
dst_dir = os.path.join(web_dir_path, "js")
os.makedirs(dst_dir, exist_ok=True)
dst = os.path.join(dst_dir, "dynamic_widgets.js")
shutil.copy2(src, dst)
except ImportError:
pass
def _read_extension_name(module_dir: str) -> str:
"""Read extension name from pyproject.toml, falling back to directory name."""
pyproject = os.path.join(module_dir, "pyproject.toml")
if os.path.exists(pyproject):
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
try:
with open(pyproject, "rb") as f:
data = tomllib.load(f)
name = data.get("project", {}).get("name")
if name:
return name
except Exception:
pass
return os.path.basename(module_dir)
def _flush_tensor_transport_state(marker: str) -> int:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
@ -130,6 +212,20 @@ class ComfyNodeExtension(ExtensionBase):
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
# Register web directory with WebDirectoryProxy (child-side)
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
if web_dir_attr is not None:
module_dir = os.path.dirname(os.path.abspath(module.__file__))
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
ext_name = _read_extension_name(module_dir)
# If web dir is empty, run the copy step that prestartup_script.py did
_run_prestartup_web_copy(module, module_dir, web_dir_path)
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
try:
from comfy_api.latest import ComfyExtension
@ -365,6 +461,12 @@ class ComfyNodeExtension(ExtensionBase):
for key, value in hidden_found.items():
setattr(node_cls.hidden, key.value.lower(), value)
# INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this
# flag is set. The isolation RPC delivers unwrapped values, so we must
# wrap each input in a single-element list to match the contract.
if getattr(node_cls, "INPUT_IS_LIST", False):
resolved_inputs = {k: [v] for k, v in resolved_inputs.items()}
function_name = getattr(node_cls, "FUNCTION", "execute")
if not hasattr(instance, function_name):
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
@ -402,7 +504,7 @@ class ComfyNodeExtension(ExtensionBase):
"args": self._wrap_unpicklable_objects(result.args),
}
if result.ui is not None:
node_output_dict["ui"] = result.ui
node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui)
if getattr(result, "expand", None) is not None:
node_output_dict["expand"] = result.expand
if getattr(result, "block_execution", None) is not None:
@ -454,6 +556,85 @@ class ComfyNodeExtension(ExtensionBase):
return self.remote_objects[object_id]
def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle:
object_id = str(uuid.uuid4())
self.remote_objects[object_id] = obj
return RemoteObjectHandle(object_id, type(obj).__name__)
async def call_remote_object_method(
self,
object_id: str,
method_name: str,
*args: Any,
**kwargs: Any,
) -> Any:
"""Invoke a method or attribute-backed accessor on a child-owned object."""
obj = await self.get_remote_object(object_id)
if method_name == "get_patcher_attr":
return getattr(obj, args[0])
if method_name == "get_model_options":
return getattr(obj, "model_options")
if method_name == "set_model_options":
setattr(obj, "model_options", args[0])
return None
if method_name == "get_object_patches":
return getattr(obj, "object_patches")
if method_name == "get_patches":
return getattr(obj, "patches")
if method_name == "get_wrappers":
return getattr(obj, "wrappers")
if method_name == "get_callbacks":
return getattr(obj, "callbacks")
if method_name == "get_load_device":
return getattr(obj, "load_device")
if method_name == "get_offload_device":
return getattr(obj, "offload_device")
if method_name == "get_hook_mode":
return getattr(obj, "hook_mode")
if method_name == "get_parent":
parent = getattr(obj, "parent", None)
if parent is None:
return None
return self._store_remote_object_handle(parent)
if method_name == "get_inner_model_attr":
attr_name = args[0]
if hasattr(obj.model, attr_name):
return getattr(obj.model, attr_name)
if hasattr(obj, attr_name):
return getattr(obj, attr_name)
return None
if method_name == "inner_model_apply_model":
return obj.model.apply_model(*args[0], **args[1])
if method_name == "inner_model_extra_conds_shapes":
return obj.model.extra_conds_shapes(*args[0], **args[1])
if method_name == "inner_model_extra_conds":
return obj.model.extra_conds(*args[0], **args[1])
if method_name == "inner_model_memory_required":
return obj.model.memory_required(*args[0], **args[1])
if method_name == "process_latent_in":
return obj.model.process_latent_in(*args[0], **args[1])
if method_name == "process_latent_out":
return obj.model.process_latent_out(*args[0], **args[1])
if method_name == "scale_latent_inpaint":
return obj.model.scale_latent_inpaint(*args[0], **args[1])
if method_name.startswith("get_"):
attr_name = method_name[4:]
if hasattr(obj, attr_name):
return getattr(obj, attr_name)
target = getattr(obj, method_name)
if callable(target):
result = target(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
if type(result).__name__ == "ModelPatcher":
return self._store_remote_object_handle(result)
return result
if args or kwargs:
raise TypeError(f"{method_name} is not callable on remote object {object_id}")
return target
def _wrap_unpicklable_objects(self, data: Any) -> Any:
if isinstance(data, (str, int, float, bool, type(None))):
return data
@ -514,9 +695,7 @@ class ComfyNodeExtension(ExtensionBase):
if serializer:
return serializer(data)
object_id = str(uuid.uuid4())
self.remote_objects[object_id] = data
return RemoteObjectHandle(object_id, type(data).__name__)
return self._store_remote_object_handle(data)
def _resolve_remote_objects(self, data: Any) -> Any:
if isinstance(data, RemoteObjectHandle):

View File

@ -12,15 +12,19 @@ def initialize_host_process() -> None:
root.addHandler(logging.NullHandler())
from .proxies.folder_paths_proxy import FolderPathsProxy
from .proxies.helper_proxies import HelperProxiesService
from .proxies.model_management_proxy import ModelManagementProxy
from .proxies.progress_proxy import ProgressProxy
from .proxies.prompt_server_impl import PromptServerService
from .proxies.utils_proxy import UtilsProxy
from .proxies.web_directory_proxy import WebDirectoryProxy
from .vae_proxy import VAERegistry
FolderPathsProxy()
HelperProxiesService()
ModelManagementProxy()
ProgressProxy()
PromptServerService()
UtilsProxy()
WebDirectoryProxy()
VAERegistry()

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import os
from pathlib import Path
from pathlib import PurePosixPath
from typing import Dict, List, TypedDict
try:
@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"})
class HostSecurityPolicy(TypedDict):
@ -22,14 +24,16 @@ class HostSecurityPolicy(TypedDict):
allow_network: bool
writable_paths: List[str]
readonly_paths: List[str]
sealed_worker_ro_import_paths: List[str]
whitelist: Dict[str, str]
DEFAULT_POLICY: HostSecurityPolicy = {
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": ["/dev/shm", "/tmp"],
"writable_paths": ["/dev/shm"],
"readonly_paths": [],
"sealed_worker_ro_import_paths": [],
"whitelist": {},
}
@ -40,10 +44,68 @@ def _default_policy() -> HostSecurityPolicy:
"allow_network": DEFAULT_POLICY["allow_network"],
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
"sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]),
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
}
def _normalize_writable_paths(paths: list[object]) -> list[str]:
normalized_paths: list[str] = []
for raw_path in paths:
# Host-policy paths are contract-style POSIX paths; keep representation
# stable across Windows/Linux so tests and config behavior stay consistent.
normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/")))
if normalized_path in FORBIDDEN_WRITABLE_PATHS:
continue
normalized_paths.append(normalized_path)
return normalized_paths
def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]:
if not file_path.is_absolute():
file_path = config_path.parent / file_path
if not file_path.exists():
logger.warning("whitelist_file %s not found, skipping.", file_path)
return {}
entries: Dict[str, str] = {}
for line in file_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
entries[line] = "*"
logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path)
return entries
def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]:
if not isinstance(raw_paths, list):
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths."
)
normalized_paths: list[str] = []
seen: set[str] = set()
for raw_path in raw_paths:
if not isinstance(raw_path, str) or not raw_path.strip():
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings."
)
normalized_path = str(PurePosixPath(raw_path.replace("\\", "/")))
# Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...)
is_absolute = normalized_path.startswith("/") or (
len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/"
)
if not is_absolute:
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths."
)
if normalized_path not in seen:
seen.add(normalized_path)
normalized_paths.append(normalized_path)
return normalized_paths
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
@ -86,14 +148,23 @@ def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
policy["allow_network"] = bool(tool_config["allow_network"])
if "writable_paths" in tool_config:
policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]]
policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"])
if "readonly_paths" in tool_config:
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
if "sealed_worker_ro_import_paths" in tool_config:
policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths(
tool_config["sealed_worker_ro_import_paths"]
)
whitelist_file = tool_config.get("whitelist_file")
if isinstance(whitelist_file, str):
policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path))
whitelist_raw = tool_config.get("whitelist")
if isinstance(whitelist_raw, dict):
policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()}
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
logger.debug(
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",

View File

@ -24,6 +24,49 @@ CACHE_SUBDIR = "cache"
CACHE_KEY_FILE = "cache_key"
CACHE_DATA_FILE = "node_info.json"
CACHE_KEY_LENGTH = 16
_NESTED_SCAN_ROOT = "packages"
_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"}
def _read_manifest(manifest_path: Path) -> dict[str, Any] | None:
try:
with manifest_path.open("rb") as f:
data = tomllib.load(f)
if isinstance(data, dict):
return data
except Exception:
return None
return None
def _is_isolation_manifest(data: dict[str, Any]) -> bool:
return (
"tool" in data
and "comfy" in data["tool"]
and "isolation" in data["tool"]["comfy"]
)
def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]:
packages_root = entry / _NESTED_SCAN_ROOT
if not packages_root.exists() or not packages_root.is_dir():
return []
nested: List[Tuple[Path, Path]] = []
for manifest in sorted(packages_root.rglob("pyproject.toml")):
node_dir = manifest.parent
if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts):
continue
data = _read_manifest(manifest)
if not data or not _is_isolation_manifest(data):
continue
isolation = data["tool"]["comfy"]["isolation"]
if isolation.get("standalone") is True:
nested.append((node_dir, manifest))
return nested
def find_manifest_directories() -> List[Tuple[Path, Path]]:
@ -45,21 +88,13 @@ def find_manifest_directories() -> List[Tuple[Path, Path]]:
if not manifest.exists():
continue
# Validate [tool.comfy.isolation] section existence
try:
with manifest.open("rb") as f:
data = tomllib.load(f)
if (
"tool" in data
and "comfy" in data["tool"]
and "isolation" in data["tool"]["comfy"]
):
manifest_dirs.append((entry, manifest))
except Exception:
data = _read_manifest(manifest)
if not data or not _is_isolation_manifest(data):
continue
manifest_dirs.append((entry, manifest))
manifest_dirs.extend(_discover_nested_manifests(entry))
return manifest_dirs

View File

@ -8,10 +8,17 @@ from pathlib import Path
from typing import Any, Dict, List, Set, TYPE_CHECKING
from .proxies.helper_proxies import restore_input_types
from comfy_api.internal import _ComfyNodeInternal
from comfy_api.latest import _io as latest_io
from .shm_forensics import scan_shm_forensics
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
_ComfyNodeInternal = object
latest_io = None
if _IMPORT_TORCH:
from comfy_api.internal import _ComfyNodeInternal
from comfy_api.latest import _io as latest_io
if TYPE_CHECKING:
from .extension_wrapper import ComfyNodeExtension
@ -19,6 +26,68 @@ LOG_PREFIX = "]["
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
class _RemoteObjectRegistryCaller:
def __init__(self, extension: Any) -> None:
self._extension = extension
def __getattr__(self, method_name: str) -> Any:
async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any:
return await self._extension.call_remote_object_method(
instance_id,
method_name,
*args,
**kwargs,
)
return _call
def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any:
from pyisolate._internal.remote_handle import RemoteObjectHandle
if isinstance(value, RemoteObjectHandle):
if value.type_name == "ModelPatcher":
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False)
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
return proxy
if value.type_name == "VAE":
from comfy.isolation.vae_proxy import VAEProxy
proxy = VAEProxy(value.object_id, manage_lifecycle=False)
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
return proxy
if value.type_name == "CLIP":
from comfy.isolation.clip_proxy import CLIPProxy
proxy = CLIPProxy(value.object_id, manage_lifecycle=False)
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
return proxy
if value.type_name == "ModelSampling":
from comfy.isolation.model_sampling_proxy import ModelSamplingProxy
proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False)
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
return proxy
return value
if isinstance(value, dict):
return {
k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items()
}
if isinstance(value, (list, tuple)):
wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value]
return type(value)(wrapped)
return value
def _resource_snapshot() -> Dict[str, int]:
fd_count = -1
shm_sender_files = 0
@ -146,6 +215,8 @@ def build_stub_class(
running_extensions: Dict[str, "ComfyNodeExtension"],
logger: logging.Logger,
) -> type:
if latest_io is None:
raise RuntimeError("comfy_api.latest._io is required to build isolation stubs")
is_v3 = bool(info.get("is_v3", False))
function_name = "_pyisolate_execute"
restored_input_types = restore_input_types(info.get("input_types", {}))
@ -160,6 +231,13 @@ def build_stub_class(
node_unique_id = _extract_hidden_unique_id(inputs)
summary = _tensor_transport_summary(inputs)
resources = _resource_snapshot()
logger.debug(
"%s ISO:execute_start ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
logger.debug(
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
LOG_PREFIX,
@ -192,7 +270,20 @@ def build_stub_class(
node_name,
node_unique_id or "-",
)
serialized = serialize_for_isolation(inputs)
# Unwrap NodeOutput-like dicts before serialization.
# OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)}
# and the executor may pass this dict as input to downstream nodes.
unwrapped_inputs = {}
for k, v in inputs.items():
if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v):
result = v.get("result")
if isinstance(result, (tuple, list)) and len(result) > 0:
unwrapped_inputs[k] = result[0]
else:
unwrapped_inputs[k] = result
else:
unwrapped_inputs[k] = v
serialized = serialize_for_isolation(unwrapped_inputs)
logger.debug(
"%s ISO:serialize_done ext=%s node=%s uid=%s",
LOG_PREFIX,
@ -220,15 +311,32 @@ def build_stub_class(
from comfy_api.latest import io as latest_io
args_raw = result.get("args", ())
deserialized_args = await deserialize_from_isolation(args_raw, extension)
deserialized_args = _wrap_remote_handles_as_host_proxies(
deserialized_args, extension
)
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
ui_raw = result.get("ui")
deserialized_ui = None
if ui_raw is not None:
deserialized_ui = await deserialize_from_isolation(ui_raw, extension)
deserialized_ui = _wrap_remote_handles_as_host_proxies(
deserialized_ui, extension
)
deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui)
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
return latest_io.NodeOutput(
*deserialized_args,
ui=result.get("ui"),
ui=deserialized_ui,
expand=result.get("expand"),
block_execution=result.get("block_execution"),
)
# OUTPUT_NODE: if sealed worker returned a tuple/list whose first
# element is a {"ui": ...} dict, unwrap it for the executor.
if (isinstance(result, (tuple, list)) and len(result) == 1
and isinstance(result[0], dict) and "ui" in result[0]):
return result[0]
deserialized = await deserialize_from_isolation(result, extension)
deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension)
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
return _detach_shared_cpu_tensors(deserialized)
except ImportError: