Merge pull request #13226 from pollockjj/issue_94
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

feat: isolation layer — conda/sealed workers, model proxies, web directory, fencing
This commit is contained in:
John Pollock 2026-03-30 09:38:12 +00:00 committed by GitHub
commit 8b6fadeb71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 6896 additions and 389 deletions

View File

@ -8,11 +8,21 @@ import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set, TYPE_CHECKING
import folder_paths
from .extension_loader import load_isolated_node
from .manifest_loader import find_manifest_directories
from .runtime_helpers import build_stub_class, get_class_types_for_extension
from .shm_forensics import scan_shm_forensics, start_shm_forensics
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
load_isolated_node = None
find_manifest_directories = None
build_stub_class = None
get_class_types_for_extension = None
scan_shm_forensics = None
start_shm_forensics = None
if _IMPORT_TORCH:
import folder_paths
from .extension_loader import load_isolated_node
from .manifest_loader import find_manifest_directories
from .runtime_helpers import build_stub_class, get_class_types_for_extension
from .shm_forensics import scan_shm_forensics, start_shm_forensics
if TYPE_CHECKING:
from pyisolate import ExtensionManager
@ -21,8 +31,9 @@ if TYPE_CHECKING:
LOG_PREFIX = "]["
isolated_node_timings: List[tuple[float, Path, int]] = []
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
if _IMPORT_TORCH:
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(__name__)
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
@ -42,7 +53,8 @@ def initialize_proxies() -> None:
from .host_hooks import initialize_host_process
initialize_host_process()
start_shm_forensics()
if start_shm_forensics is not None:
start_shm_forensics()
@dataclass(frozen=True)
@ -88,6 +100,8 @@ async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
return []
_ISOLATION_SCAN_ATTEMPTED = True
if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None:
return []
manifest_entries = find_manifest_directories()
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
@ -167,17 +181,27 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
return class_types
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"""Evict running extensions not needed for current execution."""
async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None:
"""Evict running extensions not needed for current execution.
When *caches* is provided, cache entries for evicted extensions' node
class_types are invalidated to prevent stale ``RemoteObjectHandle``
references from surviving in the output cache.
"""
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:notify_graph_wait_idle",
)
evicted_class_types: Set[str] = set()
async def _stop_extension(
ext_name: str, extension: "ComfyNodeExtension", reason: str
) -> None:
# Collect class_types BEFORE stopping so we can invalidate cache entries.
ext_class_types = _get_class_types_for_extension(ext_name)
evicted_class_types.update(ext_class_types)
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
stop_result = extension.stop()
@ -185,9 +209,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
await stop_result
_RUNNING_EXTENSIONS.pop(ext_name, None)
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
if scan_shm_forensics is not None:
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
if scan_shm_forensics is not None:
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
isolated_class_types_in_graph = needed_class_types.intersection(
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
)
@ -247,6 +273,22 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
)
finally:
# Invalidate cached outputs for evicted extensions so stale
# RemoteObjectHandle references are not served from cache.
if evicted_class_types and caches:
total_invalidated = 0
for cache in caches:
if hasattr(cache, "invalidate_by_class_types"):
total_invalidated += cache.invalidate_by_class_types(
evicted_class_types
)
if total_invalidated > 0:
logger.info(
"%s ISO:cache_invalidated count=%d class_types=%s",
LOG_PREFIX,
total_invalidated,
evicted_class_types,
)
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
logger.debug(
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)

View File

@ -3,13 +3,38 @@ from __future__ import annotations
import logging
import os
import inspect
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, cast
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
try:
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp.
# Safe to import in sealed workers without host framework modules.
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
# Singleton proxies that transitively import torch, PIL, or heavy host modules.
# Only available when torch/host framework is present.
CLIPProxy = None
CLIPRegistry = None
ModelPatcherProxy = None
ModelPatcherRegistry = None
ModelSamplingProxy = None
ModelSamplingRegistry = None
VAEProxy = None
VAERegistry = None
FirstStageModelRegistry = None
ModelManagementProxy = None
PromptServerService = None
ProgressProxy = None
UtilsProxy = None
_HAS_TORCH_PROXIES = False
if _IMPORT_TORCH:
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
from comfy.isolation.model_patcher_proxy import (
ModelPatcherProxy,
@ -20,13 +45,11 @@ try:
ModelSamplingRegistry,
)
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
from comfy.isolation.proxies.utils_proxy import UtilsProxy
from comfy.isolation.proxies.progress_proxy import ProgressProxy
except ImportError as exc: # Fail loud if Comfy environment is incomplete
raise ImportError(f"ComfyUI environment incomplete: {exc}")
from comfy.isolation.proxies.utils_proxy import UtilsProxy
_HAS_TORCH_PROXIES = True
logger = logging.getLogger(__name__)
@ -62,9 +85,20 @@ class ComfyUIAdapter(IsolationAdapter):
os.path.join(comfy_root, "custom_nodes"),
os.path.join(comfy_root, "comfy"),
],
"filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"],
}
return None
def get_sandbox_system_paths(self) -> Optional[List[str]]:
"""Returns required application paths to mount in the sandbox."""
# By inspecting where our adapter is loaded from, we can determine the comfy root
adapter_file = inspect.getfile(self.__class__)
# adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py
comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file)))
if os.path.exists(comfy_root):
return [comfy_root]
return None
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
comfy_root = snapshot.get("preferred_root")
if not comfy_root:
@ -83,6 +117,59 @@ class ComfyUIAdapter(IsolationAdapter):
logging.getLogger(pkg_name).setLevel(logging.ERROR)
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
if not _IMPORT_TORCH:
# Sealed worker without torch — register torch-free TensorValue handler
# so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts.
import numpy as np
_TORCH_DTYPE_TO_NUMPY = {
"torch.float32": np.float32,
"torch.float64": np.float64,
"torch.float16": np.float16,
"torch.bfloat16": np.float32, # numpy has no bfloat16; upcast
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.int16": np.int16,
"torch.int8": np.int8,
"torch.uint8": np.uint8,
"torch.bool": np.bool_,
}
def _deserialize_tensor_value(data: Dict[str, Any]) -> Any:
dtype_str = data["dtype"]
np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32)
shape = tuple(data["tensor_size"])
arr = np.array(data["data"], dtype=np_dtype).reshape(shape)
return arr
_NUMPY_TO_TORCH_DTYPE = {
np.float32: "torch.float32",
np.float64: "torch.float64",
np.float16: "torch.float16",
np.int32: "torch.int32",
np.int64: "torch.int64",
np.int16: "torch.int16",
np.int8: "torch.int8",
np.uint8: "torch.uint8",
np.bool_: "torch.bool",
}
def _serialize_tensor_value(obj: Any) -> Dict[str, Any]:
arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj)
dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32")
return {
"__type__": "TensorValue",
"dtype": dtype_str,
"tensor_size": list(arr.shape),
"requires_grad": False,
"data": arr.tolist(),
}
registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
# ndarray output from sealed workers serializes as TensorValue for host torch reconstruction
registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
return
import torch
def serialize_device(obj: Any) -> Dict[str, Any]:
@ -110,6 +197,54 @@ class ComfyUIAdapter(IsolationAdapter):
registry.register("dtype", serialize_dtype, deserialize_dtype)
from comfy_api.latest._io import FolderType
from comfy_api.latest._ui import SavedImages, SavedResult
def serialize_saved_result(obj: Any) -> Dict[str, Any]:
return {
"__type__": "SavedResult",
"filename": obj.filename,
"subfolder": obj.subfolder,
"folder_type": obj.type.value,
}
def deserialize_saved_result(data: Dict[str, Any]) -> Any:
if isinstance(data, SavedResult):
return data
folder_type = data["folder_type"] if "folder_type" in data else data["type"]
return SavedResult(
filename=data["filename"],
subfolder=data["subfolder"],
type=FolderType(folder_type),
)
registry.register(
"SavedResult",
serialize_saved_result,
deserialize_saved_result,
data_type=True,
)
def serialize_saved_images(obj: Any) -> Dict[str, Any]:
return {
"__type__": "SavedImages",
"results": [serialize_saved_result(result) for result in obj.results],
"is_animated": obj.is_animated,
}
def deserialize_saved_images(data: Dict[str, Any]) -> Any:
return SavedImages(
results=[deserialize_saved_result(result) for result in data["results"]],
is_animated=data.get("is_animated", False),
)
registry.register(
"SavedImages",
serialize_saved_images,
deserialize_saved_images,
data_type=True,
)
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
@ -212,28 +347,93 @@ class ComfyUIAdapter(IsolationAdapter):
# copyreg removed - no pickle fallback allowed
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
raise RuntimeError(
f"ModelSampling in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}"
)
# Host-side pass-through for proxies: do not re-register a proxy as a
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
# Proxy with _instance_id — return ref (works from both host and child)
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
# Child-side: object created locally in child (e.g. ModelSamplingAdvanced
# in nodes_z_image_turbo.py). Serialize as inline data so the host can
# reconstruct the real torch.nn.Module.
if os.environ.get("PYISOLATE_CHILD") == "1":
import base64
import io as _io
# Identify base classes from comfy.model_sampling
bases = []
for base in type(obj).__mro__:
if base.__module__ == "comfy.model_sampling" and base.__name__ != "object":
bases.append(base.__name__)
# Serialize state_dict as base64 safetensors-like
sd = obj.state_dict()
sd_serialized = {}
for k, v in sd.items():
buf = _io.BytesIO()
torch.save(v, buf)
sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii")
# Capture plain attrs (shift, multiplier, sigma_data, etc.)
plain_attrs = {}
for k, v in obj.__dict__.items():
if k.startswith("_"):
continue
if isinstance(v, (bool, int, float, str)):
plain_attrs[k] = v
return {
"__type__": "ModelSamplingInline",
"bases": bases,
"state_dict": sd_serialized,
"attrs": plain_attrs,
}
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
ms_id = ModelSamplingRegistry().register(obj)
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
def deserialize_model_sampling(data: Any) -> Any:
"""Deserialize ModelSampling refs; pass through already-materialized objects."""
"""Deserialize ModelSampling refs or inline data."""
if isinstance(data, dict):
if data.get("__type__") == "ModelSamplingInline":
return _reconstruct_model_sampling_inline(data)
return ModelSamplingProxy(data["ms_id"])
return data
def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any:
"""Reconstruct a ModelSampling object on the host from inline child data."""
import comfy.model_sampling as _ms
import base64
import io as _io
# Resolve base classes
base_classes = []
for name in data["bases"]:
cls = getattr(_ms, name, None)
if cls is not None:
base_classes.append(cls)
if not base_classes:
raise RuntimeError(
f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}"
)
# Create dynamic class matching the child's class hierarchy
ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {})
obj = ReconstructedSampling.__new__(ReconstructedSampling)
torch.nn.Module.__init__(obj)
# Restore plain attributes first
for k, v in data.get("attrs", {}).items():
setattr(obj, k, v)
# Restore state_dict (buffers like sigmas)
for k, v_b64 in data.get("state_dict", {}).items():
buf = _io.BytesIO(base64.b64decode(v_b64))
tensor = torch.load(buf, weights_only=True)
# Register as buffer so it's part of state_dict
parts = k.split(".")
if len(parts) == 1:
cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member
else:
setattr(obj, parts[0], tensor)
# Register on host so future references use proxy pattern.
# Skip in child process — register() is async RPC and cannot be
# called synchronously during deserialization.
if os.environ.get("PYISOLATE_CHILD") != "1":
ModelSamplingRegistry().register(obj)
return obj
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
"""Context-aware ModelSamplingRef deserializer for both host and child."""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
@ -262,6 +462,10 @@ class ComfyUIAdapter(IsolationAdapter):
)
# Register ModelSamplingRef for deserialization (context-aware: host or child)
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
# Register ModelSamplingInline for deserialization (child→host inline transfer)
registry.register(
"ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data)
)
def serialize_cond(obj: Any) -> Dict[str, Any]:
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
@ -393,52 +597,20 @@ class ComfyUIAdapter(IsolationAdapter):
# Fallback for non-numeric arrays (strings, objects, mixes)
return obj.tolist()
registry.register("ndarray", serialize_numpy, None)
def serialize_ply(obj: Any) -> Dict[str, Any]:
def deserialize_numpy_b64(data: Any) -> Any:
"""Deserialize base64-encoded ndarray from sealed worker."""
import base64
import torch
if obj.raw_data is not None:
return {
"__type__": "PLY",
"raw_data": base64.b64encode(obj.raw_data).decode("ascii"),
}
result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)}
if obj.colors is not None:
result["colors"] = torch.from_numpy(obj.colors)
if obj.confidence is not None:
result["confidence"] = torch.from_numpy(obj.confidence)
if obj.view_id is not None:
result["view_id"] = torch.from_numpy(obj.view_id)
return result
import numpy as np
if isinstance(data, dict) and "data" in data and "dtype" in data:
raw = base64.b64decode(data["data"])
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
return torch.from_numpy(arr.copy())
return data
def deserialize_ply(data: Any) -> Any:
import base64
from comfy_api.latest._util.ply_types import PLY
if "raw_data" in data:
return PLY(raw_data=base64.b64decode(data["raw_data"]))
return PLY(
points=data["points"],
colors=data.get("colors"),
confidence=data.get("confidence"),
view_id=data.get("view_id"),
)
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
registry.register("PLY", serialize_ply, deserialize_ply, data_type=True)
def serialize_npz(obj: Any) -> Dict[str, Any]:
import base64
return {
"__type__": "NPZ",
"frames": [base64.b64encode(f).decode("ascii") for f in obj.frames],
}
def deserialize_npz(data: Any) -> Any:
import base64
from comfy_api.latest._util.npz_types import NPZ
return NPZ(frames=[base64.b64decode(f) for f in data["frames"]])
registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True)
# -- File3D (comfy_api.latest._util.geometry_types) ---------------------
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
def serialize_file3d(obj: Any) -> Dict[str, Any]:
import base64
@ -456,6 +628,9 @@ class ComfyUIAdapter(IsolationAdapter):
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
# -- VIDEO (comfy_api.latest._input_impl.video_types) -------------------
# Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962
def serialize_video(obj: Any) -> Dict[str, Any]:
components = obj.get_components()
images = components.images.detach() if components.images.requires_grad else components.images
@ -494,19 +669,156 @@ class ComfyUIAdapter(IsolationAdapter):
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
def setup_web_directory(self, module: Any) -> None:
"""Detect WEB_DIRECTORY on a module and populate/register it.
Called by the sealed worker after loading the node module.
Mirrors extension_wrapper.py:216-227 for host-coupled nodes.
Does NOT import extension_wrapper.py (it has `import torch` at module level).
"""
import shutil
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
if web_dir_attr is None:
return
module_dir = os.path.dirname(os.path.abspath(module.__file__))
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
# Read extension name from pyproject.toml
ext_name = os.path.basename(module_dir)
pyproject = os.path.join(module_dir, "pyproject.toml")
if os.path.exists(pyproject):
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
try:
with open(pyproject, "rb") as f:
data = tomllib.load(f)
name = data.get("project", {}).get("name")
if name:
ext_name = name
except Exception:
pass
# Populate web dir if empty (mirrors _run_prestartup_web_copy)
if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))):
os.makedirs(web_dir_path, exist_ok=True)
# Module-defined copy spec
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
if copy_spec is not None and callable(copy_spec):
try:
copy_spec(web_dir_path)
except Exception as e:
logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e)
# Fallback: comfy_3d_viewers
try:
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
for viewer in VIEWER_FILES:
try:
copy_viewer(viewer, web_dir_path)
except Exception:
pass
except ImportError:
pass
# Fallback: comfy_dynamic_widgets
try:
from comfy_dynamic_widgets import get_js_path
src = os.path.realpath(get_js_path())
if os.path.exists(src):
dst_dir = os.path.join(web_dir_path, "js")
os.makedirs(dst_dir, exist_ok=True)
shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js"))
except ImportError:
pass
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
logger.info(
"][ Adapter: registered web dir for %s (%d files)",
ext_name,
sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()),
)
@staticmethod
def register_host_event_handlers(extension: Any) -> None:
"""Register host-side event handlers for an isolated extension.
Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK``
so the ComfyUI frontend receives progress bar updates.
"""
register_event_handler = inspect.getattr_static(
extension, "register_event_handler", None
)
if not callable(register_event_handler):
return
def _host_progress_handler(payload: dict) -> None:
import comfy.utils
hook = comfy.utils.PROGRESS_BAR_HOOK
if hook is not None:
hook(
payload.get("value", 0),
payload.get("total", 0),
payload.get("preview"),
payload.get("node_id"),
)
extension.register_event_handler("progress", _host_progress_handler)
def setup_child_event_hooks(self, extension: Any) -> None:
"""Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension.
Host-coupled only sealed workers do not have comfy.utils (torch).
"""
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child)
if not is_child:
return
if not _IMPORT_TORCH:
logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)")
return
import comfy.utils
def _event_progress_hook(value, total, preview=None, node_id=None):
logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id)
extension.emit_event("progress", {
"value": value,
"total": total,
"node_id": node_id,
})
comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook
logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel")
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
return [
PromptServerService,
# Always available — no torch/PIL dependency
services: List[type[ProxiedSingleton]] = [
FolderPathsProxy,
ModelManagementProxy,
UtilsProxy,
ProgressProxy,
VAERegistry,
CLIPRegistry,
ModelPatcherRegistry,
ModelSamplingRegistry,
FirstStageModelRegistry,
HelperProxiesService,
WebDirectoryProxy,
]
# Torch/PIL-dependent proxies
if _HAS_TORCH_PROXIES:
services.extend([
PromptServerService,
ModelManagementProxy,
UtilsProxy,
ProgressProxy,
VAERegistry,
CLIPRegistry,
ModelPatcherRegistry,
ModelSamplingRegistry,
FirstStageModelRegistry,
])
return services
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
# Resolve the real name whether it's an instance or the Singleton class itself
@ -525,33 +837,45 @@ class ComfyUIAdapter(IsolationAdapter):
# Fence: isolated children get writable temp inside sandbox
if os.environ.get("PYISOLATE_CHILD") == "1":
_child_temp = os.path.join("/tmp", "comfyui_temp")
import tempfile
_child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp")
os.makedirs(_child_temp, exist_ok=True)
folder_paths.temp_directory = _child_temp
return
if api_name == "ModelManagementProxy":
import comfy.model_management
if _IMPORT_TORCH:
import comfy.model_management
instance = api() if isinstance(api, type) else api
# Replace module-level functions with proxy methods
for name in dir(instance):
if not name.startswith("_"):
setattr(comfy.model_management, name, getattr(instance, name))
instance = api() if isinstance(api, type) else api
# Replace module-level functions with proxy methods
for name in dir(instance):
if not name.startswith("_"):
setattr(comfy.model_management, name, getattr(instance, name))
return
if api_name == "UtilsProxy":
if not _IMPORT_TORCH:
logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)")
return
import comfy.utils
# Static Injection of RPC mechanism to ensure Child can access it
# independent of instance lifecycle.
api.set_rpc(rpc)
# Don't overwrite host hook (infinite recursion)
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child)
# Progress hook wiring moved to setup_child_event_hooks via event channel
return
if api_name == "PromptServerProxy":
if not _IMPORT_TORCH:
return
# Defer heavy import to child context
import server

View File

@ -11,130 +11,90 @@ def is_child_process() -> bool:
def initialize_child_process() -> None:
_setup_child_loop_bridge()
# Manual RPC injection
try:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc:
_setup_prompt_server_stub(rpc)
_setup_utils_proxy(rpc)
_setup_proxy_callers(rpc)
else:
logger.warning("Could not get child RPC instance for manual injection")
_setup_prompt_server_stub()
_setup_utils_proxy()
_setup_proxy_callers()
except Exception as e:
logger.error(f"Manual RPC Injection failed: {e}")
_setup_prompt_server_stub()
_setup_utils_proxy()
_setup_proxy_callers()
_setup_logging()
def _setup_child_loop_bridge() -> None:
import asyncio
main_loop = None
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
try:
main_loop = asyncio.get_event_loop()
except RuntimeError:
pass
if main_loop is None:
return
try:
from .proxies.base import set_global_loop
set_global_loop(main_loop)
except ImportError:
pass
def _setup_prompt_server_stub(rpc=None) -> None:
try:
from .proxies.prompt_server_impl import PromptServerStub
import sys
import types
# Mock server module
if "server" not in sys.modules:
mock_server = types.ModuleType("server")
sys.modules["server"] = mock_server
server = sys.modules["server"]
if not hasattr(server, "PromptServer"):
class MockPromptServer:
pass
server.PromptServer = MockPromptServer
stub = PromptServerStub()
if rpc:
PromptServerStub.set_rpc(rpc)
if hasattr(stub, "set_rpc"):
stub.set_rpc(rpc)
server.PromptServer.instance = stub
elif hasattr(PromptServerStub, "clear_rpc"):
PromptServerStub.clear_rpc()
else:
PromptServerStub._rpc = None # type: ignore[attr-defined]
except Exception as e:
logger.error(f"Failed to setup PromptServerStub: {e}")
def _setup_utils_proxy(rpc=None) -> None:
def _setup_proxy_callers(rpc=None) -> None:
try:
import comfy.utils
import asyncio
from .proxies.folder_paths_proxy import FolderPathsProxy
from .proxies.helper_proxies import HelperProxiesService
from .proxies.model_management_proxy import ModelManagementProxy
from .proxies.progress_proxy import ProgressProxy
from .proxies.prompt_server_impl import PromptServerStub
from .proxies.utils_proxy import UtilsProxy
# Capture main loop during initialization (safe context)
main_loop = None
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
try:
main_loop = asyncio.get_event_loop()
except RuntimeError:
pass
if rpc is None:
FolderPathsProxy.clear_rpc()
HelperProxiesService.clear_rpc()
ModelManagementProxy.clear_rpc()
ProgressProxy.clear_rpc()
PromptServerStub.clear_rpc()
UtilsProxy.clear_rpc()
return
try:
from .proxies.base import set_global_loop
if main_loop:
set_global_loop(main_loop)
except ImportError:
pass
# Sync hook wrapper for progress updates
def sync_hook_wrapper(
value: int, total: int, preview: None = None, node_id: None = None
) -> None:
if node_id is None:
try:
from comfy_execution.utils import get_executing_context
ctx = get_executing_context()
if ctx:
node_id = ctx.node_id
else:
pass
except Exception:
pass
# Bypass blocked event loop by direct outbox injection
if rpc:
try:
# Use captured main loop if available (for threaded execution), or current loop
loop = main_loop
if loop is None:
loop = asyncio.get_event_loop()
rpc.outbox.put(
{
"kind": "call",
"object_id": "UtilsProxy",
"parent_call_id": None, # We are root here usually
"calling_loop": loop,
"future": loop.create_future(), # Dummy future
"method": "progress_bar_hook",
"args": (value, total, preview, node_id),
"kwargs": {},
}
)
except Exception as e:
logging.getLogger(__name__).error(f"Manual Inject Failed: {e}")
else:
logging.getLogger(__name__).warning(
"No RPC instance available for progress update"
)
comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper
FolderPathsProxy.set_rpc(rpc)
HelperProxiesService.set_rpc(rpc)
ModelManagementProxy.set_rpc(rpc)
ProgressProxy.set_rpc(rpc)
PromptServerStub.set_rpc(rpc)
UtilsProxy.set_rpc(rpc)
except Exception as e:
logger.error(f"Failed to setup UtilsProxy hook: {e}")
logger.error(f"Failed to setup child singleton proxy callers: {e}")
def _setup_logging() -> None:

View File

@ -0,0 +1,16 @@
"""Compatibility shim for the indexed serializer path."""
from __future__ import annotations
from typing import Any
def register_custom_node_serializers(_registry: Any) -> None:
"""Legacy no-op shim.
Serializer registration now lives directly in the active isolation adapter.
This module remains importable because the isolation index still references it.
"""
return None
__all__ = ["register_custom_node_serializers"]

View File

@ -8,14 +8,13 @@ import sys
import types
import platform
from pathlib import Path
from typing import Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Tuple
import pyisolate
from pyisolate import ExtensionManager, ExtensionManagerConfig
from packaging.requirements import InvalidRequirement, Requirement
from packaging.utils import canonicalize_name
from .extension_wrapper import ComfyNodeExtension
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
from .host_policy import load_host_policy
@ -42,7 +41,11 @@ def _register_web_directory(extension_name: str, node_dir: Path) -> None:
web_dir_path = str(node_dir / web_dir_name)
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
logger.debug(
"][ Registered web dir for isolated %s: %s",
extension_name,
web_dir_path,
)
return
except Exception:
pass
@ -62,15 +65,26 @@ def _register_web_directory(extension_name: str, node_dir: Path) -> None:
web_dir_path = str((node_dir / value).resolve())
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
logger.debug(
"][ Registered web dir for isolated %s: %s",
extension_name,
web_dir_path,
)
return
except Exception:
pass
async def _stop_extension_safe(
extension: ComfyNodeExtension, extension_name: str
) -> None:
def _get_extension_type(execution_model: str) -> type[Any]:
if execution_model == "sealed_worker":
return pyisolate.SealedNodeExtension
from .extension_wrapper import ComfyNodeExtension
return ComfyNodeExtension
async def _stop_extension_safe(extension: Any, extension_name: str) -> None:
try:
stop_result = extension.stop()
if inspect.isawaitable(stop_result):
@ -126,12 +140,18 @@ def _parse_cuda_wheels_config(
if raw_config is None:
return None
if not isinstance(raw_config, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels] must be a table"
)
raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table")
index_url = raw_config.get("index_url")
if not isinstance(index_url, str) or not index_url.strip():
index_urls = raw_config.get("index_urls")
if index_urls is not None:
if not isinstance(index_urls, list) or not all(
isinstance(u, str) and u.strip() for u in index_urls
):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings"
)
elif not isinstance(index_url, str) or not index_url.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
)
@ -187,11 +207,15 @@ def _parse_cuda_wheels_config(
)
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
return {
"index_url": index_url.rstrip("/") + "/",
result: dict = {
"packages": normalized_packages,
"package_map": normalized_package_map,
}
if index_urls is not None:
result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls]
else:
result["index_url"] = index_url.rstrip("/") + "/"
return result
def get_enforcement_policy() -> Dict[str, bool]:
@ -228,7 +252,7 @@ async def load_isolated_node(
node_dir: Path,
manifest_path: Path,
logger: logging.Logger,
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
build_stub_class: Callable[[str, Dict[str, object], Any], type],
venv_root: Path,
extension_managers: List[ExtensionManager],
) -> List[Tuple[str, str, type]]:
@ -243,6 +267,31 @@ async def load_isolated_node(
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
can_isolate = tool_config.get("can_isolate", False)
share_torch = tool_config.get("share_torch", False)
package_manager = tool_config.get("package_manager", "uv")
is_conda = package_manager == "conda"
execution_model = tool_config.get("execution_model")
if execution_model is None:
execution_model = "sealed_worker" if is_conda else "host-coupled"
if "sealed_host_ro_paths" in tool_config:
raise ValueError(
"Manifest field 'sealed_host_ro_paths' is not allowed. "
"Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy."
)
# Conda-specific manifest fields
conda_channels: list[str] = (
tool_config.get("conda_channels", []) if is_conda else []
)
conda_dependencies: list[str] = (
tool_config.get("conda_dependencies", []) if is_conda else []
)
conda_platforms: list[str] = (
tool_config.get("conda_platforms", []) if is_conda else []
)
conda_python: str = (
tool_config.get("conda_python", "*") if is_conda else "*"
)
# Parse [project] dependencies
project_config = manifest_data.get("project", {})
@ -260,8 +309,6 @@ async def load_isolated_node(
if not isolated:
return []
logger.info(f"][ Loading isolated node: {extension_name}")
import folder_paths
base_paths = [Path(folder_paths.base_path), node_dir]
@ -272,8 +319,9 @@ async def load_isolated_node(
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
extension_type = _get_extension_type(execution_model)
manager: ExtensionManager = pyisolate.ExtensionManager(
ComfyNodeExtension, manager_config
extension_type, manager_config
)
extension_managers.append(manager)
@ -281,15 +329,21 @@ async def load_isolated_node(
sandbox_config = {}
is_linux = platform.system() == "Linux"
if is_conda:
share_torch = False
share_cuda_ipc = False
else:
share_cuda_ipc = share_torch and is_linux
if is_linux and isolated:
sandbox_config = {
"network": host_policy["allow_network"],
"writable_paths": host_policy["writable_paths"],
"readonly_paths": host_policy["readonly_paths"],
}
share_cuda_ipc = share_torch and is_linux
extension_config = {
extension_config: dict = {
"name": extension_name,
"module_path": str(node_dir),
"isolated": True,
@ -299,17 +353,60 @@ async def load_isolated_node(
"sandbox_mode": host_policy["sandbox_mode"],
"sandbox": sandbox_config,
}
_is_sealed = execution_model == "sealed_worker"
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
logger.info(
"][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])",
extension_name,
"x" if share_torch else " ",
"x" if _is_sealed else " ",
"x" if _is_sandboxed else " ",
)
if cuda_wheels is not None:
extension_config["cuda_wheels"] = cuda_wheels
# Conda-specific keys
if is_conda:
extension_config["package_manager"] = "conda"
extension_config["conda_channels"] = conda_channels
extension_config["conda_dependencies"] = conda_dependencies
extension_config["conda_python"] = conda_python
find_links = tool_config.get("find_links", [])
if find_links:
extension_config["find_links"] = find_links
if conda_platforms:
extension_config["conda_platforms"] = conda_platforms
if execution_model != "host-coupled":
extension_config["execution_model"] = execution_model
if execution_model == "sealed_worker":
policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", [])
if isinstance(policy_ro_paths, list) and policy_ro_paths:
extension_config["sealed_host_ro_paths"] = list(policy_ro_paths)
# Sealed workers keep the host RPC service inventory even when the
# child resolves no API classes locally.
extension = manager.load_extension(extension_config)
register_dummy_module(extension_name, node_dir)
# Register host-side event handlers via adapter
from .adapter import ComfyUIAdapter
ComfyUIAdapter.register_host_event_handlers(extension)
# Register web directory on the host — only when sandbox is disabled.
# In sandbox mode, serving untrusted JS to the browser is not safe.
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# Register for proxied web serving — the child's web dir may have
# content that doesn't exist on the host (e.g., pip-installed viewer
# bundles). The WebDirectoryCache will lazily fetch via RPC.
from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache
cache = get_web_directory_cache()
cache.register_proxy(extension_name, WebDirectoryProxy())
# Try cache first (lazy spawn)
if is_cache_valid(node_dir, manifest_path, venv_root):
cached_data = load_from_cache(node_dir, venv_root)
@ -379,6 +476,10 @@ async def load_isolated_node(
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
logger.debug(f"][ {extension_name} metadata cached")
# Re-check web directory AFTER child has populated it
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# EJECT: Kill process after getting metadata (will respawn on first execution)
await _stop_extension_safe(extension, extension_name)

View File

@ -37,6 +37,88 @@ _PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
logger = logging.getLogger(__name__)
def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None:
"""Run the web asset copy step that prestartup_script.py used to do.
If the module's web/ directory is empty and the module had a
prestartup_script.py that copied assets from pip packages, this
function replicates that work inside the child process.
Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if
defined, otherwise falls back to detecting common asset packages.
"""
import shutil
# Already populated — nothing to do
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
return
os.makedirs(web_dir_path, exist_ok=True)
# Try module-defined copy spec first (generic hook for any node pack)
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
if copy_spec is not None and callable(copy_spec):
try:
copy_spec(web_dir_path)
logger.info(
"%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir
)
return
except Exception as e:
logger.warning(
"%s _PRESTARTUP_WEB_COPY failed for %s: %s",
LOG_PREFIX, module_dir, e,
)
# Fallback: detect comfy_3d_viewers and run copy_viewer()
try:
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
viewers = list(VIEWER_FILES.keys())
for viewer in viewers:
try:
copy_viewer(viewer, web_dir_path)
except Exception:
pass
if any(os.scandir(web_dir_path)):
logger.info(
"%s Copied %d viewer types from comfy_3d_viewers to %s",
LOG_PREFIX, len(viewers), web_dir_path,
)
except ImportError:
pass
# Fallback: detect comfy_dynamic_widgets
try:
from comfy_dynamic_widgets import get_js_path
src = os.path.realpath(get_js_path())
if os.path.exists(src):
dst_dir = os.path.join(web_dir_path, "js")
os.makedirs(dst_dir, exist_ok=True)
dst = os.path.join(dst_dir, "dynamic_widgets.js")
shutil.copy2(src, dst)
except ImportError:
pass
def _read_extension_name(module_dir: str) -> str:
"""Read extension name from pyproject.toml, falling back to directory name."""
pyproject = os.path.join(module_dir, "pyproject.toml")
if os.path.exists(pyproject):
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
try:
with open(pyproject, "rb") as f:
data = tomllib.load(f)
name = data.get("project", {}).get("name")
if name:
return name
except Exception:
pass
return os.path.basename(module_dir)
def _flush_tensor_transport_state(marker: str) -> int:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
@ -130,6 +212,20 @@ class ComfyNodeExtension(ExtensionBase):
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
# Register web directory with WebDirectoryProxy (child-side)
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
if web_dir_attr is not None:
module_dir = os.path.dirname(os.path.abspath(module.__file__))
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
ext_name = _read_extension_name(module_dir)
# If web dir is empty, run the copy step that prestartup_script.py did
_run_prestartup_web_copy(module, module_dir, web_dir_path)
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
try:
from comfy_api.latest import ComfyExtension
@ -365,6 +461,12 @@ class ComfyNodeExtension(ExtensionBase):
for key, value in hidden_found.items():
setattr(node_cls.hidden, key.value.lower(), value)
# INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this
# flag is set. The isolation RPC delivers unwrapped values, so we must
# wrap each input in a single-element list to match the contract.
if getattr(node_cls, "INPUT_IS_LIST", False):
resolved_inputs = {k: [v] for k, v in resolved_inputs.items()}
function_name = getattr(node_cls, "FUNCTION", "execute")
if not hasattr(instance, function_name):
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
@ -402,7 +504,7 @@ class ComfyNodeExtension(ExtensionBase):
"args": self._wrap_unpicklable_objects(result.args),
}
if result.ui is not None:
node_output_dict["ui"] = result.ui
node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui)
if getattr(result, "expand", None) is not None:
node_output_dict["expand"] = result.expand
if getattr(result, "block_execution", None) is not None:
@ -454,6 +556,85 @@ class ComfyNodeExtension(ExtensionBase):
return self.remote_objects[object_id]
def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle:
object_id = str(uuid.uuid4())
self.remote_objects[object_id] = obj
return RemoteObjectHandle(object_id, type(obj).__name__)
async def call_remote_object_method(
self,
object_id: str,
method_name: str,
*args: Any,
**kwargs: Any,
) -> Any:
"""Invoke a method or attribute-backed accessor on a child-owned object."""
obj = await self.get_remote_object(object_id)
if method_name == "get_patcher_attr":
return getattr(obj, args[0])
if method_name == "get_model_options":
return getattr(obj, "model_options")
if method_name == "set_model_options":
setattr(obj, "model_options", args[0])
return None
if method_name == "get_object_patches":
return getattr(obj, "object_patches")
if method_name == "get_patches":
return getattr(obj, "patches")
if method_name == "get_wrappers":
return getattr(obj, "wrappers")
if method_name == "get_callbacks":
return getattr(obj, "callbacks")
if method_name == "get_load_device":
return getattr(obj, "load_device")
if method_name == "get_offload_device":
return getattr(obj, "offload_device")
if method_name == "get_hook_mode":
return getattr(obj, "hook_mode")
if method_name == "get_parent":
parent = getattr(obj, "parent", None)
if parent is None:
return None
return self._store_remote_object_handle(parent)
if method_name == "get_inner_model_attr":
attr_name = args[0]
if hasattr(obj.model, attr_name):
return getattr(obj.model, attr_name)
if hasattr(obj, attr_name):
return getattr(obj, attr_name)
return None
if method_name == "inner_model_apply_model":
return obj.model.apply_model(*args[0], **args[1])
if method_name == "inner_model_extra_conds_shapes":
return obj.model.extra_conds_shapes(*args[0], **args[1])
if method_name == "inner_model_extra_conds":
return obj.model.extra_conds(*args[0], **args[1])
if method_name == "inner_model_memory_required":
return obj.model.memory_required(*args[0], **args[1])
if method_name == "process_latent_in":
return obj.model.process_latent_in(*args[0], **args[1])
if method_name == "process_latent_out":
return obj.model.process_latent_out(*args[0], **args[1])
if method_name == "scale_latent_inpaint":
return obj.model.scale_latent_inpaint(*args[0], **args[1])
if method_name.startswith("get_"):
attr_name = method_name[4:]
if hasattr(obj, attr_name):
return getattr(obj, attr_name)
target = getattr(obj, method_name)
if callable(target):
result = target(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
if type(result).__name__ == "ModelPatcher":
return self._store_remote_object_handle(result)
return result
if args or kwargs:
raise TypeError(f"{method_name} is not callable on remote object {object_id}")
return target
def _wrap_unpicklable_objects(self, data: Any) -> Any:
if isinstance(data, (str, int, float, bool, type(None))):
return data
@ -514,9 +695,7 @@ class ComfyNodeExtension(ExtensionBase):
if serializer:
return serializer(data)
object_id = str(uuid.uuid4())
self.remote_objects[object_id] = data
return RemoteObjectHandle(object_id, type(data).__name__)
return self._store_remote_object_handle(data)
def _resolve_remote_objects(self, data: Any) -> Any:
if isinstance(data, RemoteObjectHandle):

View File

@ -12,15 +12,19 @@ def initialize_host_process() -> None:
root.addHandler(logging.NullHandler())
from .proxies.folder_paths_proxy import FolderPathsProxy
from .proxies.helper_proxies import HelperProxiesService
from .proxies.model_management_proxy import ModelManagementProxy
from .proxies.progress_proxy import ProgressProxy
from .proxies.prompt_server_impl import PromptServerService
from .proxies.utils_proxy import UtilsProxy
from .proxies.web_directory_proxy import WebDirectoryProxy
from .vae_proxy import VAERegistry
FolderPathsProxy()
HelperProxiesService()
ModelManagementProxy()
ProgressProxy()
PromptServerService()
UtilsProxy()
WebDirectoryProxy()
VAERegistry()

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import os
from pathlib import Path
from pathlib import PurePosixPath
from typing import Dict, List, TypedDict
try:
@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"})
class HostSecurityPolicy(TypedDict):
@ -22,14 +24,16 @@ class HostSecurityPolicy(TypedDict):
allow_network: bool
writable_paths: List[str]
readonly_paths: List[str]
sealed_worker_ro_import_paths: List[str]
whitelist: Dict[str, str]
DEFAULT_POLICY: HostSecurityPolicy = {
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": ["/dev/shm", "/tmp"],
"writable_paths": ["/dev/shm"],
"readonly_paths": [],
"sealed_worker_ro_import_paths": [],
"whitelist": {},
}
@ -40,10 +44,68 @@ def _default_policy() -> HostSecurityPolicy:
"allow_network": DEFAULT_POLICY["allow_network"],
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
"sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]),
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
}
def _normalize_writable_paths(paths: list[object]) -> list[str]:
normalized_paths: list[str] = []
for raw_path in paths:
# Host-policy paths are contract-style POSIX paths; keep representation
# stable across Windows/Linux so tests and config behavior stay consistent.
normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/")))
if normalized_path in FORBIDDEN_WRITABLE_PATHS:
continue
normalized_paths.append(normalized_path)
return normalized_paths
def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]:
if not file_path.is_absolute():
file_path = config_path.parent / file_path
if not file_path.exists():
logger.warning("whitelist_file %s not found, skipping.", file_path)
return {}
entries: Dict[str, str] = {}
for line in file_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
entries[line] = "*"
logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path)
return entries
def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]:
if not isinstance(raw_paths, list):
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths."
)
normalized_paths: list[str] = []
seen: set[str] = set()
for raw_path in raw_paths:
if not isinstance(raw_path, str) or not raw_path.strip():
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings."
)
normalized_path = str(PurePosixPath(raw_path.replace("\\", "/")))
# Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...)
is_absolute = normalized_path.startswith("/") or (
len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/"
)
if not is_absolute:
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths."
)
if normalized_path not in seen:
seen.add(normalized_path)
normalized_paths.append(normalized_path)
return normalized_paths
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
@ -86,14 +148,23 @@ def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
policy["allow_network"] = bool(tool_config["allow_network"])
if "writable_paths" in tool_config:
policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]]
policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"])
if "readonly_paths" in tool_config:
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
if "sealed_worker_ro_import_paths" in tool_config:
policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths(
tool_config["sealed_worker_ro_import_paths"]
)
whitelist_file = tool_config.get("whitelist_file")
if isinstance(whitelist_file, str):
policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path))
whitelist_raw = tool_config.get("whitelist")
if isinstance(whitelist_raw, dict):
policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()}
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
logger.debug(
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",

View File

@ -24,6 +24,49 @@ CACHE_SUBDIR = "cache"
CACHE_KEY_FILE = "cache_key"
CACHE_DATA_FILE = "node_info.json"
CACHE_KEY_LENGTH = 16
_NESTED_SCAN_ROOT = "packages"
_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"}
def _read_manifest(manifest_path: Path) -> dict[str, Any] | None:
try:
with manifest_path.open("rb") as f:
data = tomllib.load(f)
if isinstance(data, dict):
return data
except Exception:
return None
return None
def _is_isolation_manifest(data: dict[str, Any]) -> bool:
return (
"tool" in data
and "comfy" in data["tool"]
and "isolation" in data["tool"]["comfy"]
)
def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]:
packages_root = entry / _NESTED_SCAN_ROOT
if not packages_root.exists() or not packages_root.is_dir():
return []
nested: List[Tuple[Path, Path]] = []
for manifest in sorted(packages_root.rglob("pyproject.toml")):
node_dir = manifest.parent
if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts):
continue
data = _read_manifest(manifest)
if not data or not _is_isolation_manifest(data):
continue
isolation = data["tool"]["comfy"]["isolation"]
if isolation.get("standalone") is True:
nested.append((node_dir, manifest))
return nested
def find_manifest_directories() -> List[Tuple[Path, Path]]:
@ -45,21 +88,13 @@ def find_manifest_directories() -> List[Tuple[Path, Path]]:
if not manifest.exists():
continue
# Validate [tool.comfy.isolation] section existence
try:
with manifest.open("rb") as f:
data = tomllib.load(f)
if (
"tool" in data
and "comfy" in data["tool"]
and "isolation" in data["tool"]["comfy"]
):
manifest_dirs.append((entry, manifest))
except Exception:
data = _read_manifest(manifest)
if not data or not _is_isolation_manifest(data):
continue
manifest_dirs.append((entry, manifest))
manifest_dirs.extend(_discover_nested_manifests(entry))
return manifest_dirs

View File

@ -22,6 +22,16 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
__module__ = "comfy.model_patcher"
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
def _spawn_related_proxy(self, instance_id: str) -> "ModelPatcherProxy":
proxy = ModelPatcherProxy(
instance_id,
self._registry,
manage_lifecycle=not IS_CHILD_PROCESS,
)
if getattr(self, "_rpc_caller", None) is not None:
proxy._rpc_caller = self._rpc_caller
return proxy
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
@ -164,9 +174,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def clone(self) -> ModelPatcherProxy:
new_id = self._call_rpc("clone")
return ModelPatcherProxy(
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
return self._spawn_related_proxy(new_id)
def clone_has_same_weights(self, clone: Any) -> bool:
if isinstance(clone, ModelPatcherProxy):
@ -509,11 +517,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
)
new_model = None
if result.get("model_id"):
new_model = ModelPatcherProxy(
result["model_id"],
self._registry,
manage_lifecycle=not IS_CHILD_PROCESS,
)
new_model = self._spawn_related_proxy(result["model_id"])
new_clip = None
if result.get("clip_id"):
from comfy.isolation.clip_proxy import CLIPProxy
@ -789,12 +793,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def get_additional_models(self) -> List[ModelPatcherProxy]:
ids = self._call_rpc("get_additional_models")
return [
ModelPatcherProxy(
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
for mid in ids
]
return [self._spawn_related_proxy(mid) for mid in ids]
def model_patches_models(self) -> Any:
return self._call_rpc("model_patches_models")
@ -803,6 +802,25 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def parent(self) -> Any:
return self._call_rpc("get_parent")
def model_mmap_residency(self, free: bool = False) -> tuple:
result = self._call_rpc("model_mmap_residency", free)
if isinstance(result, list):
return tuple(result)
return result
def pinned_memory_size(self) -> int:
return self._call_rpc("pinned_memory_size")
def get_non_dynamic_delegate(self) -> ModelPatcherProxy:
new_id = self._call_rpc("get_non_dynamic_delegate")
return self._spawn_related_proxy(new_id)
def disable_model_cfg1_optimization(self) -> None:
self._call_rpc("disable_model_cfg1_optimization")
def set_model_noise_refiner_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "noise_refiner")
class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy):
@ -812,8 +830,14 @@ class _InnerModelProxy:
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
raise AttributeError(name)
if name == "model_config":
from types import SimpleNamespace
data = self._parent._call_rpc("get_inner_model_attr", name)
if isinstance(data, dict):
return SimpleNamespace(**data)
return data
if name in (
"model_config",
"latent_format",
"model_type",
"current_weight_patches_uuid",
@ -824,11 +848,14 @@ class _InnerModelProxy:
if name == "device":
return self._parent._call_rpc("get_inner_model_attr", "device")
if name == "current_patcher":
return ModelPatcherProxy(
proxy = ModelPatcherProxy(
self._parent._instance_id,
self._parent._registry,
manage_lifecycle=False,
)
if getattr(self._parent, "_rpc_caller", None) is not None:
proxy._rpc_caller = self._parent._rpc_caller
return proxy
if name == "model_sampling":
if self._model_sampling is None:
self._model_sampling = self._parent._call_rpc(

View File

@ -250,22 +250,47 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return f"<ModelObject: {type(instance.model).__name__}>"
result = instance.get_model_object(name)
if name == "model_sampling":
from comfy.isolation.model_sampling_proxy import (
ModelSamplingRegistry,
ModelSamplingProxy,
)
registry = ModelSamplingRegistry()
# Preserve identity when upstream already returned a proxy. Re-registering
# a proxy object creates proxy-of-proxy call chains.
if isinstance(result, ModelSamplingProxy):
sampling_id = result._instance_id
else:
sampling_id = registry.register(result)
return ModelSamplingProxy(sampling_id, registry)
# Return inline serialization so the child reconstructs the real
# class with correct isinstance behavior. Returning a
# ModelSamplingProxy breaks isinstance checks (e.g.
# offset_first_sigma_for_snr in k_diffusion/sampling.py:173).
return self._serialize_model_sampling_inline(result)
return detach_if_grad(result)
@staticmethod
def _serialize_model_sampling_inline(obj: Any) -> dict:
"""Serialize a ModelSampling object as inline data for the child to reconstruct."""
import torch
import base64
import io as _io
bases = []
for base in type(obj).__mro__:
if base.__module__ == "comfy.model_sampling" and base.__name__ != "object":
bases.append(base.__name__)
sd = obj.state_dict()
sd_serialized = {}
for k, v in sd.items():
buf = _io.BytesIO()
torch.save(v, buf)
sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii")
plain_attrs = {}
for k, v in obj.__dict__.items():
if k.startswith("_"):
continue
if isinstance(v, (bool, int, float, str)):
plain_attrs[k] = v
return {
"__type__": "ModelSamplingInline",
"bases": bases,
"state_dict": sd_serialized,
"attrs": plain_attrs,
}
async def get_model_options(self, instance_id: str) -> dict:
instance = self._get_instance(instance_id)
import copy
@ -348,6 +373,20 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def get_ram_usage(self, instance_id: str) -> int:
return self._get_instance(instance_id).get_ram_usage()
async def model_mmap_residency(self, instance_id: str, free: bool = False) -> tuple:
return self._get_instance(instance_id).model_mmap_residency(free=free)
async def pinned_memory_size(self, instance_id: str) -> int:
return self._get_instance(instance_id).pinned_memory_size()
async def get_non_dynamic_delegate(self, instance_id: str) -> str:
instance = self._get_instance(instance_id)
delegate = instance.get_non_dynamic_delegate()
return self.register(delegate)
async def disable_model_cfg1_optimization(self, instance_id: str) -> None:
self._get_instance(instance_id).disable_model_cfg1_optimization()
async def lowvram_patch_counter(self, instance_id: str) -> int:
return self._get_instance(instance_id).lowvram_patch_counter()
@ -959,12 +998,54 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def get_inner_model_attr(self, instance_id: str, name: str) -> Any:
try:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id).model, name)
)
value = getattr(self._get_instance(instance_id).model, name)
if name == "model_config":
value = self._extract_model_config(value)
return self._sanitize_rpc_result(value)
except AttributeError:
return None
@staticmethod
def _extract_model_config(config: Any) -> dict:
"""Extract JSON-safe attributes from a model config object.
ComfyUI model config classes (supported_models_base.BASE subclasses)
have a permissive __getattr__ that returns None for any unknown
attribute instead of raising AttributeError. This defeats hasattr-based
duck-typing in _sanitize_rpc_result, causing TypeError when it tries
to call obj.items() (which resolves to None). We extract the real
class-level and instance-level attributes into a plain dict.
"""
# Attributes consumed by ModelSampling*.__init__ and other callers
_CONFIG_KEYS = (
"sampling_settings",
"unet_config",
"unet_extra_config",
"latent_format",
"manual_cast_dtype",
"custom_operations",
"optimizations",
"memory_usage_factor",
"supported_inference_dtypes",
)
result: dict = {}
for key in _CONFIG_KEYS:
# Use type(config).__dict__ first (class attrs), then instance __dict__
# to avoid triggering the permissive __getattr__
if key in type(config).__dict__:
val = type(config).__dict__[key]
# Skip classmethods/staticmethods/descriptors
if not callable(val) or isinstance(val, (dict, list, tuple)):
result[key] = val
elif hasattr(config, "__dict__") and key in config.__dict__:
result[key] = config.__dict__[key]
# Also include instance overrides (e.g. set_inference_dtype sets unet_config['dtype'])
if hasattr(config, "__dict__"):
for key, val in config.__dict__.items():
if key in _CONFIG_KEYS:
result[key] = val
return result
async def inner_model_memory_required(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:

View File

@ -118,6 +118,47 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
_GLOBAL_LOOP = loop
def run_sync_rpc_coro(coro: Any, timeout_ms: Optional[int] = None) -> Any:
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
try:
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
def call_singleton_rpc(
caller: Any,
method_name: str,
*args: Any,
timeout_ms: Optional[int] = None,
**kwargs: Any,
) -> Any:
if caller is None:
raise RuntimeError(f"No RPC caller available for {method_name}")
method = getattr(caller, method_name)
return run_sync_rpc_coro(method(*args, **kwargs), timeout_ms=timeout_ms)
class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base"
@ -208,31 +249,8 @@ class BaseProxy(Generic[T]):
)
try:
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
except concurrent.futures.TimeoutError as exc:
return run_sync_rpc_coro(coro, timeout_ms=timeout_ms)
except TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"

View File

@ -1,9 +1,35 @@
from __future__ import annotations
from typing import Dict
import os
from typing import Any, Dict, Optional
import folder_paths
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
def _folder_paths():
import folder_paths
return folder_paths
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
def _serialize_folder_names_and_paths(data: dict[str, tuple[list[str], set[str]]]) -> dict[str, dict[str, list[str]]]:
return {
key: {"paths": list(paths), "extensions": sorted(list(extensions))}
for key, (paths, extensions) in data.items()
}
def _deserialize_folder_names_and_paths(data: dict[str, dict[str, list[str]]]) -> dict[str, tuple[list[str], set[str]]]:
return {
key: (list(value.get("paths", [])), set(value.get("extensions", [])))
for key, value in data.items()
}
class FolderPathsProxy(ProxiedSingleton):
"""
@ -12,18 +38,165 @@ class FolderPathsProxy(ProxiedSingleton):
mutable collections to ensure efficient by-value transfer.
"""
def __getattr__(self, name):
return getattr(folder_paths, name)
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("FolderPathsProxy RPC caller is not configured")
return cls._rpc
def __getattr__(self, name):
if _is_child_process():
property_rpc = {
"models_dir": "rpc_get_models_dir",
"folder_names_and_paths": "rpc_get_folder_names_and_paths",
"extension_mimetypes_cache": "rpc_get_extension_mimetypes_cache",
"filename_list_cache": "rpc_get_filename_list_cache",
}
rpc_name = property_rpc.get(name)
if rpc_name is not None:
return call_singleton_rpc(self._get_caller(), rpc_name)
raise AttributeError(name)
return getattr(_folder_paths(), name)
# Return dict snapshots (avoid RPC chatter)
@property
def folder_names_and_paths(self) -> Dict:
return dict(folder_paths.folder_names_and_paths)
if _is_child_process():
payload = call_singleton_rpc(self._get_caller(), "rpc_get_folder_names_and_paths")
return _deserialize_folder_names_and_paths(payload)
return _folder_paths().folder_names_and_paths
@property
def extension_mimetypes_cache(self) -> Dict:
return dict(folder_paths.extension_mimetypes_cache)
if _is_child_process():
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_extension_mimetypes_cache"))
return dict(_folder_paths().extension_mimetypes_cache)
@property
def filename_list_cache(self) -> Dict:
return dict(folder_paths.filename_list_cache)
if _is_child_process():
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list_cache"))
return dict(_folder_paths().filename_list_cache)
@property
def models_dir(self) -> str:
if _is_child_process():
return str(call_singleton_rpc(self._get_caller(), "rpc_get_models_dir"))
return _folder_paths().models_dir
def get_temp_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_temp_directory")
return _folder_paths().get_temp_directory()
def get_input_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_input_directory")
return _folder_paths().get_input_directory()
def get_output_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_output_directory")
return _folder_paths().get_output_directory()
def get_user_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_user_directory")
return _folder_paths().get_user_directory()
def get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
if _is_child_process():
return call_singleton_rpc(
self._get_caller(), "rpc_get_annotated_filepath", name, default_dir
)
return _folder_paths().get_annotated_filepath(name, default_dir)
def exists_annotated_filepath(self, name: str) -> bool:
if _is_child_process():
return bool(
call_singleton_rpc(self._get_caller(), "rpc_exists_annotated_filepath", name)
)
return bool(_folder_paths().exists_annotated_filepath(name))
def add_model_folder_path(
self, folder_name: str, full_folder_path: str, is_default: bool = False
) -> None:
if _is_child_process():
call_singleton_rpc(
self._get_caller(),
"rpc_add_model_folder_path",
folder_name,
full_folder_path,
is_default,
)
return None
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
return None
def get_folder_paths(self, folder_name: str) -> list[str]:
if _is_child_process():
return list(call_singleton_rpc(self._get_caller(), "rpc_get_folder_paths", folder_name))
return list(_folder_paths().get_folder_paths(folder_name))
def get_filename_list(self, folder_name: str) -> list[str]:
if _is_child_process():
return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
return list(_folder_paths().get_filename_list(folder_name))
def get_full_path(self, folder_name: str, filename: str) -> str | None:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_full_path", folder_name, filename)
return _folder_paths().get_full_path(folder_name, filename)
async def rpc_get_models_dir(self) -> str:
return _folder_paths().models_dir
async def rpc_get_folder_names_and_paths(self) -> dict[str, dict[str, list[str]]]:
return _serialize_folder_names_and_paths(_folder_paths().folder_names_and_paths)
async def rpc_get_extension_mimetypes_cache(self) -> dict[str, Any]:
return dict(_folder_paths().extension_mimetypes_cache)
async def rpc_get_filename_list_cache(self) -> dict[str, Any]:
return dict(_folder_paths().filename_list_cache)
async def rpc_get_temp_directory(self) -> str:
return _folder_paths().get_temp_directory()
async def rpc_get_input_directory(self) -> str:
return _folder_paths().get_input_directory()
async def rpc_get_output_directory(self) -> str:
return _folder_paths().get_output_directory()
async def rpc_get_user_directory(self) -> str:
return _folder_paths().get_user_directory()
async def rpc_get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
return _folder_paths().get_annotated_filepath(name, default_dir)
async def rpc_exists_annotated_filepath(self, name: str) -> bool:
return _folder_paths().exists_annotated_filepath(name)
async def rpc_add_model_folder_path(
self, folder_name: str, full_folder_path: str, is_default: bool = False
) -> None:
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
async def rpc_get_folder_paths(self, folder_name: str) -> list[str]:
return _folder_paths().get_folder_paths(folder_name)
async def rpc_get_filename_list(self, folder_name: str) -> list[str]:
return _folder_paths().get_filename_list(folder_name)
async def rpc_get_full_path(self, folder_name: str, filename: str) -> str | None:
return _folder_paths().get_full_path(folder_name, filename)

View File

@ -1,7 +1,12 @@
from __future__ import annotations
import os
from typing import Any, Dict, Optional
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
class AnyTypeProxy(str):
"""Replacement for custom AnyType objects used by some nodes."""
@ -71,9 +76,29 @@ def _restore_special_value(value: Any) -> Any:
return value
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
def _serialize_special_value(value: Any) -> Any:
if isinstance(value, AnyTypeProxy):
return {"__pyisolate_any_type__": True, "value": str(value)}
if isinstance(value, FlexibleOptionalInputProxy):
return {
"__pyisolate_flexible_optional__": True,
"type": _serialize_special_value(value.type),
"data": {k: _serialize_special_value(v) for k, v in value.items()},
}
if isinstance(value, ByPassTypeTupleProxy):
return {
"__pyisolate_bypass_tuple__": [_serialize_special_value(v) for v in value]
}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_serialize_special_value(v) for v in value]}
if isinstance(value, list):
return [_serialize_special_value(v) for v in value]
if isinstance(value, dict):
return {k: _serialize_special_value(v) for k, v in value.items()}
return value
def _restore_input_types_local(raw: Dict[str, object]) -> Dict[str, object]:
if not isinstance(raw, dict):
return raw # type: ignore[return-value]
@ -90,9 +115,44 @@ def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
return restored
class HelperProxiesService(ProxiedSingleton):
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("HelperProxiesService RPC caller is not configured")
return cls._rpc
async def rpc_restore_input_types(self, raw: Dict[str, object]) -> Dict[str, object]:
restored = _restore_input_types_local(raw)
return _serialize_special_value(restored)
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
if os.environ.get("PYISOLATE_CHILD") == "1":
payload = call_singleton_rpc(
HelperProxiesService._get_caller(),
"rpc_restore_input_types",
raw,
)
return _restore_input_types_local(payload)
return _restore_input_types_local(raw)
__all__ = [
"AnyTypeProxy",
"FlexibleOptionalInputProxy",
"ByPassTypeTupleProxy",
"HelperProxiesService",
"restore_input_types",
]

View File

@ -1,27 +1,142 @@
import comfy.model_management as mm
from __future__ import annotations
import os
from typing import Any, Optional
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
def _mm():
import comfy.model_management
return comfy.model_management
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
class TorchDeviceProxy:
def __init__(self, device_str: str):
self._device_str = device_str
if ":" in device_str:
device_type, index = device_str.split(":", 1)
self.type = device_type
self.index = int(index)
else:
self.type = device_str
self.index = None
def __str__(self) -> str:
return self._device_str
def __repr__(self) -> str:
return f"TorchDeviceProxy({self._device_str!r})"
def _serialize_value(value: Any) -> Any:
value_type = type(value)
if value_type.__module__ == "torch" and value_type.__name__ == "device":
return {"__pyisolate_torch_device__": str(value)}
if isinstance(value, TorchDeviceProxy):
return {"__pyisolate_torch_device__": str(value)}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_serialize_value(item) for item in value]}
if isinstance(value, list):
return [_serialize_value(item) for item in value]
if isinstance(value, dict):
return {key: _serialize_value(inner) for key, inner in value.items()}
return value
def _deserialize_value(value: Any) -> Any:
if isinstance(value, dict):
if "__pyisolate_torch_device__" in value:
return TorchDeviceProxy(value["__pyisolate_torch_device__"])
if "__pyisolate_tuple__" in value:
return tuple(_deserialize_value(item) for item in value["__pyisolate_tuple__"])
return {key: _deserialize_value(inner) for key, inner in value.items()}
if isinstance(value, list):
return [_deserialize_value(item) for item in value]
return value
def _normalize_argument(value: Any) -> Any:
if isinstance(value, TorchDeviceProxy):
import torch
return torch.device(str(value))
if isinstance(value, dict):
if "__pyisolate_torch_device__" in value:
import torch
return torch.device(value["__pyisolate_torch_device__"])
if "__pyisolate_tuple__" in value:
return tuple(_normalize_argument(item) for item in value["__pyisolate_tuple__"])
return {key: _normalize_argument(inner) for key, inner in value.items()}
if isinstance(value, list):
return [_normalize_argument(item) for item in value]
return value
class ModelManagementProxy(ProxiedSingleton):
"""
Dynamic proxy for comfy.model_management.
Uses __getattr__ to forward all calls to the underlying module,
reducing maintenance burden.
Exact-relay proxy for comfy.model_management.
Child calls never import comfy.model_management directly; they serialize
arguments, relay to host, and deserialize the host result back.
"""
# Explicitly expose Enums/Classes as properties
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("ModelManagementProxy RPC caller is not configured")
return cls._rpc
def _relay_call(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
payload = call_singleton_rpc(
self._get_caller(),
"rpc_call",
method_name,
_serialize_value(args),
_serialize_value(kwargs),
)
return _deserialize_value(payload)
@property
def VRAMState(self):
return mm.VRAMState
return _mm().VRAMState
@property
def CPUState(self):
return mm.CPUState
return _mm().CPUState
@property
def OOM_EXCEPTION(self):
return mm.OOM_EXCEPTION
return _mm().OOM_EXCEPTION
def __getattr__(self, name):
"""Forward all other attribute access to the module."""
return getattr(mm, name)
def __getattr__(self, name: str):
if _is_child_process():
def child_method(*args: Any, **kwargs: Any) -> Any:
return self._relay_call(name, *args, **kwargs)
return child_method
return getattr(_mm(), name)
async def rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
normalized_args = _normalize_argument(_deserialize_value(args))
normalized_kwargs = _normalize_argument(_deserialize_value(kwargs))
method = getattr(_mm(), method_name)
result = method(*normalized_args, **normalized_kwargs)
return _serialize_value(result)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import os
from typing import Any, Optional
try:
@ -10,13 +11,38 @@ except ImportError:
class ProxiedSingleton:
pass
from .base import call_singleton_rpc
from comfy_execution.progress import get_progress_state
def _get_progress_state():
from comfy_execution.progress import get_progress_state
return get_progress_state()
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
logger = logging.getLogger(__name__)
class ProgressProxy(ProxiedSingleton):
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("ProgressProxy RPC caller is not configured")
return cls._rpc
def set_progress(
self,
value: float,
@ -24,7 +50,33 @@ class ProgressProxy(ProxiedSingleton):
node_id: Optional[str] = None,
image: Any = None,
) -> None:
get_progress_state().update_progress(
if _is_child_process():
call_singleton_rpc(
self._get_caller(),
"rpc_set_progress",
value,
max_value,
node_id,
image,
)
return None
_get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=image,
)
return None
async def rpc_set_progress(
self,
value: float,
max_value: float,
node_id: Optional[str] = None,
image: Any = None,
) -> None:
_get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,

View File

@ -13,10 +13,10 @@ import os
from typing import Any, Dict, Optional, Callable
import logging
from aiohttp import web
# IMPORTS
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
logger = logging.getLogger(__name__)
LOG_PREFIX = "[Isolation:C<->H]"
@ -64,6 +64,10 @@ class PromptServerStub:
PromptServerService, target_id
) # We import Service below?
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
# We need PromptServerService available for the create_caller call?
# Or just use the Stub class if ID matches?
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
@ -133,7 +137,7 @@ class PromptServerStub:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
except RuntimeError:
pass # Sync context without loop?
call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid)
# --- Route Registration Logic ---
def register_route(self, method: str, path: str, handler: Callable):
@ -147,7 +151,7 @@ class PromptServerStub:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
except RuntimeError:
pass
call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler)
class RouteStub:
@ -226,6 +230,7 @@ class PromptServerService(ProxiedSingleton):
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
"""RPC Target: Register a route that forwards to the Child."""
from aiohttp import web
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
async def route_wrapper(request: web.Request) -> web.Response:
@ -251,8 +256,9 @@ class PromptServerService(ProxiedSingleton):
# Register loop
self.server.app.router.add_route(method, path, route_wrapper)
def _serialize_response(self, result: Any) -> web.Response:
def _serialize_response(self, result: Any) -> Any:
"""Helper to convert Child result -> web.Response"""
from aiohttp import web
if isinstance(result, web.Response):
return result
# Handle dict (json)

View File

@ -2,12 +2,16 @@
from __future__ import annotations
from typing import Optional, Any
import comfy.utils
from pyisolate import ProxiedSingleton
import os
def _comfy_utils():
import comfy.utils
return comfy.utils
class UtilsProxy(ProxiedSingleton):
"""
Proxy for comfy.utils.
@ -23,6 +27,10 @@ class UtilsProxy(ProxiedSingleton):
# Create caller using class name as ID (standard for Singletons)
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
async def progress_bar_hook(
self,
value: int,
@ -35,30 +43,22 @@ class UtilsProxy(ProxiedSingleton):
Child-side: this method call is intercepted by RPC and sent to host.
"""
if os.environ.get("PYISOLATE_CHILD") == "1":
# Manual RPC dispatch for Child process
# Use class-level RPC storage (Static Injection)
if UtilsProxy._rpc:
return await UtilsProxy._rpc.progress_bar_hook(
value, total, preview, node_id
)
# Fallback channel: global child rpc
try:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
get_child_rpc_instance()
# If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it,
# but we need a caller. For now, just pass to avoid crashing.
pass
except (ImportError, LookupError):
pass
return None
if UtilsProxy._rpc is None:
raise RuntimeError("UtilsProxy RPC caller is not configured")
return await UtilsProxy._rpc.progress_bar_hook(
value, total, preview, node_id
)
# Host Execution
if comfy.utils.PROGRESS_BAR_HOOK is not None:
comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
utils = _comfy_utils()
if utils.PROGRESS_BAR_HOOK is not None:
return utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
return None
def set_progress_bar_global_hook(self, hook: Any) -> None:
"""Forward hook registration (though usually not needed from child)."""
comfy.utils.set_progress_bar_global_hook(hook)
if os.environ.get("PYISOLATE_CHILD") == "1":
raise RuntimeError(
"UtilsProxy.set_progress_bar_global_hook is not available in child without exact relay support"
)
_comfy_utils().set_progress_bar_global_hook(hook)

View File

@ -0,0 +1,219 @@
"""WebDirectoryProxy — serves isolated node web assets via RPC.
Child side: enumerates and reads files from the extension's web/ directory.
Host side: gets an RPC proxy that fetches file listings and contents on demand.
Only files with allowed extensions (.js, .html, .css) are served.
Directory traversal is rejected. File contents are base64-encoded for
safe JSON-RPC transport.
"""
from __future__ import annotations
import base64
import logging
import os
from pathlib import Path, PurePosixPath
from typing import Any, Dict, List
from pyisolate import ProxiedSingleton
logger = logging.getLogger(__name__)
ALLOWED_EXTENSIONS = frozenset({".js", ".html", ".css"})
MIME_TYPES = {
".js": "application/javascript",
".html": "text/html",
".css": "text/css",
}
class WebDirectoryProxy(ProxiedSingleton):
"""Proxy for serving isolated extension web directories.
On the child side, this class has direct filesystem access to the
extension's web/ directory. On the host side, callers get an RPC
proxy whose method calls are forwarded to the child.
"""
# {extension_name: absolute_path_to_web_dir}
_web_dirs: dict[str, str] = {}
@classmethod
def register_web_dir(cls, extension_name: str, web_dir_path: str) -> None:
"""Register an extension's web directory (child-side only)."""
cls._web_dirs[extension_name] = web_dir_path
logger.info(
"][ WebDirectoryProxy: registered %s -> %s",
extension_name,
web_dir_path,
)
def list_web_files(self, extension_name: str) -> List[Dict[str, str]]:
"""Return a list of servable files in the extension's web directory.
Each entry is {"relative_path": "js/foo.js", "content_type": "application/javascript"}.
Only files with allowed extensions are included.
"""
web_dir = self._web_dirs.get(extension_name)
if not web_dir:
return []
root = Path(web_dir)
if not root.is_dir():
return []
result: List[Dict[str, str]] = []
for path in sorted(root.rglob("*")):
if not path.is_file():
continue
ext = path.suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
continue
rel = path.relative_to(root)
result.append({
"relative_path": str(PurePosixPath(rel)),
"content_type": MIME_TYPES[ext],
})
return result
def get_web_file(
self, extension_name: str, relative_path: str
) -> Dict[str, Any]:
"""Return the contents of a single web file as base64.
Raises ValueError for traversal attempts or disallowed file types.
Returns {"content": <base64 str>, "content_type": <MIME str>}.
"""
_validate_path(relative_path)
web_dir = self._web_dirs.get(extension_name)
if not web_dir:
raise FileNotFoundError(
f"No web directory registered for {extension_name}"
)
root = Path(web_dir)
target = (root / relative_path).resolve()
# Ensure resolved path is under the web directory
if not str(target).startswith(str(root.resolve())):
raise ValueError(f"Path escapes web directory: {relative_path}")
if not target.is_file():
raise FileNotFoundError(f"File not found: {relative_path}")
ext = target.suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise ValueError(f"Disallowed file type: {ext}")
content_type = MIME_TYPES[ext]
raw = target.read_bytes()
return {
"content": base64.b64encode(raw).decode("ascii"),
"content_type": content_type,
}
def _validate_path(relative_path: str) -> None:
"""Reject directory traversal and absolute paths."""
if os.path.isabs(relative_path):
raise ValueError(f"Absolute paths are not allowed: {relative_path}")
if ".." in PurePosixPath(relative_path).parts:
raise ValueError(f"Directory traversal is not allowed: {relative_path}")
# ---------------------------------------------------------------------------
# Host-side cache and aiohttp handler
# ---------------------------------------------------------------------------
class WebDirectoryCache:
"""Host-side in-memory cache for proxied web directory contents.
Populated lazily via RPC calls to the child's WebDirectoryProxy.
Once a file is cached, subsequent requests are served from memory.
"""
def __init__(self) -> None:
# {extension_name: {relative_path: {"content": bytes, "content_type": str}}}
self._file_cache: dict[str, dict[str, dict[str, Any]]] = {}
# {extension_name: [{"relative_path": str, "content_type": str}, ...]}
self._listing_cache: dict[str, list[dict[str, str]]] = {}
# {extension_name: WebDirectoryProxy (RPC proxy instance)}
self._proxies: dict[str, Any] = {}
def register_proxy(self, extension_name: str, proxy: Any) -> None:
"""Register an RPC proxy for an extension's web directory."""
self._proxies[extension_name] = proxy
logger.info(
"][ WebDirectoryCache: registered proxy for %s", extension_name
)
@property
def extension_names(self) -> list[str]:
return list(self._proxies.keys())
def list_files(self, extension_name: str) -> list[dict[str, str]]:
"""List servable files for an extension (cached after first call)."""
if extension_name not in self._listing_cache:
proxy = self._proxies.get(extension_name)
if proxy is None:
return []
try:
self._listing_cache[extension_name] = proxy.list_web_files(
extension_name
)
except Exception:
logger.warning(
"][ WebDirectoryCache: failed to list files for %s",
extension_name,
exc_info=True,
)
return []
return self._listing_cache[extension_name]
def get_file(
self, extension_name: str, relative_path: str
) -> dict[str, Any] | None:
"""Get file content (cached after first fetch). Returns None on miss."""
ext_cache = self._file_cache.get(extension_name)
if ext_cache and relative_path in ext_cache:
return ext_cache[relative_path]
proxy = self._proxies.get(extension_name)
if proxy is None:
return None
try:
result = proxy.get_web_file(extension_name, relative_path)
except (FileNotFoundError, ValueError):
return None
except Exception:
logger.warning(
"][ WebDirectoryCache: failed to fetch %s/%s",
extension_name,
relative_path,
exc_info=True,
)
return None
decoded = {
"content": base64.b64decode(result["content"]),
"content_type": result["content_type"],
}
if extension_name not in self._file_cache:
self._file_cache[extension_name] = {}
self._file_cache[extension_name][relative_path] = decoded
return decoded
# Global cache instance — populated during isolation loading
_web_directory_cache = WebDirectoryCache()
def get_web_directory_cache() -> WebDirectoryCache:
return _web_directory_cache

View File

@ -8,10 +8,17 @@ from pathlib import Path
from typing import Any, Dict, List, Set, TYPE_CHECKING
from .proxies.helper_proxies import restore_input_types
from comfy_api.internal import _ComfyNodeInternal
from comfy_api.latest import _io as latest_io
from .shm_forensics import scan_shm_forensics
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
_ComfyNodeInternal = object
latest_io = None
if _IMPORT_TORCH:
from comfy_api.internal import _ComfyNodeInternal
from comfy_api.latest import _io as latest_io
if TYPE_CHECKING:
from .extension_wrapper import ComfyNodeExtension
@ -19,6 +26,68 @@ LOG_PREFIX = "]["
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
class _RemoteObjectRegistryCaller:
def __init__(self, extension: Any) -> None:
self._extension = extension
def __getattr__(self, method_name: str) -> Any:
async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any:
return await self._extension.call_remote_object_method(
instance_id,
method_name,
*args,
**kwargs,
)
return _call
def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any:
from pyisolate._internal.remote_handle import RemoteObjectHandle
if isinstance(value, RemoteObjectHandle):
if value.type_name == "ModelPatcher":
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False)
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
return proxy
if value.type_name == "VAE":
from comfy.isolation.vae_proxy import VAEProxy
proxy = VAEProxy(value.object_id, manage_lifecycle=False)
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
return proxy
if value.type_name == "CLIP":
from comfy.isolation.clip_proxy import CLIPProxy
proxy = CLIPProxy(value.object_id, manage_lifecycle=False)
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
return proxy
if value.type_name == "ModelSampling":
from comfy.isolation.model_sampling_proxy import ModelSamplingProxy
proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False)
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
return proxy
return value
if isinstance(value, dict):
return {
k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items()
}
if isinstance(value, (list, tuple)):
wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value]
return type(value)(wrapped)
return value
def _resource_snapshot() -> Dict[str, int]:
fd_count = -1
shm_sender_files = 0
@ -146,6 +215,8 @@ def build_stub_class(
running_extensions: Dict[str, "ComfyNodeExtension"],
logger: logging.Logger,
) -> type:
if latest_io is None:
raise RuntimeError("comfy_api.latest._io is required to build isolation stubs")
is_v3 = bool(info.get("is_v3", False))
function_name = "_pyisolate_execute"
restored_input_types = restore_input_types(info.get("input_types", {}))
@ -160,6 +231,13 @@ def build_stub_class(
node_unique_id = _extract_hidden_unique_id(inputs)
summary = _tensor_transport_summary(inputs)
resources = _resource_snapshot()
logger.debug(
"%s ISO:execute_start ext=%s node=%s uid=%s",
LOG_PREFIX,
extension.name,
node_name,
node_unique_id or "-",
)
logger.debug(
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
LOG_PREFIX,
@ -192,7 +270,20 @@ def build_stub_class(
node_name,
node_unique_id or "-",
)
serialized = serialize_for_isolation(inputs)
# Unwrap NodeOutput-like dicts before serialization.
# OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)}
# and the executor may pass this dict as input to downstream nodes.
unwrapped_inputs = {}
for k, v in inputs.items():
if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v):
result = v.get("result")
if isinstance(result, (tuple, list)) and len(result) > 0:
unwrapped_inputs[k] = result[0]
else:
unwrapped_inputs[k] = result
else:
unwrapped_inputs[k] = v
serialized = serialize_for_isolation(unwrapped_inputs)
logger.debug(
"%s ISO:serialize_done ext=%s node=%s uid=%s",
LOG_PREFIX,
@ -220,15 +311,32 @@ def build_stub_class(
from comfy_api.latest import io as latest_io
args_raw = result.get("args", ())
deserialized_args = await deserialize_from_isolation(args_raw, extension)
deserialized_args = _wrap_remote_handles_as_host_proxies(
deserialized_args, extension
)
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
ui_raw = result.get("ui")
deserialized_ui = None
if ui_raw is not None:
deserialized_ui = await deserialize_from_isolation(ui_raw, extension)
deserialized_ui = _wrap_remote_handles_as_host_proxies(
deserialized_ui, extension
)
deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui)
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
return latest_io.NodeOutput(
*deserialized_args,
ui=result.get("ui"),
ui=deserialized_ui,
expand=result.get("expand"),
block_execution=result.get("block_execution"),
)
# OUTPUT_NODE: if sealed worker returned a tuple/list whose first
# element is a {"ui": ...} dict, unwrap it for the executor.
if (isinstance(result, (tuple, list)) and len(result) == 1
and isinstance(result[0], dict) and "ui" in result[0]):
return result[0]
deserialized = await deserialize_from_isolation(result, extension)
deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension)
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
return _detach_shared_cpu_tensors(deserialized)
except ImportError:

View File

@ -65,6 +65,22 @@ class SavedAudios(_UIOutput):
return {"audio": self.results}
def _is_isolated_child() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
def _get_preview_folder_type() -> FolderType:
if _is_isolated_child():
return FolderType.output
return FolderType.temp
def _get_preview_route_prefix(folder_type: FolderType) -> str:
if folder_type == FolderType.output:
return "output"
return "temp"
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
if folder_type == FolderType.input:
return folder_paths.get_input_directory()
@ -388,10 +404,11 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
folder_type = _get_preview_folder_type()
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
folder_type=FolderType.temp,
folder_type=folder_type,
cls=cls,
compress_level=1,
)
@ -412,10 +429,11 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
folder_type = _get_preview_folder_type()
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
folder_type=FolderType.temp,
folder_type=folder_type,
cls=cls,
format="flac",
quality="128k",
@ -438,15 +456,16 @@ class PreviewUI3D(_UIOutput):
self.model_file = model_file
self.camera_info = camera_info
self.bg_image_path = None
folder_type = _get_preview_folder_type()
bg_image = kwargs.get("bg_image", None)
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = PILImage.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory()
preview_dir = _get_directory_by_folder_type(folder_type)
filename = f"bg_{uuid.uuid4().hex}.png"
bg_image_path = os.path.join(temp_dir, filename)
bg_image_path = os.path.join(preview_dir, filename)
img.save(bg_image_path, compress_level=1)
self.bg_image_path = f"temp/{filename}"
self.bg_image_path = f"{_get_preview_route_prefix(folder_type)}/{filename}"
def as_dict(self):
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}

View File

@ -0,0 +1,259 @@
from __future__ import annotations
import numpy as np
class TrimeshData:
"""Triangular mesh payload for cross-process transfer.
Lightweight carrier for mesh geometry that does not depend on the
``trimesh`` library. Serializers create this on the host side;
isolated child processes convert to/from ``trimesh.Trimesh`` as needed.
Supports both ColorVisuals (vertex_colors) and TextureVisuals
(uv + material with textures).
"""
def __init__(
self,
vertices: np.ndarray,
faces: np.ndarray,
vertex_normals: np.ndarray | None = None,
face_normals: np.ndarray | None = None,
vertex_colors: np.ndarray | None = None,
uv: np.ndarray | None = None,
material: dict | None = None,
vertex_attributes: dict | None = None,
face_attributes: dict | None = None,
metadata: dict | None = None,
) -> None:
self.vertices = np.ascontiguousarray(vertices, dtype=np.float64)
self.faces = np.ascontiguousarray(faces, dtype=np.int64)
self.vertex_normals = (
np.ascontiguousarray(vertex_normals, dtype=np.float64)
if vertex_normals is not None
else None
)
self.face_normals = (
np.ascontiguousarray(face_normals, dtype=np.float64)
if face_normals is not None
else None
)
self.vertex_colors = (
np.ascontiguousarray(vertex_colors, dtype=np.uint8)
if vertex_colors is not None
else None
)
self.uv = (
np.ascontiguousarray(uv, dtype=np.float64)
if uv is not None
else None
)
self.material = material
self.vertex_attributes = vertex_attributes or {}
self.face_attributes = face_attributes or {}
self.metadata = self._detensorize_dict(metadata) if metadata else {}
@staticmethod
def _detensorize_dict(d):
"""Recursively convert any tensors in a dict back to numpy arrays."""
if not isinstance(d, dict):
return d
result = {}
for k, v in d.items():
if hasattr(v, "numpy"):
result[k] = v.cpu().numpy() if hasattr(v, "cpu") else v.numpy()
elif isinstance(v, dict):
result[k] = TrimeshData._detensorize_dict(v)
elif isinstance(v, list):
result[k] = [
item.cpu().numpy() if hasattr(item, "numpy") and hasattr(item, "cpu")
else item.numpy() if hasattr(item, "numpy")
else item
for item in v
]
else:
result[k] = v
return result
@staticmethod
def _to_numpy(arr, dtype):
if arr is None:
return None
if hasattr(arr, "numpy"):
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
return np.ascontiguousarray(arr, dtype=dtype)
@property
def num_vertices(self) -> int:
return self.vertices.shape[0]
@property
def num_faces(self) -> int:
return self.faces.shape[0]
@property
def has_texture(self) -> bool:
return self.uv is not None and self.material is not None
def to_trimesh(self):
"""Convert to trimesh.Trimesh (requires trimesh in the environment)."""
import trimesh
from trimesh.visual import TextureVisuals
kwargs = {}
if self.vertex_normals is not None:
kwargs["vertex_normals"] = self.vertex_normals
if self.face_normals is not None:
kwargs["face_normals"] = self.face_normals
if self.metadata:
kwargs["metadata"] = self.metadata
mesh = trimesh.Trimesh(
vertices=self.vertices, faces=self.faces, process=False, **kwargs
)
# Reconstruct visual
if self.has_texture:
material = self._dict_to_material(self.material)
mesh.visual = TextureVisuals(uv=self.uv, material=material)
elif self.vertex_colors is not None:
mesh.visual.vertex_colors = self.vertex_colors
for k, v in self.vertex_attributes.items():
mesh.vertex_attributes[k] = v
for k, v in self.face_attributes.items():
mesh.face_attributes[k] = v
return mesh
@staticmethod
def _material_to_dict(material) -> dict:
"""Serialize a trimesh material to a plain dict."""
import base64
from io import BytesIO
from trimesh.visual.material import PBRMaterial, SimpleMaterial
result = {"type": type(material).__name__, "name": getattr(material, "name", None)}
if isinstance(material, PBRMaterial):
result["baseColorFactor"] = material.baseColorFactor
result["metallicFactor"] = material.metallicFactor
result["roughnessFactor"] = material.roughnessFactor
result["emissiveFactor"] = material.emissiveFactor
result["alphaMode"] = material.alphaMode
result["alphaCutoff"] = material.alphaCutoff
result["doubleSided"] = material.doubleSided
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
"metallicRoughnessTexture", "occlusionTexture"):
tex = getattr(material, tex_name, None)
if tex is not None:
buf = BytesIO()
tex.save(buf, format="PNG")
result[tex_name] = base64.b64encode(buf.getvalue()).decode("ascii")
elif isinstance(material, SimpleMaterial):
result["main_color"] = list(material.main_color) if material.main_color is not None else None
result["glossiness"] = material.glossiness
if hasattr(material, "image") and material.image is not None:
buf = BytesIO()
material.image.save(buf, format="PNG")
result["image"] = base64.b64encode(buf.getvalue()).decode("ascii")
return result
@staticmethod
def _dict_to_material(d: dict):
"""Reconstruct a trimesh material from a plain dict."""
import base64
from io import BytesIO
from PIL import Image
from trimesh.visual.material import PBRMaterial, SimpleMaterial
mat_type = d.get("type", "PBRMaterial")
if mat_type == "PBRMaterial":
kwargs = {
"name": d.get("name"),
"baseColorFactor": d.get("baseColorFactor"),
"metallicFactor": d.get("metallicFactor"),
"roughnessFactor": d.get("roughnessFactor"),
"emissiveFactor": d.get("emissiveFactor"),
"alphaMode": d.get("alphaMode"),
"alphaCutoff": d.get("alphaCutoff"),
"doubleSided": d.get("doubleSided"),
}
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
"metallicRoughnessTexture", "occlusionTexture"):
if tex_name in d and d[tex_name] is not None:
img = Image.open(BytesIO(base64.b64decode(d[tex_name])))
kwargs[tex_name] = img
return PBRMaterial(**{k: v for k, v in kwargs.items() if v is not None})
elif mat_type == "SimpleMaterial":
kwargs = {
"name": d.get("name"),
"glossiness": d.get("glossiness"),
}
if d.get("main_color") is not None:
kwargs["diffuse"] = d["main_color"]
if d.get("image") is not None:
kwargs["image"] = Image.open(BytesIO(base64.b64decode(d["image"])))
return SimpleMaterial(**kwargs)
raise ValueError(f"Unknown material type: {mat_type}")
@classmethod
def from_trimesh(cls, mesh) -> TrimeshData:
"""Create from a trimesh.Trimesh object."""
from trimesh.visual.texture import TextureVisuals
vertex_normals = None
if mesh._cache.cache.get("vertex_normals") is not None:
vertex_normals = np.asarray(mesh.vertex_normals)
face_normals = None
if mesh._cache.cache.get("face_normals") is not None:
face_normals = np.asarray(mesh.face_normals)
vertex_colors = None
uv = None
material = None
if isinstance(mesh.visual, TextureVisuals):
if mesh.visual.uv is not None:
uv = np.asarray(mesh.visual.uv, dtype=np.float64)
if mesh.visual.material is not None:
material = cls._material_to_dict(mesh.visual.material)
else:
try:
vc = mesh.visual.vertex_colors
if vc is not None and len(vc) > 0:
vertex_colors = np.asarray(vc, dtype=np.uint8)
except Exception:
pass
va = {}
if hasattr(mesh, "vertex_attributes") and mesh.vertex_attributes:
for k, v in mesh.vertex_attributes.items():
va[k] = np.asarray(v) if hasattr(v, "__array__") else v
fa = {}
if hasattr(mesh, "face_attributes") and mesh.face_attributes:
for k, v in mesh.face_attributes.items():
fa[k] = np.asarray(v) if hasattr(v, "__array__") else v
return cls(
vertices=np.asarray(mesh.vertices),
faces=np.asarray(mesh.faces),
vertex_normals=vertex_normals,
face_normals=face_normals,
vertex_colors=vertex_colors,
uv=uv,
material=material,
vertex_attributes=va if va else None,
face_attributes=fa if fa else None,
metadata=mesh.metadata if mesh.metadata else None,
)

View File

@ -0,0 +1,18 @@
"""comfy_api_sealed_worker — torch-free type definitions for sealed worker children.
Drop-in replacement for comfy_api.latest._util type imports in sealed workers
that do not have torch installed. Contains only data type definitions (TrimeshData,
PLY, NPZ, etc.) with numpy-only dependencies.
Usage in serializers:
if _IMPORT_TORCH:
from comfy_api.latest._util.trimesh_types import TrimeshData
else:
from comfy_api_sealed_worker.trimesh_types import TrimeshData
"""
from .trimesh_types import TrimeshData
from .ply_types import PLY
from .npz_types import NPZ
__all__ = ["TrimeshData", "PLY", "NPZ"]

View File

@ -0,0 +1,27 @@
from __future__ import annotations
import os
class NPZ:
"""Ordered collection of NPZ file payloads.
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
``save_to`` writes numbered files into a directory.
"""
def __init__(self, frames: list[bytes]) -> None:
self.frames = frames
@property
def num_frames(self) -> int:
return len(self.frames)
def save_to(self, directory: str, prefix: str = "frame") -> str:
os.makedirs(directory, exist_ok=True)
for i, frame_bytes in enumerate(self.frames):
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
with open(path, "wb") as f:
f.write(frame_bytes)
return directory

View File

@ -0,0 +1,97 @@
from __future__ import annotations
import numpy as np
class PLY:
"""Point cloud payload for PLY file output.
Supports two schemas:
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
When ``raw_data`` is provided, the object acts as an opaque binary PLY
carrier and ``save_to`` writes the bytes directly.
"""
def __init__(
self,
points: np.ndarray | None = None,
colors: np.ndarray | None = None,
confidence: np.ndarray | None = None,
view_id: np.ndarray | None = None,
raw_data: bytes | None = None,
) -> None:
self.raw_data = raw_data
if raw_data is not None:
self.points = None
self.colors = None
self.confidence = None
self.view_id = None
return
if points is None:
raise ValueError("Either points or raw_data must be provided")
if points.ndim != 2 or points.shape[1] != 3:
raise ValueError(f"points must be (N, 3), got {points.shape}")
self.points = np.ascontiguousarray(points, dtype=np.float32)
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
@property
def is_gaussian(self) -> bool:
return self.raw_data is not None
@property
def num_points(self) -> int:
if self.points is not None:
return self.points.shape[0]
return 0
@staticmethod
def _to_numpy(arr, dtype):
if arr is None:
return None
if hasattr(arr, "numpy"):
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
return np.ascontiguousarray(arr, dtype=dtype)
def save_to(self, path: str) -> str:
if self.raw_data is not None:
with open(path, "wb") as f:
f.write(self.raw_data)
return path
self.points = self._to_numpy(self.points, np.float32)
self.colors = self._to_numpy(self.colors, np.float32)
self.confidence = self._to_numpy(self.confidence, np.float32)
self.view_id = self._to_numpy(self.view_id, np.int32)
N = self.num_points
header_lines = [
"ply",
"format ascii 1.0",
f"element vertex {N}",
"property float x",
"property float y",
"property float z",
]
if self.colors is not None:
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
if self.confidence is not None:
header_lines.append("property float confidence")
if self.view_id is not None:
header_lines.append("property int view_id")
header_lines.append("end_header")
with open(path, "w") as f:
f.write("\n".join(header_lines) + "\n")
for i in range(N):
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
if self.colors is not None:
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
parts.append(f"{r} {g} {b}")
if self.confidence is not None:
parts.append(f"{self.confidence[i]}")
if self.view_id is not None:
parts.append(f"{int(self.view_id[i])}")
f.write(" ".join(parts) + "\n")
return path

View File

@ -0,0 +1,259 @@
from __future__ import annotations
import numpy as np
class TrimeshData:
"""Triangular mesh payload for cross-process transfer.
Lightweight carrier for mesh geometry that does not depend on the
``trimesh`` library. Serializers create this on the host side;
isolated child processes convert to/from ``trimesh.Trimesh`` as needed.
Supports both ColorVisuals (vertex_colors) and TextureVisuals
(uv + material with textures).
"""
def __init__(
self,
vertices: np.ndarray,
faces: np.ndarray,
vertex_normals: np.ndarray | None = None,
face_normals: np.ndarray | None = None,
vertex_colors: np.ndarray | None = None,
uv: np.ndarray | None = None,
material: dict | None = None,
vertex_attributes: dict | None = None,
face_attributes: dict | None = None,
metadata: dict | None = None,
) -> None:
self.vertices = np.ascontiguousarray(vertices, dtype=np.float64)
self.faces = np.ascontiguousarray(faces, dtype=np.int64)
self.vertex_normals = (
np.ascontiguousarray(vertex_normals, dtype=np.float64)
if vertex_normals is not None
else None
)
self.face_normals = (
np.ascontiguousarray(face_normals, dtype=np.float64)
if face_normals is not None
else None
)
self.vertex_colors = (
np.ascontiguousarray(vertex_colors, dtype=np.uint8)
if vertex_colors is not None
else None
)
self.uv = (
np.ascontiguousarray(uv, dtype=np.float64)
if uv is not None
else None
)
self.material = material
self.vertex_attributes = vertex_attributes or {}
self.face_attributes = face_attributes or {}
self.metadata = self._detensorize_dict(metadata) if metadata else {}
@staticmethod
def _detensorize_dict(d):
"""Recursively convert any tensors in a dict back to numpy arrays."""
if not isinstance(d, dict):
return d
result = {}
for k, v in d.items():
if hasattr(v, "numpy"):
result[k] = v.cpu().numpy() if hasattr(v, "cpu") else v.numpy()
elif isinstance(v, dict):
result[k] = TrimeshData._detensorize_dict(v)
elif isinstance(v, list):
result[k] = [
item.cpu().numpy() if hasattr(item, "numpy") and hasattr(item, "cpu")
else item.numpy() if hasattr(item, "numpy")
else item
for item in v
]
else:
result[k] = v
return result
@staticmethod
def _to_numpy(arr, dtype):
if arr is None:
return None
if hasattr(arr, "numpy"):
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
return np.ascontiguousarray(arr, dtype=dtype)
@property
def num_vertices(self) -> int:
return self.vertices.shape[0]
@property
def num_faces(self) -> int:
return self.faces.shape[0]
@property
def has_texture(self) -> bool:
return self.uv is not None and self.material is not None
def to_trimesh(self):
"""Convert to trimesh.Trimesh (requires trimesh in the environment)."""
import trimesh
from trimesh.visual import TextureVisuals
kwargs = {}
if self.vertex_normals is not None:
kwargs["vertex_normals"] = self.vertex_normals
if self.face_normals is not None:
kwargs["face_normals"] = self.face_normals
if self.metadata:
kwargs["metadata"] = self.metadata
mesh = trimesh.Trimesh(
vertices=self.vertices, faces=self.faces, process=False, **kwargs
)
# Reconstruct visual
if self.has_texture:
material = self._dict_to_material(self.material)
mesh.visual = TextureVisuals(uv=self.uv, material=material)
elif self.vertex_colors is not None:
mesh.visual.vertex_colors = self.vertex_colors
for k, v in self.vertex_attributes.items():
mesh.vertex_attributes[k] = v
for k, v in self.face_attributes.items():
mesh.face_attributes[k] = v
return mesh
@staticmethod
def _material_to_dict(material) -> dict:
"""Serialize a trimesh material to a plain dict."""
import base64
from io import BytesIO
from trimesh.visual.material import PBRMaterial, SimpleMaterial
result = {"type": type(material).__name__, "name": getattr(material, "name", None)}
if isinstance(material, PBRMaterial):
result["baseColorFactor"] = material.baseColorFactor
result["metallicFactor"] = material.metallicFactor
result["roughnessFactor"] = material.roughnessFactor
result["emissiveFactor"] = material.emissiveFactor
result["alphaMode"] = material.alphaMode
result["alphaCutoff"] = material.alphaCutoff
result["doubleSided"] = material.doubleSided
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
"metallicRoughnessTexture", "occlusionTexture"):
tex = getattr(material, tex_name, None)
if tex is not None:
buf = BytesIO()
tex.save(buf, format="PNG")
result[tex_name] = base64.b64encode(buf.getvalue()).decode("ascii")
elif isinstance(material, SimpleMaterial):
result["main_color"] = list(material.main_color) if material.main_color is not None else None
result["glossiness"] = material.glossiness
if hasattr(material, "image") and material.image is not None:
buf = BytesIO()
material.image.save(buf, format="PNG")
result["image"] = base64.b64encode(buf.getvalue()).decode("ascii")
return result
@staticmethod
def _dict_to_material(d: dict):
"""Reconstruct a trimesh material from a plain dict."""
import base64
from io import BytesIO
from PIL import Image
from trimesh.visual.material import PBRMaterial, SimpleMaterial
mat_type = d.get("type", "PBRMaterial")
if mat_type == "PBRMaterial":
kwargs = {
"name": d.get("name"),
"baseColorFactor": d.get("baseColorFactor"),
"metallicFactor": d.get("metallicFactor"),
"roughnessFactor": d.get("roughnessFactor"),
"emissiveFactor": d.get("emissiveFactor"),
"alphaMode": d.get("alphaMode"),
"alphaCutoff": d.get("alphaCutoff"),
"doubleSided": d.get("doubleSided"),
}
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
"metallicRoughnessTexture", "occlusionTexture"):
if tex_name in d and d[tex_name] is not None:
img = Image.open(BytesIO(base64.b64decode(d[tex_name])))
kwargs[tex_name] = img
return PBRMaterial(**{k: v for k, v in kwargs.items() if v is not None})
elif mat_type == "SimpleMaterial":
kwargs = {
"name": d.get("name"),
"glossiness": d.get("glossiness"),
}
if d.get("main_color") is not None:
kwargs["diffuse"] = d["main_color"]
if d.get("image") is not None:
kwargs["image"] = Image.open(BytesIO(base64.b64decode(d["image"])))
return SimpleMaterial(**kwargs)
raise ValueError(f"Unknown material type: {mat_type}")
@classmethod
def from_trimesh(cls, mesh) -> TrimeshData:
"""Create from a trimesh.Trimesh object."""
from trimesh.visual.texture import TextureVisuals
vertex_normals = None
if mesh._cache.cache.get("vertex_normals") is not None:
vertex_normals = np.asarray(mesh.vertex_normals)
face_normals = None
if mesh._cache.cache.get("face_normals") is not None:
face_normals = np.asarray(mesh.face_normals)
vertex_colors = None
uv = None
material = None
if isinstance(mesh.visual, TextureVisuals):
if mesh.visual.uv is not None:
uv = np.asarray(mesh.visual.uv, dtype=np.float64)
if mesh.visual.material is not None:
material = cls._material_to_dict(mesh.visual.material)
else:
try:
vc = mesh.visual.vertex_colors
if vc is not None and len(vc) > 0:
vertex_colors = np.asarray(vc, dtype=np.uint8)
except Exception:
pass
va = {}
if hasattr(mesh, "vertex_attributes") and mesh.vertex_attributes:
for k, v in mesh.vertex_attributes.items():
va[k] = np.asarray(v) if hasattr(v, "__array__") else v
fa = {}
if hasattr(mesh, "face_attributes") and mesh.face_attributes:
for k, v in mesh.face_attributes.items():
fa[k] = np.asarray(v) if hasattr(v, "__array__") else v
return cls(
vertices=np.asarray(mesh.vertices),
faces=np.asarray(mesh.faces),
vertex_normals=vertex_normals,
face_normals=face_normals,
vertex_colors=vertex_colors,
uv=uv,
material=material,
vertex_attributes=va if va else None,
face_attributes=fa if fa else None,
metadata=mesh.metadata if mesh.metadata else None,
)

View File

@ -685,7 +685,7 @@ class PromptExecutor:
return
try:
from comfy.isolation import notify_execution_graph
await notify_execution_graph(class_types)
await notify_execution_graph(class_types, caches=self.caches.all)
except Exception:
if fail_loud:
raise
@ -785,26 +785,26 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
if args.use_process_isolation:
try:
# Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._wait_model_patcher_quiescence_safe(
fail_loud=False,
timeout_ms=120000,
marker="EX:boundary_cleanup_wait_idle",
)
await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc()
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
except Exception:
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
if args.use_process_isolation:
try:
# Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._wait_model_patcher_quiescence_safe(
fail_loud=False,
timeout_ms=120000,
marker="EX:boundary_cleanup_wait_idle",
)
await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc()
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
except Exception:
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))

View File

@ -347,6 +347,17 @@ class PromptServer():
extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
# Include JS files from proxied web directories (isolated nodes)
if args.use_process_isolation:
from comfy.isolation.proxies.web_directory_proxy import get_web_directory_cache
cache = get_web_directory_cache()
for ext_name in cache.extension_names:
for entry in cache.list_files(ext_name):
if entry["relative_path"].endswith(".js"):
extensions.append(
"/extensions/" + urllib.parse.quote(ext_name) + "/" + entry["relative_path"]
)
return web.json_response(extensions)
def get_dir_by_type(dir_type):
@ -1022,6 +1033,40 @@ class PromptServer():
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([web.static('/extensions/' + name, dir)])
# Add dynamic handler for proxied web directories (isolated nodes)
if args.use_process_isolation:
from comfy.isolation.proxies.web_directory_proxy import (
get_web_directory_cache,
ALLOWED_EXTENSIONS,
)
async def proxied_web_handler(request):
ext_name = request.match_info["ext_name"]
file_path = request.match_info["file_path"]
suffix = os.path.splitext(file_path)[1].lower()
if suffix not in ALLOWED_EXTENSIONS:
return web.Response(status=403, text="Forbidden file type")
cache = get_web_directory_cache()
result = cache.get_file(ext_name, file_path)
if result is None:
return web.Response(status=404, text="Not found")
content_type = {
".js": "application/javascript",
".css": "text/css",
".html": "text/html",
".json": "application/json",
}.get(suffix, "application/octet-stream")
return web.Response(body=result, content_type=content_type)
self.app.router.add_get(
"/extensions/{ext_name}/{file_path:.*}",
proxied_web_handler,
)
installed_templates_version = FrontendManager.get_installed_templates_version()
use_legacy_templates = True
if installed_templates_version:

View File

@ -0,0 +1,209 @@
# pylint: disable=import-outside-toplevel,import-error
from __future__ import annotations
import logging
import os
import sys
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
def _artifact_dir() -> Path | None:
raw = os.environ.get("PYISOLATE_ARTIFACT_DIR")
if not raw:
return None
path = Path(raw)
path.mkdir(parents=True, exist_ok=True)
return path
def _write_artifact(name: str, content: str) -> None:
artifact_dir = _artifact_dir()
if artifact_dir is None:
return
(artifact_dir / name).write_text(content, encoding="utf-8")
def _contains_tensor_marker(value: Any) -> bool:
if isinstance(value, dict):
if value.get("__type__") == "TensorValue":
return True
return any(_contains_tensor_marker(v) for v in value.values())
if isinstance(value, (list, tuple)):
return any(_contains_tensor_marker(v) for v in value)
return False
class InspectRuntimeNode:
RETURN_TYPES = (
"STRING",
"STRING",
"BOOLEAN",
"BOOLEAN",
"STRING",
"STRING",
"BOOLEAN",
)
RETURN_NAMES = (
"path_dump",
"runtime_report",
"saw_comfy_root",
"imported_comfy_wrapper",
"comfy_module_dump",
"python_exe",
"saw_user_site",
)
FUNCTION = "inspect"
CATEGORY = "PyIsolated/SealedWorker"
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
return {"required": {}}
def inspect(self) -> tuple[str, str, bool, bool, str, str, bool]:
import cfgrib
import eccodes
import xarray as xr
path_dump = "\n".join(sys.path)
comfy_root = "/home/johnj/ComfyUI"
saw_comfy_root = any(
entry == comfy_root
or entry.startswith(f"{comfy_root}/comfy")
or entry.startswith(f"{comfy_root}/.venv")
for entry in sys.path
)
imported_comfy_wrapper = "comfy.isolation.extension_wrapper" in sys.modules
comfy_module_dump = "\n".join(
sorted(name for name in sys.modules if name.startswith("comfy"))
)
saw_user_site = any("/.local/lib/" in entry for entry in sys.path)
python_exe = sys.executable
runtime_lines = [
"Conda sealed worker runtime probe",
f"python_exe={python_exe}",
f"xarray_origin={getattr(xr, '__file__', '<missing>')}",
f"cfgrib_origin={getattr(cfgrib, '__file__', '<missing>')}",
f"eccodes_origin={getattr(eccodes, '__file__', '<missing>')}",
f"saw_comfy_root={saw_comfy_root}",
f"imported_comfy_wrapper={imported_comfy_wrapper}",
f"saw_user_site={saw_user_site}",
]
runtime_report = "\n".join(runtime_lines)
_write_artifact("child_bootstrap_paths.txt", path_dump)
_write_artifact("child_import_trace.txt", comfy_module_dump)
_write_artifact("child_dependency_dump.txt", runtime_report)
logger.warning("][ Conda sealed runtime probe executed")
logger.warning("][ conda python executable: %s", python_exe)
logger.warning(
"][ conda dependency origins: xarray=%s cfgrib=%s eccodes=%s",
getattr(xr, "__file__", "<missing>"),
getattr(cfgrib, "__file__", "<missing>"),
getattr(eccodes, "__file__", "<missing>"),
)
return (
path_dump,
runtime_report,
saw_comfy_root,
imported_comfy_wrapper,
comfy_module_dump,
python_exe,
saw_user_site,
)
class OpenWeatherDatasetNode:
RETURN_TYPES = ("FLOAT", "STRING", "STRING")
RETURN_NAMES = ("sum_value", "grib_path", "dependency_report")
FUNCTION = "open_dataset"
CATEGORY = "PyIsolated/SealedWorker"
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
return {"required": {}}
def open_dataset(self) -> tuple[float, str, str]:
import eccodes
import xarray as xr
artifact_dir = _artifact_dir()
if artifact_dir is None:
artifact_dir = Path(os.environ.get("HOME", ".")) / "pyisolate_artifacts"
artifact_dir.mkdir(parents=True, exist_ok=True)
grib_path = artifact_dir / "toolkit_weather_fixture.grib2"
gid = eccodes.codes_grib_new_from_samples("GRIB2")
for key, value in [
("gridType", "regular_ll"),
("Nx", 2),
("Ny", 2),
("latitudeOfFirstGridPointInDegrees", 1.0),
("longitudeOfFirstGridPointInDegrees", 0.0),
("latitudeOfLastGridPointInDegrees", 0.0),
("longitudeOfLastGridPointInDegrees", 1.0),
("iDirectionIncrementInDegrees", 1.0),
("jDirectionIncrementInDegrees", 1.0),
("jScansPositively", 0),
("shortName", "t"),
("typeOfLevel", "surface"),
("level", 0),
("date", 20260315),
("time", 0),
("step", 0),
]:
eccodes.codes_set(gid, key, value)
eccodes.codes_set_values(gid, [1.0, 2.0, 3.0, 4.0])
with grib_path.open("wb") as handle:
eccodes.codes_write(gid, handle)
eccodes.codes_release(gid)
dataset = xr.open_dataset(grib_path, engine="cfgrib")
sum_value = float(dataset["t"].sum().item())
dependency_report = "\n".join(
[
f"dataset_sum={sum_value}",
f"grib_path={grib_path}",
"xarray_engine=cfgrib",
]
)
_write_artifact("weather_dependency_report.txt", dependency_report)
logger.warning("][ cfgrib import ok")
logger.warning("][ xarray open_dataset engine=cfgrib path=%s", grib_path)
logger.warning("][ conda weather dataset sum=%s", sum_value)
return sum_value, str(grib_path), dependency_report
class EchoLatentNode:
RETURN_TYPES = ("LATENT", "BOOLEAN")
RETURN_NAMES = ("latent", "saw_json_tensor")
FUNCTION = "echo_latent"
CATEGORY = "PyIsolated/SealedWorker"
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
return {"required": {"latent": ("LATENT",)}}
def echo_latent(self, latent: Any) -> tuple[Any, bool]:
saw_json_tensor = _contains_tensor_marker(latent)
logger.warning("][ conda latent echo json_marker=%s", saw_json_tensor)
return latent, saw_json_tensor
NODE_CLASS_MAPPINGS = {
"CondaSealedRuntimeProbe": InspectRuntimeNode,
"CondaSealedOpenWeatherDataset": OpenWeatherDatasetNode,
"CondaSealedLatentEcho": EchoLatentNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CondaSealedRuntimeProbe": "Conda Sealed Runtime Probe",
"CondaSealedOpenWeatherDataset": "Conda Sealed Open Weather Dataset",
"CondaSealedLatentEcho": "Conda Sealed Latent Echo",
}

View File

@ -0,0 +1,13 @@
[project]
name = "comfyui-toolkit-conda-sealed-worker"
version = "0.1.0"
dependencies = ["xarray", "cfgrib"]
[tool.comfy.isolation]
can_isolate = true
share_torch = false
package_manager = "conda"
execution_model = "sealed_worker"
standalone = true
conda_channels = ["conda-forge"]
conda_dependencies = ["eccodes", "cfgrib"]

View File

@ -0,0 +1,7 @@
[tool.comfy.host]
sandbox_mode = "required"
allow_network = false
writable_paths = [
"/dev/shm",
"/home/johnj/ComfyUI/output",
]

View File

@ -0,0 +1,6 @@
from .probe_nodes import (
NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS,
)
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]

View File

@ -0,0 +1,75 @@
from __future__ import annotations
class InternalIsolationProbeImage:
CATEGORY = "tests/isolation"
RETURN_TYPES = ()
FUNCTION = "run"
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
def run(self):
from comfy_api.latest import UI
import torch
image = torch.zeros((1, 2, 2, 3), dtype=torch.float32)
image[:, :, :, 0] = 1.0
ui = UI.PreviewImage(image)
return {"ui": ui.as_dict(), "result": ()}
class InternalIsolationProbeAudio:
CATEGORY = "tests/isolation"
RETURN_TYPES = ()
FUNCTION = "run"
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
def run(self):
from comfy_api.latest import UI
import torch
waveform = torch.zeros((1, 1, 32), dtype=torch.float32)
audio = {"waveform": waveform, "sample_rate": 44100}
ui = UI.PreviewAudio(audio)
return {"ui": ui.as_dict(), "result": ()}
class InternalIsolationProbeUI3D:
CATEGORY = "tests/isolation"
RETURN_TYPES = ()
FUNCTION = "run"
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
def run(self):
from comfy_api.latest import UI
import torch
bg_image = torch.zeros((1, 2, 2, 3), dtype=torch.float32)
bg_image[:, :, :, 1] = 1.0
camera_info = {"distance": 1.0}
ui = UI.PreviewUI3D("internal_probe_preview.obj", camera_info, bg_image=bg_image)
return {"ui": ui.as_dict(), "result": ()}
NODE_CLASS_MAPPINGS = {
"InternalIsolationProbeImage": InternalIsolationProbeImage,
"InternalIsolationProbeAudio": InternalIsolationProbeAudio,
"InternalIsolationProbeUI3D": InternalIsolationProbeUI3D,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"InternalIsolationProbeImage": "Internal Isolation Probe Image",
"InternalIsolationProbeAudio": "Internal Isolation Probe Audio",
"InternalIsolationProbeUI3D": "Internal Isolation Probe UI3D",
}

View File

@ -0,0 +1,955 @@
from __future__ import annotations
import asyncio
import importlib.util
import os
import sys
from pathlib import Path
from typing import Any
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
UV_SEALED_WORKER_MODULE = COMFYUI_ROOT / "tests" / "isolation" / "uv_sealed_worker" / "__init__.py"
FORBIDDEN_MINIMAL_SEALED_MODULES = (
"torch",
"folder_paths",
"comfy.utils",
"comfy.model_management",
"main",
"comfy.isolation.extension_wrapper",
)
FORBIDDEN_SEALED_SINGLETON_MODULES = (
"torch",
"folder_paths",
"comfy.utils",
"comfy_execution.progress",
)
FORBIDDEN_EXACT_SMALL_PROXY_MODULES = FORBIDDEN_SEALED_SINGLETON_MODULES
FORBIDDEN_MODEL_MANAGEMENT_MODULES = (
"comfy.model_management",
)
def _load_module_from_path(module_name: str, module_path: Path):
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"unable to build import spec for {module_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
try:
spec.loader.exec_module(module)
except Exception:
sys.modules.pop(module_name, None)
raise
return module
def matching_modules(prefixes: tuple[str, ...], modules: set[str]) -> list[str]:
return sorted(
module_name
for module_name in modules
if any(
module_name == prefix or module_name.startswith(f"{prefix}.")
for prefix in prefixes
)
)
def _load_helper_proxy_service() -> Any | None:
try:
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
except (ImportError, AttributeError):
return None
return HelperProxiesService
def _load_model_management_proxy() -> Any | None:
try:
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
except (ImportError, AttributeError):
return None
return ModelManagementProxy
async def _capture_minimal_sealed_worker_imports() -> dict[str, object]:
from pyisolate.sealed import SealedNodeExtension
module_name = "tests.isolation.uv_sealed_worker_boundary_probe"
before = set(sys.modules)
extension = SealedNodeExtension()
module = _load_module_from_path(module_name, UV_SEALED_WORKER_MODULE)
try:
await extension.on_module_loaded(module)
node_list = await extension.list_nodes()
node_details = await extension.get_node_details("UVSealedRuntimeProbe")
imported = set(sys.modules) - before
return {
"mode": "minimal_sealed_worker",
"node_names": sorted(node_list),
"runtime_probe_function": node_details["function"],
"modules": sorted(imported),
"forbidden_matches": matching_modules(FORBIDDEN_MINIMAL_SEALED_MODULES, imported),
}
finally:
sys.modules.pop(module_name, None)
def capture_minimal_sealed_worker_imports() -> dict[str, object]:
return asyncio.run(_capture_minimal_sealed_worker_imports())
class FakeSingletonCaller:
def __init__(self, methods: dict[str, Any], calls: list[dict[str, Any]], object_id: str):
self._methods = methods
self._calls = calls
self._object_id = object_id
def __getattr__(self, name: str):
if name not in self._methods:
raise AttributeError(name)
async def method(*args: Any, **kwargs: Any) -> Any:
self._calls.append(
{
"object_id": self._object_id,
"method": name,
"args": list(args),
"kwargs": dict(kwargs),
}
)
result = self._methods[name]
return result(*args, **kwargs) if callable(result) else result
return method
class FakeSingletonRPC:
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
self._device = {"__pyisolate_torch_device__": "cpu"}
self._services: dict[str, dict[str, Any]] = {
"FolderPathsProxy": {
"rpc_get_models_dir": lambda: "/sandbox/models",
"rpc_get_folder_names_and_paths": lambda: {
"checkpoints": {
"paths": ["/sandbox/models/checkpoints"],
"extensions": [".ckpt", ".safetensors"],
}
},
"rpc_get_extension_mimetypes_cache": lambda: {"webp": "image"},
"rpc_get_filename_list_cache": lambda: {},
"rpc_get_temp_directory": lambda: "/sandbox/temp",
"rpc_get_input_directory": lambda: "/sandbox/input",
"rpc_get_output_directory": lambda: "/sandbox/output",
"rpc_get_user_directory": lambda: "/sandbox/user",
"rpc_get_annotated_filepath": self._get_annotated_filepath,
"rpc_exists_annotated_filepath": lambda _name: False,
"rpc_add_model_folder_path": lambda *_args, **_kwargs: None,
"rpc_get_folder_paths": lambda folder_name: [f"/sandbox/models/{folder_name}"],
"rpc_get_filename_list": lambda folder_name: [f"{folder_name}_fixture.safetensors"],
"rpc_get_full_path": lambda folder_name, filename: f"/sandbox/models/{folder_name}/{filename}",
},
"UtilsProxy": {
"progress_bar_hook": lambda value, total, preview=None, node_id=None: {
"value": value,
"total": total,
"preview": preview,
"node_id": node_id,
}
},
"ProgressProxy": {
"rpc_set_progress": lambda value, max_value, node_id=None, image=None: {
"value": value,
"max_value": max_value,
"node_id": node_id,
"image": image,
}
},
"HelperProxiesService": {
"rpc_restore_input_types": lambda raw: raw,
},
"ModelManagementProxy": {
"rpc_call": self._model_management_rpc_call,
},
}
def _model_management_rpc_call(self, method_name: str, args: Any = None, kwargs: Any = None) -> Any:
if method_name == "get_torch_device":
return self._device
elif method_name == "get_torch_device_name":
return "cpu"
elif method_name == "get_free_memory":
return 34359738368
raise AssertionError(f"unexpected model_management method {method_name}")
@staticmethod
def _get_annotated_filepath(name: str, default_dir: str | None = None) -> str:
if name.endswith("[output]"):
return f"/sandbox/output/{name[:-8]}"
if name.endswith("[input]"):
return f"/sandbox/input/{name[:-7]}"
if name.endswith("[temp]"):
return f"/sandbox/temp/{name[:-6]}"
base_dir = default_dir or "/sandbox/input"
return f"{base_dir}/{name}"
def create_caller(self, cls: Any, object_id: str):
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
if methods is None:
raise KeyError(object_id)
return FakeSingletonCaller(methods, self.calls, object_id)
def _clear_proxy_rpcs() -> None:
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.progress_proxy import ProgressProxy
from comfy.isolation.proxies.utils_proxy import UtilsProxy
FolderPathsProxy.clear_rpc()
ProgressProxy.clear_rpc()
UtilsProxy.clear_rpc()
helper_proxy_service = _load_helper_proxy_service()
if helper_proxy_service is not None:
helper_proxy_service.clear_rpc()
model_management_proxy = _load_model_management_proxy()
if model_management_proxy is not None and hasattr(model_management_proxy, "clear_rpc"):
model_management_proxy.clear_rpc()
def prepare_sealed_singleton_proxies(fake_rpc: FakeSingletonRPC) -> None:
os.environ["PYISOLATE_CHILD"] = "1"
os.environ["PYISOLATE_IMPORT_TORCH"] = "0"
_clear_proxy_rpcs()
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.progress_proxy import ProgressProxy
from comfy.isolation.proxies.utils_proxy import UtilsProxy
FolderPathsProxy.set_rpc(fake_rpc)
ProgressProxy.set_rpc(fake_rpc)
UtilsProxy.set_rpc(fake_rpc)
helper_proxy_service = _load_helper_proxy_service()
if helper_proxy_service is not None:
helper_proxy_service.set_rpc(fake_rpc)
model_management_proxy = _load_model_management_proxy()
if model_management_proxy is not None and hasattr(model_management_proxy, "set_rpc"):
model_management_proxy.set_rpc(fake_rpc)
def reset_forbidden_singleton_modules() -> None:
for module_name in (
"folder_paths",
"comfy.utils",
"comfy_execution.progress",
):
sys.modules.pop(module_name, None)
class FakeExactRelayCaller:
def __init__(self, methods: dict[str, Any], transcripts: list[dict[str, Any]], object_id: str):
self._methods = methods
self._transcripts = transcripts
self._object_id = object_id
def __getattr__(self, name: str):
if name not in self._methods:
raise AttributeError(name)
async def method(*args: Any, **kwargs: Any) -> Any:
self._transcripts.append(
{
"phase": "child_call",
"object_id": self._object_id,
"method": name,
"args": list(args),
"kwargs": dict(kwargs),
}
)
impl = self._methods[name]
self._transcripts.append(
{
"phase": "host_invocation",
"object_id": self._object_id,
"method": name,
"target": impl["target"],
"args": list(args),
"kwargs": dict(kwargs),
}
)
result = impl["result"](*args, **kwargs) if callable(impl["result"]) else impl["result"]
self._transcripts.append(
{
"phase": "result",
"object_id": self._object_id,
"method": name,
"result": result,
}
)
return result
return method
class FakeExactRelayRPC:
def __init__(self) -> None:
self.transcripts: list[dict[str, Any]] = []
self._device = {"__pyisolate_torch_device__": "cpu"}
self._services: dict[str, dict[str, Any]] = {
"FolderPathsProxy": {
"rpc_get_models_dir": {
"target": "folder_paths.models_dir",
"result": "/sandbox/models",
},
"rpc_get_temp_directory": {
"target": "folder_paths.get_temp_directory",
"result": "/sandbox/temp",
},
"rpc_get_input_directory": {
"target": "folder_paths.get_input_directory",
"result": "/sandbox/input",
},
"rpc_get_output_directory": {
"target": "folder_paths.get_output_directory",
"result": "/sandbox/output",
},
"rpc_get_user_directory": {
"target": "folder_paths.get_user_directory",
"result": "/sandbox/user",
},
"rpc_get_folder_names_and_paths": {
"target": "folder_paths.folder_names_and_paths",
"result": {
"checkpoints": {
"paths": ["/sandbox/models/checkpoints"],
"extensions": [".ckpt", ".safetensors"],
}
},
},
"rpc_get_extension_mimetypes_cache": {
"target": "folder_paths.extension_mimetypes_cache",
"result": {"webp": "image"},
},
"rpc_get_filename_list_cache": {
"target": "folder_paths.filename_list_cache",
"result": {},
},
"rpc_get_annotated_filepath": {
"target": "folder_paths.get_annotated_filepath",
"result": lambda name, default_dir=None: FakeSingletonRPC._get_annotated_filepath(name, default_dir),
},
"rpc_exists_annotated_filepath": {
"target": "folder_paths.exists_annotated_filepath",
"result": False,
},
"rpc_add_model_folder_path": {
"target": "folder_paths.add_model_folder_path",
"result": None,
},
"rpc_get_folder_paths": {
"target": "folder_paths.get_folder_paths",
"result": lambda folder_name: [f"/sandbox/models/{folder_name}"],
},
"rpc_get_filename_list": {
"target": "folder_paths.get_filename_list",
"result": lambda folder_name: [f"{folder_name}_fixture.safetensors"],
},
"rpc_get_full_path": {
"target": "folder_paths.get_full_path",
"result": lambda folder_name, filename: f"/sandbox/models/{folder_name}/{filename}",
},
},
"UtilsProxy": {
"progress_bar_hook": {
"target": "comfy.utils.PROGRESS_BAR_HOOK",
"result": lambda value, total, preview=None, node_id=None: {
"value": value,
"total": total,
"preview": preview,
"node_id": node_id,
},
},
},
"ProgressProxy": {
"rpc_set_progress": {
"target": "comfy_execution.progress.get_progress_state().update_progress",
"result": None,
},
},
"HelperProxiesService": {
"rpc_restore_input_types": {
"target": "comfy.isolation.proxies.helper_proxies.restore_input_types",
"result": lambda raw: raw,
}
},
"ModelManagementProxy": {
"rpc_call": {
"target": "comfy.model_management.*",
"result": self._model_management_rpc_call,
},
},
}
def _model_management_rpc_call(self, method_name: str, args: Any = None, kwargs: Any = None) -> Any:
device = {"__pyisolate_torch_device__": "cpu"}
if method_name == "get_torch_device":
return device
elif method_name == "get_torch_device_name":
return "cpu"
elif method_name == "get_free_memory":
return 34359738368
raise AssertionError(f"unexpected exact-relay method {method_name}")
def create_caller(self, cls: Any, object_id: str):
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
if methods is None:
raise KeyError(object_id)
return FakeExactRelayCaller(methods, self.transcripts, object_id)
def capture_exact_small_proxy_relay() -> dict[str, object]:
reset_forbidden_singleton_modules()
fake_rpc = FakeExactRelayRPC()
previous_child = os.environ.get("PYISOLATE_CHILD")
previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH")
try:
prepare_sealed_singleton_proxies(fake_rpc)
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.helper_proxies import restore_input_types
from comfy.isolation.proxies.progress_proxy import ProgressProxy
from comfy.isolation.proxies.utils_proxy import UtilsProxy
folder_proxy = FolderPathsProxy()
utils_proxy = UtilsProxy()
progress_proxy = ProgressProxy()
before = set(sys.modules)
restored = restore_input_types(
{
"required": {
"image": {"__pyisolate_any_type__": True, "value": "*"},
}
}
)
folder_path = folder_proxy.get_annotated_filepath("demo.png[input]")
models_dir = folder_proxy.models_dir
folder_names_and_paths = folder_proxy.folder_names_and_paths
asyncio.run(utils_proxy.progress_bar_hook(2, 5, node_id="node-17"))
progress_proxy.set_progress(1.5, 5.0, node_id="node-17")
imported = set(sys.modules) - before
return {
"mode": "exact_small_proxy_relay",
"folder_path": folder_path,
"models_dir": models_dir,
"folder_names_and_paths": folder_names_and_paths,
"restored_any_type": str(restored["required"]["image"]),
"transcripts": fake_rpc.transcripts,
"modules": sorted(imported),
"forbidden_matches": matching_modules(FORBIDDEN_EXACT_SMALL_PROXY_MODULES, imported),
}
finally:
_clear_proxy_rpcs()
if previous_child is None:
os.environ.pop("PYISOLATE_CHILD", None)
else:
os.environ["PYISOLATE_CHILD"] = previous_child
if previous_import_torch is None:
os.environ.pop("PYISOLATE_IMPORT_TORCH", None)
else:
os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch
class FakeModelManagementExactRelayRPC:
def __init__(self) -> None:
self.transcripts: list[dict[str, object]] = []
self._device = {"__pyisolate_torch_device__": "cpu"}
self._services: dict[str, dict[str, Any]] = {
"ModelManagementProxy": {
"rpc_call": self._rpc_call,
}
}
def create_caller(self, cls: Any, object_id: str):
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
if methods is None:
raise KeyError(object_id)
return _ModelManagementExactRelayCaller(methods)
def _rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
self.transcripts.append(
{
"phase": "child_call",
"object_id": "ModelManagementProxy",
"method": method_name,
"args": _json_safe(args),
"kwargs": _json_safe(kwargs),
}
)
target = f"comfy.model_management.{method_name}"
self.transcripts.append(
{
"phase": "host_invocation",
"object_id": "ModelManagementProxy",
"method": method_name,
"target": target,
"args": _json_safe(args),
"kwargs": _json_safe(kwargs),
}
)
if method_name == "get_torch_device":
result = self._device
elif method_name == "get_torch_device_name":
result = "cpu"
elif method_name == "get_free_memory":
result = 34359738368
else:
raise AssertionError(f"unexpected exact-relay method {method_name}")
self.transcripts.append(
{
"phase": "result",
"object_id": "ModelManagementProxy",
"method": method_name,
"result": _json_safe(result),
}
)
return result
class _ModelManagementExactRelayCaller:
def __init__(self, methods: dict[str, Any]):
self._methods = methods
def __getattr__(self, name: str):
if name not in self._methods:
raise AttributeError(name)
async def method(*args: Any, **kwargs: Any) -> Any:
impl = self._methods[name]
return impl(*args, **kwargs) if callable(impl) else impl
return method
def _json_safe(value: Any) -> Any:
if callable(value):
return f"<callable {getattr(value, '__name__', 'anonymous')}>"
if isinstance(value, tuple):
return [_json_safe(item) for item in value]
if isinstance(value, list):
return [_json_safe(item) for item in value]
if isinstance(value, dict):
return {key: _json_safe(inner) for key, inner in value.items()}
return value
def capture_model_management_exact_relay() -> dict[str, object]:
for module_name in FORBIDDEN_MODEL_MANAGEMENT_MODULES:
sys.modules.pop(module_name, None)
fake_rpc = FakeModelManagementExactRelayRPC()
previous_child = os.environ.get("PYISOLATE_CHILD")
previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH")
try:
os.environ["PYISOLATE_CHILD"] = "1"
os.environ["PYISOLATE_IMPORT_TORCH"] = "0"
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
if hasattr(ModelManagementProxy, "clear_rpc"):
ModelManagementProxy.clear_rpc()
if hasattr(ModelManagementProxy, "set_rpc"):
ModelManagementProxy.set_rpc(fake_rpc)
proxy = ModelManagementProxy()
before = set(sys.modules)
device = proxy.get_torch_device()
device_name = proxy.get_torch_device_name(device)
free_memory = proxy.get_free_memory(device)
imported = set(sys.modules) - before
return {
"mode": "model_management_exact_relay",
"device": str(device),
"device_type": getattr(device, "type", None),
"device_name": device_name,
"free_memory": free_memory,
"transcripts": fake_rpc.transcripts,
"modules": sorted(imported),
"forbidden_matches": matching_modules(FORBIDDEN_MODEL_MANAGEMENT_MODULES, imported),
}
finally:
model_management_proxy = _load_model_management_proxy()
if model_management_proxy is not None and hasattr(model_management_proxy, "clear_rpc"):
model_management_proxy.clear_rpc()
if previous_child is None:
os.environ.pop("PYISOLATE_CHILD", None)
else:
os.environ["PYISOLATE_CHILD"] = previous_child
if previous_import_torch is None:
os.environ.pop("PYISOLATE_IMPORT_TORCH", None)
else:
os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch
FORBIDDEN_PROMPT_WEB_MODULES = (
"server",
"aiohttp",
"comfy.isolation.extension_wrapper",
)
FORBIDDEN_EXACT_BOOTSTRAP_MODULES = (
"comfy.isolation.adapter",
"folder_paths",
"comfy.utils",
"comfy.model_management",
"server",
"main",
"comfy.isolation.extension_wrapper",
)
class _PromptServiceExactRelayCaller:
def __init__(self, methods: dict[str, Any], transcripts: list[dict[str, Any]], object_id: str):
self._methods = methods
self._transcripts = transcripts
self._object_id = object_id
def __getattr__(self, name: str):
if name not in self._methods:
raise AttributeError(name)
async def method(*args: Any, **kwargs: Any) -> Any:
self._transcripts.append(
{
"phase": "child_call",
"object_id": self._object_id,
"method": name,
"args": _json_safe(args),
"kwargs": _json_safe(kwargs),
}
)
impl = self._methods[name]
self._transcripts.append(
{
"phase": "host_invocation",
"object_id": self._object_id,
"method": name,
"target": impl["target"],
"args": _json_safe(args),
"kwargs": _json_safe(kwargs),
}
)
result = impl["result"](*args, **kwargs) if callable(impl["result"]) else impl["result"]
self._transcripts.append(
{
"phase": "result",
"object_id": self._object_id,
"method": name,
"result": _json_safe(result),
}
)
return result
return method
class FakePromptWebRPC:
def __init__(self) -> None:
self.transcripts: list[dict[str, Any]] = []
self._services = {
"PromptServerService": {
"ui_send_progress_text": {
"target": "server.PromptServer.instance.send_progress_text",
"result": None,
},
"register_route_rpc": {
"target": "server.PromptServer.instance.routes.add_route",
"result": None,
},
}
}
def create_caller(self, cls: Any, object_id: str):
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
if methods is None:
raise KeyError(object_id)
return _PromptServiceExactRelayCaller(methods, self.transcripts, object_id)
class FakeWebDirectoryProxy:
def __init__(self, transcripts: list[dict[str, Any]]):
self._transcripts = transcripts
def get_web_file(self, extension_name: str, relative_path: str) -> dict[str, Any]:
self._transcripts.append(
{
"phase": "child_call",
"object_id": "WebDirectoryProxy",
"method": "get_web_file",
"args": [extension_name, relative_path],
"kwargs": {},
}
)
self._transcripts.append(
{
"phase": "host_invocation",
"object_id": "WebDirectoryProxy",
"method": "get_web_file",
"target": "comfy.isolation.proxies.web_directory_proxy.WebDirectoryProxy.get_web_file",
"args": [extension_name, relative_path],
"kwargs": {},
}
)
result = {
"content": "Y29uc29sZS5sb2coJ2RlbycpOw==",
"content_type": "application/javascript",
}
self._transcripts.append(
{
"phase": "result",
"object_id": "WebDirectoryProxy",
"method": "get_web_file",
"result": result,
}
)
return result
def capture_prompt_web_exact_relay() -> dict[str, object]:
for module_name in FORBIDDEN_PROMPT_WEB_MODULES:
sys.modules.pop(module_name, None)
fake_rpc = FakePromptWebRPC()
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache
PromptServerStub.set_rpc(fake_rpc)
stub = PromptServerStub()
cache = WebDirectoryCache()
cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts))
before = set(sys.modules)
def demo_handler(_request):
return {"ok": True}
stub.send_progress_text("hello", "node-17")
stub.routes.get("/demo")(demo_handler)
web_file = cache.get_file("demo_ext", "js/app.js")
imported = set(sys.modules) - before
return {
"mode": "prompt_web_exact_relay",
"web_file": {
"content_type": web_file["content_type"] if web_file else None,
"content": web_file["content"].decode("utf-8") if web_file else None,
},
"transcripts": fake_rpc.transcripts,
"modules": sorted(imported),
"forbidden_matches": matching_modules(FORBIDDEN_PROMPT_WEB_MODULES, imported),
}
class FakeExactBootstrapRPC:
def __init__(self) -> None:
self.transcripts: list[dict[str, Any]] = []
self._device = {"__pyisolate_torch_device__": "cpu"}
self._services: dict[str, dict[str, Any]] = {
"FolderPathsProxy": FakeExactRelayRPC()._services["FolderPathsProxy"],
"HelperProxiesService": FakeExactRelayRPC()._services["HelperProxiesService"],
"ProgressProxy": FakeExactRelayRPC()._services["ProgressProxy"],
"UtilsProxy": FakeExactRelayRPC()._services["UtilsProxy"],
"PromptServerService": {
"ui_send_sync": {
"target": "server.PromptServer.instance.send_sync",
"result": None,
},
"ui_send": {
"target": "server.PromptServer.instance.send",
"result": None,
},
"ui_send_progress_text": {
"target": "server.PromptServer.instance.send_progress_text",
"result": None,
},
"register_route_rpc": {
"target": "server.PromptServer.instance.routes.add_route",
"result": None,
},
},
"ModelManagementProxy": {
"rpc_call": self._rpc_call,
},
}
def create_caller(self, cls: Any, object_id: str):
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
if methods is None:
raise KeyError(object_id)
if object_id == "ModelManagementProxy":
return _ModelManagementExactRelayCaller(methods)
return _PromptServiceExactRelayCaller(methods, self.transcripts, object_id)
def _rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
self.transcripts.append(
{
"phase": "child_call",
"object_id": "ModelManagementProxy",
"method": method_name,
"args": _json_safe(args),
"kwargs": _json_safe(kwargs),
}
)
self.transcripts.append(
{
"phase": "host_invocation",
"object_id": "ModelManagementProxy",
"method": method_name,
"target": f"comfy.model_management.{method_name}",
"args": _json_safe(args),
"kwargs": _json_safe(kwargs),
}
)
result = self._device if method_name == "get_torch_device" else None
self.transcripts.append(
{
"phase": "result",
"object_id": "ModelManagementProxy",
"method": method_name,
"result": _json_safe(result),
}
)
return result
def capture_exact_proxy_bootstrap_contract() -> dict[str, object]:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance, set_child_rpc_instance
from comfy.isolation.adapter import ComfyUIAdapter
from comfy.isolation.child_hooks import initialize_child_process
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
from comfy.isolation.proxies.progress_proxy import ProgressProxy
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
from comfy.isolation.proxies.utils_proxy import UtilsProxy
host_services = sorted(cls.__name__ for cls in ComfyUIAdapter().provide_rpc_services())
for module_name in FORBIDDEN_EXACT_BOOTSTRAP_MODULES:
sys.modules.pop(module_name, None)
previous_child = os.environ.get("PYISOLATE_CHILD")
previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH")
os.environ["PYISOLATE_CHILD"] = "1"
os.environ["PYISOLATE_IMPORT_TORCH"] = "0"
_clear_proxy_rpcs()
if hasattr(PromptServerStub, "clear_rpc"):
PromptServerStub.clear_rpc()
else:
PromptServerStub._rpc = None # type: ignore[attr-defined]
fake_rpc = FakeExactBootstrapRPC()
set_child_rpc_instance(fake_rpc)
before = set(sys.modules)
try:
initialize_child_process()
imported = set(sys.modules) - before
matrix = {
"base.py": {
"bound": get_child_rpc_instance() is fake_rpc,
"details": {"child_rpc_instance": get_child_rpc_instance() is fake_rpc},
},
"folder_paths_proxy.py": {
"bound": "FolderPathsProxy" in host_services and FolderPathsProxy._rpc is not None,
"details": {"host_service": "FolderPathsProxy" in host_services, "child_rpc": FolderPathsProxy._rpc is not None},
},
"helper_proxies.py": {
"bound": "HelperProxiesService" in host_services and HelperProxiesService._rpc is not None,
"details": {"host_service": "HelperProxiesService" in host_services, "child_rpc": HelperProxiesService._rpc is not None},
},
"model_management_proxy.py": {
"bound": "ModelManagementProxy" in host_services and ModelManagementProxy._rpc is not None,
"details": {"host_service": "ModelManagementProxy" in host_services, "child_rpc": ModelManagementProxy._rpc is not None},
},
"progress_proxy.py": {
"bound": "ProgressProxy" in host_services and ProgressProxy._rpc is not None,
"details": {"host_service": "ProgressProxy" in host_services, "child_rpc": ProgressProxy._rpc is not None},
},
"prompt_server_impl.py": {
"bound": "PromptServerService" in host_services and PromptServerStub._rpc is not None,
"details": {"host_service": "PromptServerService" in host_services, "child_rpc": PromptServerStub._rpc is not None},
},
"utils_proxy.py": {
"bound": "UtilsProxy" in host_services and UtilsProxy._rpc is not None,
"details": {"host_service": "UtilsProxy" in host_services, "child_rpc": UtilsProxy._rpc is not None},
},
"web_directory_proxy.py": {
"bound": "WebDirectoryProxy" in host_services,
"details": {"host_service": "WebDirectoryProxy" in host_services},
},
}
finally:
set_child_rpc_instance(None)
if previous_child is None:
os.environ.pop("PYISOLATE_CHILD", None)
else:
os.environ["PYISOLATE_CHILD"] = previous_child
if previous_import_torch is None:
os.environ.pop("PYISOLATE_IMPORT_TORCH", None)
else:
os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch
omitted = sorted(name for name, status in matrix.items() if not status["bound"])
return {
"mode": "exact_proxy_bootstrap_contract",
"host_services": host_services,
"matrix": matrix,
"omitted_proxies": omitted,
"modules": sorted(imported),
"forbidden_matches": matching_modules(FORBIDDEN_EXACT_BOOTSTRAP_MODULES, imported),
}
def capture_sealed_singleton_imports() -> dict[str, object]:
reset_forbidden_singleton_modules()
fake_rpc = FakeSingletonRPC()
previous_child = os.environ.get("PYISOLATE_CHILD")
previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH")
before = set(sys.modules)
try:
prepare_sealed_singleton_proxies(fake_rpc)
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from comfy.isolation.proxies.progress_proxy import ProgressProxy
from comfy.isolation.proxies.utils_proxy import UtilsProxy
folder_proxy = FolderPathsProxy()
progress_proxy = ProgressProxy()
utils_proxy = UtilsProxy()
folder_path = folder_proxy.get_annotated_filepath("demo.png[input]")
temp_dir = folder_proxy.get_temp_directory()
models_dir = folder_proxy.models_dir
asyncio.run(utils_proxy.progress_bar_hook(2, 5, node_id="node-17"))
progress_proxy.set_progress(1.5, 5.0, node_id="node-17")
imported = set(sys.modules) - before
return {
"mode": "sealed_singletons",
"folder_path": folder_path,
"temp_dir": temp_dir,
"models_dir": models_dir,
"rpc_calls": fake_rpc.calls,
"modules": sorted(imported),
"forbidden_matches": matching_modules(FORBIDDEN_SEALED_SINGLETON_MODULES, imported),
}
finally:
_clear_proxy_rpcs()
if previous_child is None:
os.environ.pop("PYISOLATE_CHILD", None)
else:
os.environ["PYISOLATE_CHILD"] = previous_child
if previous_import_torch is None:
os.environ.pop("PYISOLATE_IMPORT_TORCH", None)
else:
os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch

View File

@ -0,0 +1,69 @@
from __future__ import annotations
import argparse
import shutil
import sys
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
PROBE_SOURCE_ROOT = COMFYUI_ROOT / "tests" / "isolation" / "internal_probe_node"
PROBE_NODE_NAME = "InternalIsolationProbeNode"
PYPROJECT_CONTENT = """[project]
name = "InternalIsolationProbeNode"
version = "0.0.1"
[tool.comfy.isolation]
can_isolate = true
share_torch = true
"""
def _probe_target_root(comfy_root: Path) -> Path:
return Path(comfy_root) / "custom_nodes" / PROBE_NODE_NAME
def stage_probe_node(comfy_root: Path) -> Path:
if not PROBE_SOURCE_ROOT.is_dir():
raise RuntimeError(f"Missing probe source directory: {PROBE_SOURCE_ROOT}")
target_root = _probe_target_root(comfy_root)
target_root.mkdir(parents=True, exist_ok=True)
for source_path in PROBE_SOURCE_ROOT.iterdir():
destination_path = target_root / source_path.name
if source_path.is_dir():
shutil.copytree(source_path, destination_path, dirs_exist_ok=True)
else:
shutil.copy2(source_path, destination_path)
(target_root / "pyproject.toml").write_text(PYPROJECT_CONTENT, encoding="utf-8")
return target_root
@contextmanager
def staged_probe_node() -> Iterator[Path]:
staging_root = Path(tempfile.mkdtemp(prefix="comfyui_internal_probe_"))
try:
yield stage_probe_node(staging_root)
finally:
shutil.rmtree(staging_root, ignore_errors=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Stage the internal isolation probe node under an explicit ComfyUI root."
)
parser.add_argument(
"--target-root",
type=Path,
required=True,
help="Explicit ComfyUI root to stage under. Caller owns cleanup.",
)
args = parser.parse_args()
staged = stage_probe_node(args.target_root)
sys.stdout.write(f"{staged}\n")

View File

@ -25,7 +25,7 @@ def _run_client_process(env):
existing = env.get("PYTHONPATH", "")
if existing:
pythonpath_parts.append(existing)
env["PYTHONPATH"] = ":".join(pythonpath_parts)
env["PYTHONPATH"] = os.pathsep.join(pythonpath_parts)
result = subprocess.run( # noqa: S603
[sys.executable, "-c", SCRIPT],

View File

@ -8,13 +8,19 @@ import logging
import os
import sys
from types import SimpleNamespace
from typing import Any, cast
import pytest
import comfy.isolation as isolation_pkg
from comfy.isolation import runtime_helpers
from comfy.isolation import extension_loader as extension_loader_module
from comfy.isolation import extension_wrapper as extension_wrapper_module
from comfy.isolation import model_patcher_proxy_utils
from comfy.isolation.extension_loader import ExtensionLoadError, load_isolated_node
from comfy.isolation.extension_wrapper import ComfyNodeExtension
from comfy.isolation.model_patcher_proxy_utils import maybe_wrap_model_for_isolation
from pyisolate._internal.environment_conda import _generate_pixi_toml
class _DummyExtension:
@ -63,11 +69,10 @@ flash_attn = "flash-attn-special"
captured.update(config)
return _DummyExtension()
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
monkeypatch.setattr(
"comfy.isolation.extension_loader.pyisolate.ExtensionManager", DummyManager
)
monkeypatch.setattr(
"comfy.isolation.extension_loader.load_host_policy",
extension_loader_module,
"load_host_policy",
lambda base_path: {
"sandbox_mode": "required",
"allow_network": False,
@ -75,11 +80,10 @@ flash_attn = "flash-attn-special"
"readonly_paths": [],
},
)
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
monkeypatch.setattr(
"comfy.isolation.extension_loader.is_cache_valid", lambda *args, **kwargs: True
)
monkeypatch.setattr(
"comfy.isolation.extension_loader.load_from_cache",
extension_loader_module,
"load_from_cache",
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
)
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
@ -141,6 +145,163 @@ packages = ["flash-attn"]
)
def test_conda_cuda_wheels_declared_packages_do_not_force_pixi_solve(tmp_path, monkeypatch):
node_dir = tmp_path / "node"
node_dir.mkdir()
manifest_path = node_dir / "pyproject.toml"
_write_manifest(
node_dir,
"""
[project]
name = "demo-node"
dependencies = ["numpy>=1.0", "spconv", "cumm", "flash-attn"]
[tool.comfy.isolation]
can_isolate = true
package_manager = "conda"
conda_channels = ["conda-forge"]
[tool.comfy.isolation.cuda_wheels]
index_url = "https://example.invalid/cuda-wheels"
packages = ["spconv", "cumm", "flash-attn"]
""".strip(),
)
captured: dict[str, object] = {}
class DummyManager:
def __init__(self, *args, **kwargs) -> None:
return None
def load_extension(self, config):
captured.update(config)
return _DummyExtension()
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
monkeypatch.setattr(
extension_loader_module,
"load_host_policy",
lambda base_path: {
"sandbox_mode": "disabled",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
},
)
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
monkeypatch.setattr(
extension_loader_module,
"load_from_cache",
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
)
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
asyncio.run(
load_isolated_node(
node_dir,
manifest_path,
logging.getLogger("test"),
lambda *args, **kwargs: object,
tmp_path / "venvs",
[],
)
)
generated = _generate_pixi_toml(captured)
assert 'numpy = ">=1.0"' in generated
assert "spconv =" not in generated
assert "cumm =" not in generated
assert "flash-attn =" not in generated
def test_conda_cuda_wheels_loader_accepts_sam3d_contract(tmp_path, monkeypatch):
node_dir = tmp_path / "node"
node_dir.mkdir()
manifest_path = node_dir / "pyproject.toml"
_write_manifest(
node_dir,
"""
[project]
name = "demo-node"
dependencies = [
"torch",
"torchvision",
"pytorch3d",
"gsplat",
"nvdiffrast",
"flash-attn",
"sageattention",
"spconv",
"cumm",
]
[tool.comfy.isolation]
can_isolate = true
package_manager = "conda"
conda_channels = ["conda-forge"]
[tool.comfy.isolation.cuda_wheels]
index_url = "https://example.invalid/cuda-wheels"
packages = ["pytorch3d", "gsplat", "nvdiffrast", "flash-attn", "sageattention", "spconv", "cumm"]
""".strip(),
)
captured: dict[str, object] = {}
class DummyManager:
def __init__(self, *args, **kwargs) -> None:
return None
def load_extension(self, config):
captured.update(config)
return _DummyExtension()
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
monkeypatch.setattr(
extension_loader_module,
"load_host_policy",
lambda base_path: {
"sandbox_mode": "disabled",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
},
)
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
monkeypatch.setattr(
extension_loader_module,
"load_from_cache",
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
)
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
asyncio.run(
load_isolated_node(
node_dir,
manifest_path,
logging.getLogger("test"),
lambda *args, **kwargs: object,
tmp_path / "venvs",
[],
)
)
assert captured["package_manager"] == "conda"
assert captured["cuda_wheels"] == {
"index_url": "https://example.invalid/cuda-wheels/",
"packages": [
"pytorch3d",
"gsplat",
"nvdiffrast",
"flash-attn",
"sageattention",
"spconv",
"cumm",
],
"package_map": {},
}
def test_load_isolated_node_omits_cuda_wheels_when_not_configured(tmp_path, monkeypatch):
node_dir = tmp_path / "node"
node_dir.mkdir()
@ -167,11 +328,10 @@ can_isolate = true
captured.update(config)
return _DummyExtension()
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
monkeypatch.setattr(
"comfy.isolation.extension_loader.pyisolate.ExtensionManager", DummyManager
)
monkeypatch.setattr(
"comfy.isolation.extension_loader.load_host_policy",
extension_loader_module,
"load_host_policy",
lambda base_path: {
"sandbox_mode": "disabled",
"allow_network": False,
@ -179,11 +339,10 @@ can_isolate = true
"readonly_paths": [],
},
)
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
monkeypatch.setattr(
"comfy.isolation.extension_loader.is_cache_valid", lambda *args, **kwargs: True
)
monkeypatch.setattr(
"comfy.isolation.extension_loader.load_from_cache",
extension_loader_module,
"load_from_cache",
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
)
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
@ -214,7 +373,7 @@ def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch):
self.registry = registry
self.manage_lifecycle = manage_lifecycle
monkeypatch.setattr("comfy.isolation.model_patcher_proxy_utils.args.use_process_isolation", True)
monkeypatch.setattr(model_patcher_proxy_utils.args, "use_process_isolation", True)
monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False)
monkeypatch.delenv("PYISOLATE_CHILD", raising=False)
monkeypatch.setitem(
@ -228,20 +387,17 @@ def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch):
SimpleNamespace(ModelPatcherProxy=DummyProxy),
)
wrapped = maybe_wrap_model_for_isolation(object())
wrapped = cast(Any, maybe_wrap_model_for_isolation(object()))
assert isinstance(wrapped, DummyProxy)
assert wrapped.model_id == "model-123"
assert wrapped.manage_lifecycle is True
assert getattr(wrapped, "model_id") == "model-123"
assert getattr(wrapped, "manage_lifecycle") is True
def test_flush_transport_state_uses_child_env_without_legacy_flag(monkeypatch):
monkeypatch.setenv("PYISOLATE_CHILD", "1")
monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False)
monkeypatch.setattr(
"comfy.isolation.extension_wrapper._flush_tensor_transport_state",
lambda marker: 3,
)
monkeypatch.setattr(extension_wrapper_module, "_flush_tensor_transport_state", lambda marker: 3)
monkeypatch.setitem(
sys.modules,
"comfy.isolation.model_patcher_proxy_registry",
@ -260,8 +416,6 @@ def test_flush_transport_state_uses_child_env_without_legacy_flag(monkeypatch):
def test_build_stub_class_relieves_host_vram_without_legacy_flag(monkeypatch):
import comfy.isolation as isolation_pkg
relieve_calls: list[str] = []
async def deserialize_from_isolation(result, extension):

View File

@ -0,0 +1,22 @@
from __future__ import annotations
from tests.isolation.singleton_boundary_helpers import (
capture_exact_proxy_bootstrap_contract,
)
def test_no_proxy_omission_allowed() -> None:
payload = capture_exact_proxy_bootstrap_contract()
assert payload["omitted_proxies"] == []
assert payload["forbidden_matches"] == []
matrix = payload["matrix"]
assert matrix["base.py"]["bound"] is True
assert matrix["folder_paths_proxy.py"]["bound"] is True
assert matrix["helper_proxies.py"]["bound"] is True
assert matrix["model_management_proxy.py"]["bound"] is True
assert matrix["progress_proxy.py"]["bound"] is True
assert matrix["prompt_server_impl.py"]["bound"] is True
assert matrix["utils_proxy.py"]["bound"] is True
assert matrix["web_directory_proxy.py"]["bound"] is True

View File

@ -0,0 +1,128 @@
from __future__ import annotations
from tests.isolation.singleton_boundary_helpers import (
capture_exact_small_proxy_relay,
capture_model_management_exact_relay,
capture_prompt_web_exact_relay,
)
def _transcripts_for(payload: dict[str, object], object_id: str, method: str) -> list[dict[str, object]]:
return [
entry
for entry in payload["transcripts"]
if entry["object_id"] == object_id and entry["method"] == method
]
def test_folder_paths_exact_relay() -> None:
payload = capture_exact_small_proxy_relay()
assert payload["forbidden_matches"] == []
assert payload["models_dir"] == "/sandbox/models"
assert payload["folder_path"] == "/sandbox/input/demo.png"
models_dir_calls = _transcripts_for(payload, "FolderPathsProxy", "rpc_get_models_dir")
annotated_calls = _transcripts_for(payload, "FolderPathsProxy", "rpc_get_annotated_filepath")
assert models_dir_calls
assert annotated_calls
assert all(entry["phase"] != "child_call" or entry["method"] != "rpc_snapshot" for entry in payload["transcripts"])
def test_progress_exact_relay() -> None:
payload = capture_exact_small_proxy_relay()
progress_calls = _transcripts_for(payload, "ProgressProxy", "rpc_set_progress")
assert progress_calls
host_targets = [entry["target"] for entry in progress_calls if entry["phase"] == "host_invocation"]
assert host_targets == ["comfy_execution.progress.get_progress_state().update_progress"]
result_entries = [entry for entry in progress_calls if entry["phase"] == "result"]
assert result_entries == [{"phase": "result", "object_id": "ProgressProxy", "method": "rpc_set_progress", "result": None}]
def test_utils_exact_relay() -> None:
payload = capture_exact_small_proxy_relay()
utils_calls = _transcripts_for(payload, "UtilsProxy", "progress_bar_hook")
assert utils_calls
host_targets = [entry["target"] for entry in utils_calls if entry["phase"] == "host_invocation"]
assert host_targets == ["comfy.utils.PROGRESS_BAR_HOOK"]
result_entries = [entry for entry in utils_calls if entry["phase"] == "result"]
assert result_entries
assert result_entries[0]["result"]["value"] == 2
assert result_entries[0]["result"]["total"] == 5
def test_helper_proxy_exact_relay() -> None:
payload = capture_exact_small_proxy_relay()
helper_calls = _transcripts_for(payload, "HelperProxiesService", "rpc_restore_input_types")
assert helper_calls
host_targets = [entry["target"] for entry in helper_calls if entry["phase"] == "host_invocation"]
assert host_targets == ["comfy.isolation.proxies.helper_proxies.restore_input_types"]
assert payload["restored_any_type"] == "*"
def test_model_management_exact_relay() -> None:
payload = capture_model_management_exact_relay()
model_calls = _transcripts_for(payload, "ModelManagementProxy", "get_torch_device")
model_calls += _transcripts_for(payload, "ModelManagementProxy", "get_torch_device_name")
model_calls += _transcripts_for(payload, "ModelManagementProxy", "get_free_memory")
assert payload["forbidden_matches"] == []
assert model_calls
host_targets = [
entry["target"]
for entry in payload["transcripts"]
if entry["phase"] == "host_invocation"
]
assert host_targets == [
"comfy.model_management.get_torch_device",
"comfy.model_management.get_torch_device_name",
"comfy.model_management.get_free_memory",
]
def test_model_management_capability_preserved() -> None:
payload = capture_model_management_exact_relay()
assert payload["device"] == "cpu"
assert payload["device_type"] == "cpu"
assert payload["device_name"] == "cpu"
assert payload["free_memory"] == 34359738368
def test_prompt_server_exact_relay() -> None:
payload = capture_prompt_web_exact_relay()
prompt_calls = _transcripts_for(payload, "PromptServerService", "ui_send_progress_text")
prompt_calls += _transcripts_for(payload, "PromptServerService", "register_route_rpc")
assert payload["forbidden_matches"] == []
assert prompt_calls
host_targets = [
entry["target"]
for entry in payload["transcripts"]
if entry["object_id"] == "PromptServerService" and entry["phase"] == "host_invocation"
]
assert host_targets == [
"server.PromptServer.instance.send_progress_text",
"server.PromptServer.instance.routes.add_route",
]
def test_web_directory_exact_relay() -> None:
payload = capture_prompt_web_exact_relay()
web_calls = _transcripts_for(payload, "WebDirectoryProxy", "get_web_file")
assert web_calls
host_targets = [entry["target"] for entry in web_calls if entry["phase"] == "host_invocation"]
assert host_targets == ["comfy.isolation.proxies.web_directory_proxy.WebDirectoryProxy.get_web_file"]
assert payload["web_file"]["content_type"] == "application/javascript"
assert payload["web_file"]["content"] == "console.log('deo');"

View File

@ -0,0 +1,428 @@
"""Tests for conda config parsing in extension_loader.py (Slice 5).
These tests verify that extension_loader.py correctly parses conda-related
fields from pyproject.toml manifests and passes them into the extension config
dict given to pyisolate. The torch import chain is broken by pre-mocking
extension_wrapper before importing extension_loader.
"""
from __future__ import annotations
import importlib
import sys
import types
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
def _make_manifest(
*,
package_manager: str = "uv",
conda_channels: list[str] | None = None,
conda_dependencies: list[str] | None = None,
conda_platforms: list[str] | None = None,
share_torch: bool = False,
can_isolate: bool = True,
dependencies: list[str] | None = None,
cuda_wheels: list[str] | None = None,
) -> dict:
"""Build a manifest dict matching tomllib.load() output."""
isolation: dict = {"can_isolate": can_isolate}
if package_manager != "uv":
isolation["package_manager"] = package_manager
if conda_channels is not None:
isolation["conda_channels"] = conda_channels
if conda_dependencies is not None:
isolation["conda_dependencies"] = conda_dependencies
if conda_platforms is not None:
isolation["conda_platforms"] = conda_platforms
if share_torch:
isolation["share_torch"] = True
if cuda_wheels is not None:
isolation["cuda_wheels"] = cuda_wheels
return {
"project": {
"name": "test-extension",
"dependencies": dependencies or ["numpy"],
},
"tool": {"comfy": {"isolation": isolation}},
}
@pytest.fixture
def manifest_file(tmp_path):
"""Create a dummy pyproject.toml so manifest_path.open('rb') succeeds."""
path = tmp_path / "pyproject.toml"
path.write_bytes(b"") # content is overridden by tomllib mock
return path
@pytest.fixture
def loader_module(monkeypatch):
"""Import extension_loader under a mocked isolation package for this test only."""
mock_wrapper = MagicMock()
mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {})
iso_mod = types.ModuleType("comfy.isolation")
iso_mod.__path__ = [ # type: ignore[attr-defined]
str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation")
]
iso_mod.__package__ = "comfy.isolation"
manifest_loader = types.SimpleNamespace(
is_cache_valid=lambda *args, **kwargs: False,
load_from_cache=lambda *args, **kwargs: None,
save_to_cache=lambda *args, **kwargs: None,
)
host_policy = types.SimpleNamespace(
load_host_policy=lambda base_path: {
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
}
)
folder_paths = types.SimpleNamespace(base_path="/fake/comfyui")
monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod)
monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper)
monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock())
monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader)
monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy)
monkeypatch.setitem(sys.modules, "folder_paths", folder_paths)
sys.modules.pop("comfy.isolation.extension_loader", None)
module = importlib.import_module("comfy.isolation.extension_loader")
try:
yield module, mock_wrapper
finally:
sys.modules.pop("comfy.isolation.extension_loader", None)
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"):
delattr(comfy_pkg, "isolation")
@pytest.fixture
def mock_pyisolate(loader_module):
"""Mock pyisolate to avoid real venv creation."""
module, mock_wrapper = loader_module
mock_ext = AsyncMock()
mock_ext.list_nodes = AsyncMock(return_value={})
mock_manager = MagicMock()
mock_manager.load_extension = MagicMock(return_value=mock_ext)
sealed_type = type("SealedNodeExtension", (), {})
with patch.object(module, "pyisolate") as mock_pi:
mock_pi.ExtensionManager = MagicMock(return_value=mock_manager)
mock_pi.SealedNodeExtension = sealed_type
yield module, mock_pi, mock_manager, mock_ext, mock_wrapper
def load_isolated_node(*args, **kwargs):
return sys.modules["comfy.isolation.extension_loader"].load_isolated_node(
*args, **kwargs
)
class TestCondaPackageManagerParsing:
"""Verify extension_loader.py parses conda config from pyproject.toml."""
@pytest.mark.asyncio
async def test_conda_package_manager_in_config(
self, mock_pyisolate, manifest_file, tmp_path
):
"""package_manager='conda' must appear in extension_config."""
manifest = _make_manifest(
package_manager="conda",
conda_channels=["conda-forge"],
conda_dependencies=["eccodes"],
)
_, _, mock_manager, _, _ = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert config["package_manager"] == "conda"
@pytest.mark.asyncio
async def test_conda_channels_in_config(
self, mock_pyisolate, manifest_file, tmp_path
):
"""conda_channels must be passed through to extension_config."""
manifest = _make_manifest(
package_manager="conda",
conda_channels=["conda-forge", "nvidia"],
conda_dependencies=["eccodes"],
)
_, _, mock_manager, _, _ = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert config["conda_channels"] == ["conda-forge", "nvidia"]
@pytest.mark.asyncio
async def test_conda_dependencies_in_config(
self, mock_pyisolate, manifest_file, tmp_path
):
"""conda_dependencies must be passed through to extension_config."""
manifest = _make_manifest(
package_manager="conda",
conda_channels=["conda-forge"],
conda_dependencies=["eccodes", "cfgrib"],
)
_, _, mock_manager, _, _ = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert config["conda_dependencies"] == ["eccodes", "cfgrib"]
@pytest.mark.asyncio
async def test_conda_platforms_in_config(
self, mock_pyisolate, manifest_file, tmp_path
):
"""conda_platforms must be passed through to extension_config."""
manifest = _make_manifest(
package_manager="conda",
conda_channels=["conda-forge"],
conda_dependencies=["eccodes"],
conda_platforms=["linux-64"],
)
_, _, mock_manager, _, _ = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert config["conda_platforms"] == ["linux-64"]
class TestCondaForcedOverrides:
"""Verify conda forces share_torch=False, share_cuda_ipc=False."""
@pytest.mark.asyncio
async def test_conda_forces_share_torch_false(
self, mock_pyisolate, manifest_file, tmp_path
):
"""share_torch must be forced False for conda, even if manifest says True."""
manifest = _make_manifest(
package_manager="conda",
conda_channels=["conda-forge"],
conda_dependencies=["eccodes"],
share_torch=True, # manifest requests True — must be overridden
)
_, _, mock_manager, _, _ = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert config["share_torch"] is False
@pytest.mark.asyncio
async def test_conda_forces_share_cuda_ipc_false(
self, mock_pyisolate, manifest_file, tmp_path
):
"""share_cuda_ipc must be forced False for conda."""
manifest = _make_manifest(
package_manager="conda",
conda_channels=["conda-forge"],
conda_dependencies=["eccodes"],
share_torch=True,
)
_, _, mock_manager, _, _ = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert config["share_cuda_ipc"] is False
@pytest.mark.asyncio
async def test_conda_sealed_worker_uses_host_policy_sandbox_config(
self, mock_pyisolate, manifest_file, tmp_path
):
"""Conda sealed_worker must carry the host-policy sandbox config on Linux."""
manifest = _make_manifest(
package_manager="conda",
conda_channels=["conda-forge"],
conda_dependencies=["eccodes"],
)
_, _, mock_manager, _, _ = mock_pyisolate
with (
patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib,
patch(
"comfy.isolation.extension_loader.platform.system",
return_value="Linux",
),
):
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert config["sandbox"] == {
"network": False,
"writable_paths": [],
"readonly_paths": [],
}
@pytest.mark.asyncio
async def test_conda_uses_sealed_extension_type(
self, mock_pyisolate, manifest_file, tmp_path
):
"""Conda must not launch through ComfyNodeExtension."""
_, mock_pi, _, _, mock_wrapper = mock_pyisolate
manifest = _make_manifest(
package_manager="conda",
conda_channels=["conda-forge"],
conda_dependencies=["eccodes"],
)
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
assert extension_type.__name__ == "SealedNodeExtension"
assert extension_type is not mock_wrapper.ComfyNodeExtension
class TestUvUnchanged:
"""Verify uv extensions are NOT affected by conda changes."""
@pytest.mark.asyncio
async def test_uv_default_no_conda_keys(
self, mock_pyisolate, manifest_file, tmp_path
):
"""Default uv extension must NOT have package_manager or conda keys."""
manifest = _make_manifest() # defaults: uv, no conda fields
_, _, mock_manager, _, _ = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
# uv extensions should not have conda-specific keys
assert config.get("package_manager", "uv") == "uv"
assert "conda_channels" not in config
assert "conda_dependencies" not in config
@pytest.mark.asyncio
async def test_uv_keeps_comfy_extension_type(
self, mock_pyisolate, manifest_file, tmp_path
):
"""uv keeps the existing ComfyNodeExtension path."""
_, mock_pi, _, _, _ = mock_pyisolate
manifest = _make_manifest()
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
assert extension_type.__name__ == "ComfyNodeExtension"
assert extension_type is not mock_pi.SealedNodeExtension

View File

@ -0,0 +1,281 @@
"""Tests for execution_model parsing and sealed-worker loader selection."""
from __future__ import annotations
import importlib
import sys
import types
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
def _make_manifest(
*,
package_manager: str = "uv",
execution_model: str | None = None,
can_isolate: bool = True,
dependencies: list[str] | None = None,
sealed_host_ro_paths: list[str] | None = None,
) -> dict:
isolation: dict = {"can_isolate": can_isolate}
if package_manager != "uv":
isolation["package_manager"] = package_manager
if execution_model is not None:
isolation["execution_model"] = execution_model
if sealed_host_ro_paths is not None:
isolation["sealed_host_ro_paths"] = sealed_host_ro_paths
return {
"project": {
"name": "test-extension",
"dependencies": dependencies or ["numpy"],
},
"tool": {"comfy": {"isolation": isolation}},
}
@pytest.fixture
def manifest_file(tmp_path):
path = tmp_path / "pyproject.toml"
path.write_bytes(b"")
return path
@pytest.fixture
def loader_module(monkeypatch):
mock_wrapper = MagicMock()
mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {})
iso_mod = types.ModuleType("comfy.isolation")
iso_mod.__path__ = [ # type: ignore[attr-defined]
str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation")
]
iso_mod.__package__ = "comfy.isolation"
manifest_loader = types.SimpleNamespace(
is_cache_valid=lambda *args, **kwargs: False,
load_from_cache=lambda *args, **kwargs: None,
save_to_cache=lambda *args, **kwargs: None,
)
host_policy = types.SimpleNamespace(
load_host_policy=lambda base_path: {
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
"sealed_worker_ro_import_paths": [],
}
)
folder_paths = types.SimpleNamespace(base_path="/fake/comfyui")
monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod)
monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper)
monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock())
monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader)
monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy)
monkeypatch.setitem(sys.modules, "folder_paths", folder_paths)
sys.modules.pop("comfy.isolation.extension_loader", None)
module = importlib.import_module("comfy.isolation.extension_loader")
try:
yield module
finally:
sys.modules.pop("comfy.isolation.extension_loader", None)
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"):
delattr(comfy_pkg, "isolation")
@pytest.fixture
def mock_pyisolate(loader_module):
mock_ext = AsyncMock()
mock_ext.list_nodes = AsyncMock(return_value={})
mock_manager = MagicMock()
mock_manager.load_extension = MagicMock(return_value=mock_ext)
sealed_type = type("SealedNodeExtension", (), {})
with patch.object(loader_module, "pyisolate") as mock_pi:
mock_pi.ExtensionManager = MagicMock(return_value=mock_manager)
mock_pi.SealedNodeExtension = sealed_type
yield loader_module, mock_pi, mock_manager, mock_ext, sealed_type
def load_isolated_node(*args, **kwargs):
return sys.modules["comfy.isolation.extension_loader"].load_isolated_node(*args, **kwargs)
@pytest.mark.asyncio
async def test_uv_sealed_worker_selects_sealed_extension_type(
mock_pyisolate, manifest_file, tmp_path
):
manifest = _make_manifest(execution_model="sealed_worker")
_, mock_pi, mock_manager, _, sealed_type = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
config = mock_manager.load_extension.call_args[0][0]
assert extension_type is sealed_type
assert config["execution_model"] == "sealed_worker"
assert "apis" not in config
@pytest.mark.asyncio
async def test_default_uv_keeps_host_coupled_extension_type(
mock_pyisolate, manifest_file, tmp_path
):
manifest = _make_manifest()
_, mock_pi, mock_manager, _, sealed_type = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
config = mock_manager.load_extension.call_args[0][0]
assert extension_type is not sealed_type
assert "execution_model" not in config
@pytest.mark.asyncio
async def test_conda_without_execution_model_remains_sealed_worker(
mock_pyisolate, manifest_file, tmp_path
):
manifest = _make_manifest(package_manager="conda")
manifest["tool"]["comfy"]["isolation"]["conda_channels"] = ["conda-forge"]
manifest["tool"]["comfy"]["isolation"]["conda_dependencies"] = ["eccodes"]
_, mock_pi, mock_manager, _, sealed_type = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
config = mock_manager.load_extension.call_args[0][0]
assert extension_type is sealed_type
assert config["execution_model"] == "sealed_worker"
@pytest.mark.asyncio
async def test_sealed_worker_uses_host_policy_ro_import_paths(
mock_pyisolate, manifest_file, tmp_path
):
manifest = _make_manifest(execution_model="sealed_worker")
module, _, mock_manager, _, _ = mock_pyisolate
with (
patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib,
patch.object(
module,
"load_host_policy",
return_value={
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
"sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"],
},
),
):
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert config["sealed_host_ro_paths"] == ["/home/johnj/ComfyUI"]
@pytest.mark.asyncio
async def test_host_coupled_does_not_emit_sealed_host_ro_paths(
mock_pyisolate, manifest_file, tmp_path
):
manifest = _make_manifest(execution_model="host-coupled")
module, _, mock_manager, _, _ = mock_pyisolate
with (
patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib,
patch.object(
module,
"load_host_policy",
return_value={
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
"sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"],
},
),
):
mock_tomllib.load.return_value = manifest
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
config = mock_manager.load_extension.call_args[0][0]
assert "sealed_host_ro_paths" not in config
@pytest.mark.asyncio
async def test_sealed_worker_manifest_ro_import_paths_blocked(
mock_pyisolate, manifest_file, tmp_path
):
manifest = _make_manifest(
execution_model="sealed_worker",
sealed_host_ro_paths=["/home/johnj/ComfyUI"],
)
_, _, _mock_manager, _, _ = mock_pyisolate
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
with pytest.raises(ValueError, match="Manifest field 'sealed_host_ro_paths' is not allowed"):
await load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_file,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)

View File

@ -4,6 +4,7 @@ import pytest
from pathlib import Path
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
from tests.isolation.singleton_boundary_helpers import capture_sealed_singleton_imports
class TestFolderPathsProxy:
@ -109,3 +110,13 @@ class TestFolderPathsProxy:
result = proxy.get_folder_paths("checkpoints")
# Should have at least one checkpoint path registered
assert len(result) > 0, "Checkpoints folder paths is empty"
def test_sealed_child_safe_uses_rpc_without_importing_folder_paths(self, monkeypatch):
monkeypatch.setenv("PYISOLATE_CHILD", "1")
monkeypatch.setenv("PYISOLATE_IMPORT_TORCH", "0")
payload = capture_sealed_singleton_imports()
assert payload["temp_dir"] == "/sandbox/temp"
assert payload["models_dir"] == "/sandbox/models"
assert "folder_paths" not in payload["modules"]

View File

@ -1,5 +1,7 @@
from pathlib import Path
import pytest
def _write_pyproject(path: Path, content: str) -> None:
path.write_text(content, encoding="utf-8")
@ -111,3 +113,97 @@ allow_network = true
assert policy["sandbox_mode"] == "disabled"
assert policy["allow_network"] is True
def test_disallows_host_tmp_default_or_override_defaults(tmp_path):
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
policy = load_host_policy(tmp_path)
assert "/tmp" not in DEFAULT_POLICY["writable_paths"]
assert "/tmp" not in policy["writable_paths"]
def test_disallows_host_tmp_default_or_override_config(tmp_path):
from comfy.isolation.host_policy import load_host_policy
_write_pyproject(
tmp_path / "pyproject.toml",
"""
[tool.comfy.host]
writable_paths = ["/dev/shm", "/tmp", "/tmp/", "/work/cache"]
""".strip(),
)
policy = load_host_policy(tmp_path)
assert policy["writable_paths"] == ["/dev/shm", "/work/cache"]
def test_sealed_worker_ro_import_paths_defaults_off_and_parse(tmp_path):
from comfy.isolation.host_policy import load_host_policy
policy = load_host_policy(tmp_path)
assert policy["sealed_worker_ro_import_paths"] == []
_write_pyproject(
tmp_path / "pyproject.toml",
"""
[tool.comfy.host]
sealed_worker_ro_import_paths = ["/home/johnj/ComfyUI", "/opt/comfy-shared"]
""".strip(),
)
policy = load_host_policy(tmp_path)
assert policy["sealed_worker_ro_import_paths"] == [
"/home/johnj/ComfyUI",
"/opt/comfy-shared",
]
def test_sealed_worker_ro_import_paths_rejects_non_list_or_relative(tmp_path):
from comfy.isolation.host_policy import load_host_policy
_write_pyproject(
tmp_path / "pyproject.toml",
"""
[tool.comfy.host]
sealed_worker_ro_import_paths = "/home/johnj/ComfyUI"
""".strip(),
)
with pytest.raises(ValueError, match="must be a list of absolute paths"):
load_host_policy(tmp_path)
_write_pyproject(
tmp_path / "pyproject.toml",
"""
[tool.comfy.host]
sealed_worker_ro_import_paths = ["relative/path"]
""".strip(),
)
with pytest.raises(ValueError, match="entries must be absolute paths"):
load_host_policy(tmp_path)
def test_host_policy_path_override_controls_ro_import_paths(tmp_path, monkeypatch):
from comfy.isolation.host_policy import load_host_policy
_write_pyproject(
tmp_path / "pyproject.toml",
"""
[tool.comfy.host]
sealed_worker_ro_import_paths = ["/ignored/base/path"]
""".strip(),
)
override_path = tmp_path / "host_policy_override.toml"
_write_pyproject(
override_path,
"""
[tool.comfy.host]
sealed_worker_ro_import_paths = ["/override/ro/path"]
""".strip(),
)
monkeypatch.setenv("COMFY_HOST_POLICY_PATH", str(override_path))
policy = load_host_policy(tmp_path)
assert policy["sealed_worker_ro_import_paths"] == ["/override/ro/path"]

View File

@ -1,5 +1,12 @@
"""Unit tests for PyIsolate isolation system initialization."""
import importlib
import sys
from tests.isolation.singleton_boundary_helpers import (
FakeSingletonRPC,
reset_forbidden_singleton_modules,
)
def test_log_prefix():
@ -11,9 +18,9 @@ def test_log_prefix():
def test_module_initialization():
"""Verify module initializes without errors."""
import comfy.isolation
assert hasattr(comfy.isolation, 'LOG_PREFIX')
assert hasattr(comfy.isolation, 'initialize_proxies')
isolation_pkg = importlib.import_module("comfy.isolation")
assert hasattr(isolation_pkg, "LOG_PREFIX")
assert hasattr(isolation_pkg, "initialize_proxies")
class TestInitializeProxies:
@ -54,3 +61,20 @@ class TestInitializeProxies:
utils_proxy = UtilsProxy()
assert folder_proxy is not None
assert utils_proxy is not None
def test_sealed_child_safe_initialize_proxies_avoids_real_utils_import(self, monkeypatch):
monkeypatch.setenv("PYISOLATE_CHILD", "1")
monkeypatch.setenv("PYISOLATE_IMPORT_TORCH", "0")
reset_forbidden_singleton_modules()
from pyisolate._internal import rpc_protocol
from comfy.isolation import initialize_proxies
fake_rpc = FakeSingletonRPC()
monkeypatch.setattr(rpc_protocol, "get_child_rpc_instance", lambda: fake_rpc)
initialize_proxies()
assert "comfy.utils" not in sys.modules
assert "folder_paths" not in sys.modules
assert "comfy_execution.progress" not in sys.modules

View File

@ -0,0 +1,105 @@
from __future__ import annotations
import importlib.util
import json
from pathlib import Path
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
ISOLATION_ROOT = COMFYUI_ROOT / "tests" / "isolation"
PROBE_ROOT = ISOLATION_ROOT / "internal_probe_node"
WORKFLOW_ROOT = ISOLATION_ROOT / "workflows"
TOOLKIT_ROOT = COMFYUI_ROOT / "custom_nodes" / "ComfyUI-IsolationToolkit"
EXPECTED_PROBE_FILES = {
"__init__.py",
"probe_nodes.py",
}
EXPECTED_WORKFLOWS = {
"internal_probe_preview_image_audio.json",
"internal_probe_ui3d.json",
}
BANNED_REFERENCES = (
"ComfyUI-IsolationToolkit",
"toolkit_smoke_playlist",
"run_isolation_toolkit_smoke.sh",
)
def _text_assets() -> list[Path]:
return sorted(list(PROBE_ROOT.rglob("*.py")) + list(WORKFLOW_ROOT.glob("internal_probe_*.json")))
def _load_probe_package():
spec = importlib.util.spec_from_file_location(
"internal_probe_node",
PROBE_ROOT / "__init__.py",
submodule_search_locations=[str(PROBE_ROOT)],
)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_inventory_is_minimal_and_isolation_owned():
assert PROBE_ROOT.is_dir()
assert WORKFLOW_ROOT.is_dir()
assert PROBE_ROOT.is_relative_to(ISOLATION_ROOT)
assert WORKFLOW_ROOT.is_relative_to(ISOLATION_ROOT)
assert not PROBE_ROOT.is_relative_to(TOOLKIT_ROOT)
probe_files = {path.name for path in PROBE_ROOT.iterdir() if path.is_file()}
workflow_files = {path.name for path in WORKFLOW_ROOT.glob("internal_probe_*.json")}
assert probe_files == EXPECTED_PROBE_FILES
assert workflow_files == EXPECTED_WORKFLOWS
module = _load_probe_package()
mappings = module.NODE_CLASS_MAPPINGS
assert sorted(mappings.keys()) == [
"InternalIsolationProbeAudio",
"InternalIsolationProbeImage",
"InternalIsolationProbeUI3D",
]
preview_workflow = json.loads(
(WORKFLOW_ROOT / "internal_probe_preview_image_audio.json").read_text(
encoding="utf-8"
)
)
ui3d_workflow = json.loads(
(WORKFLOW_ROOT / "internal_probe_ui3d.json").read_text(encoding="utf-8")
)
assert [preview_workflow[node_id]["class_type"] for node_id in ("1", "2")] == [
"InternalIsolationProbeImage",
"InternalIsolationProbeAudio",
]
assert [ui3d_workflow[node_id]["class_type"] for node_id in ("1",)] == [
"InternalIsolationProbeUI3D",
]
def test_zero_toolkit_references_in_probe_assets():
for asset in _text_assets():
content = asset.read_text(encoding="utf-8")
for banned in BANNED_REFERENCES:
assert banned not in content, f"{asset} unexpectedly references {banned}"
def test_replacement_contract_has_zero_toolkit_references():
contract_assets = [
*(PROBE_ROOT.rglob("*.py")),
*WORKFLOW_ROOT.glob("internal_probe_*.json"),
ISOLATION_ROOT / "stage_internal_probe_node.py",
ISOLATION_ROOT / "internal_probe_host_policy.toml",
]
for asset in sorted(contract_assets):
assert asset.exists(), f"Missing replacement-contract asset: {asset}"
content = asset.read_text(encoding="utf-8")
for banned in BANNED_REFERENCES:
assert banned not in content, f"{asset} unexpectedly references {banned}"

View File

@ -0,0 +1,180 @@
from __future__ import annotations
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
import pytest
import nodes
from tests.isolation.stage_internal_probe_node import (
PROBE_NODE_NAME,
stage_probe_node,
staged_probe_node,
)
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
ISOLATION_ROOT = COMFYUI_ROOT / "tests" / "isolation"
PROBE_SOURCE_ROOT = ISOLATION_ROOT / "internal_probe_node"
EXPECTED_NODE_IDS = [
"InternalIsolationProbeAudio",
"InternalIsolationProbeImage",
"InternalIsolationProbeUI3D",
]
CLIENT_SCRIPT = """
import importlib.util
import json
import os
import sys
import pyisolate._internal.client # noqa: F401 # triggers snapshot bootstrap
module_path = os.environ["PYISOLATE_MODULE_PATH"]
spec = importlib.util.spec_from_file_location(
"internal_probe_node",
os.path.join(module_path, "__init__.py"),
submodule_search_locations=[module_path],
)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules["internal_probe_node"] = module
spec.loader.exec_module(module)
print(
json.dumps(
{
"sys_path": list(sys.path),
"module_path": module_path,
"node_ids": sorted(module.NODE_CLASS_MAPPINGS.keys()),
}
)
)
"""
def _run_client_process(env: dict[str, str]) -> dict:
pythonpath_parts = [str(COMFYUI_ROOT)]
existing = env.get("PYTHONPATH", "")
if existing:
pythonpath_parts.append(existing)
env["PYTHONPATH"] = ":".join(pythonpath_parts)
result = subprocess.run( # noqa: S603
[sys.executable, "-c", CLIENT_SCRIPT],
capture_output=True,
text=True,
env=env,
check=True,
)
return json.loads(result.stdout.strip().splitlines()[-1])
@pytest.fixture()
def staged_probe_module(tmp_path: Path) -> tuple[Path, Path]:
staged_comfy_root = tmp_path / "ComfyUI"
module_path = staged_comfy_root / "custom_nodes" / "InternalIsolationProbeNode"
shutil.copytree(PROBE_SOURCE_ROOT, module_path)
return staged_comfy_root, module_path
@pytest.mark.asyncio
async def test_staged_probe_node_discovered(staged_probe_module: tuple[Path, Path]) -> None:
_, module_path = staged_probe_module
class_mappings_snapshot = dict(nodes.NODE_CLASS_MAPPINGS)
display_name_snapshot = dict(nodes.NODE_DISPLAY_NAME_MAPPINGS)
loaded_module_dirs_snapshot = dict(nodes.LOADED_MODULE_DIRS)
try:
ignore = set(nodes.NODE_CLASS_MAPPINGS.keys())
loaded = await nodes.load_custom_node(
str(module_path), ignore=ignore, module_parent="custom_nodes"
)
assert loaded is True
assert nodes.LOADED_MODULE_DIRS["InternalIsolationProbeNode"] == str(
module_path.resolve()
)
for node_id in EXPECTED_NODE_IDS:
assert node_id in nodes.NODE_CLASS_MAPPINGS
node_cls = nodes.NODE_CLASS_MAPPINGS[node_id]
assert (
getattr(node_cls, "RELATIVE_PYTHON_MODULE", None)
== "custom_nodes.InternalIsolationProbeNode"
)
finally:
nodes.NODE_CLASS_MAPPINGS.clear()
nodes.NODE_CLASS_MAPPINGS.update(class_mappings_snapshot)
nodes.NODE_DISPLAY_NAME_MAPPINGS.clear()
nodes.NODE_DISPLAY_NAME_MAPPINGS.update(display_name_snapshot)
nodes.LOADED_MODULE_DIRS.clear()
nodes.LOADED_MODULE_DIRS.update(loaded_module_dirs_snapshot)
def test_staged_probe_node_module_path_is_valid_for_child_bootstrap(
tmp_path: Path, staged_probe_module: tuple[Path, Path]
) -> None:
staged_comfy_root, module_path = staged_probe_module
snapshot = {
"sys_path": [str(COMFYUI_ROOT), "/host/lib1", "/host/lib2"],
"sys_executable": sys.executable,
"sys_prefix": sys.prefix,
"environment": {},
}
snapshot_path = tmp_path / "snapshot.json"
snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8")
env = os.environ.copy()
env.update(
{
"PYISOLATE_CHILD": "1",
"PYISOLATE_HOST_SNAPSHOT": str(snapshot_path),
"PYISOLATE_MODULE_PATH": str(module_path),
}
)
payload = _run_client_process(env)
assert payload["module_path"] == str(module_path)
assert payload["node_ids"] == EXPECTED_NODE_IDS
assert str(COMFYUI_ROOT) in payload["sys_path"]
assert str(staged_comfy_root) not in payload["sys_path"]
def test_stage_probe_node_stages_only_under_explicit_root(tmp_path: Path) -> None:
comfy_root = tmp_path / "sandbox-root"
module_path = stage_probe_node(comfy_root)
assert module_path == comfy_root / "custom_nodes" / PROBE_NODE_NAME
assert module_path.is_dir()
assert (module_path / "__init__.py").is_file()
assert (module_path / "probe_nodes.py").is_file()
assert (module_path / "pyproject.toml").is_file()
def test_staged_probe_node_context_cleans_up_temp_root() -> None:
with staged_probe_node() as module_path:
staging_root = module_path.parents[1]
assert module_path.name == PROBE_NODE_NAME
assert module_path.is_dir()
assert staging_root.is_dir()
assert not staging_root.exists()
def test_stage_script_requires_explicit_target_root() -> None:
result = subprocess.run( # noqa: S603
[sys.executable, str(ISOLATION_ROOT / "stage_internal_probe_node.py")],
capture_output=True,
text=True,
check=False,
)
assert result.returncode != 0
assert "--target-root" in result.stderr

View File

@ -0,0 +1,86 @@
from __future__ import annotations
import importlib
import sys
from pathlib import Path
from types import ModuleType
def _write_manifest(path: Path, *, standalone: bool = False) -> None:
lines = [
"[project]",
'name = "test-node"',
'version = "0.1.0"',
"",
"[tool.comfy.isolation]",
"can_isolate = true",
"share_torch = false",
]
if standalone:
lines.append("standalone = true")
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def _load_manifest_loader(custom_nodes_root: Path):
folder_paths = ModuleType("folder_paths")
folder_paths.base_path = str(custom_nodes_root)
folder_paths.get_folder_paths = lambda kind: [str(custom_nodes_root)] if kind == "custom_nodes" else []
sys.modules["folder_paths"] = folder_paths
if "comfy.isolation" not in sys.modules:
iso_mod = ModuleType("comfy.isolation")
iso_mod.__path__ = [ # type: ignore[attr-defined]
str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation")
]
iso_mod.__package__ = "comfy.isolation"
sys.modules["comfy.isolation"] = iso_mod
sys.modules.pop("comfy.isolation.manifest_loader", None)
import comfy.isolation.manifest_loader as manifest_loader
return importlib.reload(manifest_loader)
def test_finds_top_level_isolation_manifest(tmp_path: Path) -> None:
node_dir = tmp_path / "TopLevelNode"
node_dir.mkdir(parents=True)
_write_manifest(node_dir / "pyproject.toml")
manifest_loader = _load_manifest_loader(tmp_path)
manifests = manifest_loader.find_manifest_directories()
assert manifests == [(node_dir, node_dir / "pyproject.toml")]
def test_ignores_nested_manifest_without_standalone_flag(tmp_path: Path) -> None:
toolkit_dir = tmp_path / "ToolkitNode"
toolkit_dir.mkdir(parents=True)
_write_manifest(toolkit_dir / "pyproject.toml")
nested_dir = toolkit_dir / "packages" / "nested_fixture"
nested_dir.mkdir(parents=True)
_write_manifest(nested_dir / "pyproject.toml", standalone=False)
manifest_loader = _load_manifest_loader(tmp_path)
manifests = manifest_loader.find_manifest_directories()
assert manifests == [(toolkit_dir, toolkit_dir / "pyproject.toml")]
def test_finds_nested_standalone_manifest(tmp_path: Path) -> None:
toolkit_dir = tmp_path / "ToolkitNode"
toolkit_dir.mkdir(parents=True)
_write_manifest(toolkit_dir / "pyproject.toml")
nested_dir = toolkit_dir / "packages" / "uv_sealed_worker"
nested_dir.mkdir(parents=True)
_write_manifest(nested_dir / "pyproject.toml", standalone=True)
manifest_loader = _load_manifest_loader(tmp_path)
manifests = manifest_loader.find_manifest_directories()
assert manifests == [
(toolkit_dir, toolkit_dir / "pyproject.toml"),
(nested_dir, nested_dir / "pyproject.toml"),
]

View File

@ -0,0 +1,125 @@
"""Generic runtime-helper stub contract tests."""
from __future__ import annotations
import asyncio
import logging
import os
import subprocess
import sys
from pathlib import Path
from types import SimpleNamespace
from typing import Any, cast
from comfy.isolation import runtime_helpers
from comfy_api.latest import io as latest_io
from tests.isolation.stage_internal_probe_node import PROBE_NODE_NAME, staged_probe_node
class _DummyExtension:
def __init__(self, *, name: str, module_path: str):
self.name = name
self.module_path = module_path
async def execute_node(self, _node_name: str, **inputs):
return {
"__node_output__": True,
"args": (inputs,),
"ui": {"status": "ok"},
"expand": False,
"block_execution": False,
}
def _install_model_serialization_stub(monkeypatch):
async def deserialize_from_isolation(payload, _extension):
return payload
monkeypatch.setitem(
sys.modules,
"pyisolate._internal.model_serialization",
SimpleNamespace(
serialize_for_isolation=lambda payload: payload,
deserialize_from_isolation=deserialize_from_isolation,
),
)
def test_stub_sets_relative_python_module(monkeypatch):
_install_model_serialization_stub(monkeypatch)
monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None)
monkeypatch.setattr(runtime_helpers, "_relieve_host_vram_pressure", lambda *args, **kwargs: None)
extension = _DummyExtension(name="internal_probe", module_path=os.getcwd())
stub = cast(Any, runtime_helpers.build_stub_class(
"ProbeNode",
{
"is_v3": True,
"schema_v1": {},
"input_types": {},
},
extension,
{},
logging.getLogger("test"),
))
info = getattr(stub, "GET_NODE_INFO_V1")()
assert info["python_module"] == "custom_nodes.internal_probe"
def test_stub_ui_dispatch_roundtrip(monkeypatch):
_install_model_serialization_stub(monkeypatch)
monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None)
monkeypatch.setattr(runtime_helpers, "_relieve_host_vram_pressure", lambda *args, **kwargs: None)
extension = _DummyExtension(name="internal_probe", module_path=os.getcwd())
stub = runtime_helpers.build_stub_class(
"ProbeNode",
{
"is_v3": True,
"schema_v1": {"python_module": "custom_nodes.internal_probe"},
"input_types": {},
},
extension,
{},
logging.getLogger("test"),
)
result = asyncio.run(getattr(stub, "_pyisolate_execute")(SimpleNamespace(), token="value"))
assert isinstance(result, latest_io.NodeOutput)
assert result.ui == {"status": "ok"}
def test_stub_class_types_align_with_extension():
extension = SimpleNamespace(name="internal_probe", module_path="/sandbox/probe")
running_extensions = {"internal_probe": extension}
specs = [
SimpleNamespace(module_path=Path("/sandbox/probe"), node_name="ProbeImage"),
SimpleNamespace(module_path=Path("/sandbox/probe"), node_name="ProbeAudio"),
SimpleNamespace(module_path=Path("/sandbox/other"), node_name="OtherNode"),
]
class_types = runtime_helpers.get_class_types_for_extension(
"internal_probe", running_extensions, specs
)
assert class_types == {"ProbeImage", "ProbeAudio"}
def test_probe_stage_requires_explicit_root():
script = Path(__file__).resolve().parent / "stage_internal_probe_node.py"
result = subprocess.run([sys.executable, str(script)], capture_output=True, text=True, check=False)
assert result.returncode != 0
assert "--target-root" in result.stderr
def test_probe_stage_cleans_up_context():
with staged_probe_node() as module_path:
staged_root = module_path.parents[1]
assert module_path.name == PROBE_NODE_NAME
assert staged_root.exists()
assert not staged_root.exists()

View File

@ -0,0 +1,53 @@
import logging
import socket
import sys
from pathlib import Path
repo_root = Path(__file__).resolve().parents[2]
pyisolate_root = repo_root.parent / "pyisolate"
if pyisolate_root.exists():
sys.path.insert(0, str(pyisolate_root))
from comfy.isolation.adapter import ComfyUIAdapter
from comfy_api.latest._io import FolderType
from comfy_api.latest._ui import SavedImages, SavedResult
from pyisolate._internal.rpc_transports import JSONSocketTransport
from pyisolate._internal.serialization_registry import SerializerRegistry
def test_savedimages_roundtrip(caplog):
registry = SerializerRegistry.get_instance()
registry.clear()
ComfyUIAdapter().register_serializers(registry)
payload = SavedImages(
results=[SavedResult("issue82.png", "slice2", FolderType.output)],
is_animated=True,
)
a, b = socket.socketpair()
sender = JSONSocketTransport(a)
receiver = JSONSocketTransport(b)
try:
with caplog.at_level(logging.WARNING, logger="pyisolate._internal.rpc_transports"):
sender.send({"ui": payload})
result = receiver.recv()
finally:
sender.close()
receiver.close()
registry.clear()
ui = result["ui"]
assert isinstance(ui, SavedImages)
assert ui.is_animated is True
assert len(ui.results) == 1
assert isinstance(ui.results[0], SavedResult)
assert ui.results[0].filename == "issue82.png"
assert ui.results[0].subfolder == "slice2"
assert ui.results[0].type == FolderType.output
assert ui.as_dict() == {
"images": [SavedResult("issue82.png", "slice2", FolderType.output)],
"animated": (True,),
}
assert not any("GENERIC SERIALIZER USED" in record.message for record in caplog.records)
assert not any("GENERIC DESERIALIZER USED" in record.message for record in caplog.records)

View File

@ -0,0 +1,368 @@
"""Generic sealed-worker loader contract matrix tests."""
from __future__ import annotations
import importlib
import json
import sys
import types
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
TEST_WORKFLOW_ROOT = COMFYUI_ROOT / "tests" / "isolation" / "workflows"
SEALED_WORKFLOW_CLASS_TYPES: dict[str, set[str]] = {
"quick_6_uv_sealed_worker.json": {
"EmptyLatentImage",
"ProxyTestSealedWorker",
"UVSealedBoltonsSlugify",
"UVSealedLatentEcho",
"UVSealedRuntimeProbe",
},
"isolation_7_uv_sealed_worker.json": {
"EmptyLatentImage",
"ProxyTestSealedWorker",
"UVSealedBoltonsSlugify",
"UVSealedLatentEcho",
"UVSealedRuntimeProbe",
},
"quick_8_conda_sealed_worker.json": {
"CondaSealedLatentEcho",
"CondaSealedOpenWeatherDataset",
"CondaSealedRuntimeProbe",
"EmptyLatentImage",
"ProxyTestCondaSealedWorker",
},
"isolation_9_conda_sealed_worker.json": {
"CondaSealedLatentEcho",
"CondaSealedOpenWeatherDataset",
"CondaSealedRuntimeProbe",
"EmptyLatentImage",
"ProxyTestCondaSealedWorker",
},
}
def _workflow_class_types(path: Path) -> set[str]:
payload = json.loads(path.read_text(encoding="utf-8"))
return {
node["class_type"]
for node in payload.values()
if isinstance(node, dict) and "class_type" in node
}
def _make_manifest(
*,
package_manager: str = "uv",
execution_model: str | None = None,
can_isolate: bool = True,
dependencies: list[str] | None = None,
share_torch: bool = False,
sealed_host_ro_paths: list[str] | None = None,
) -> dict:
isolation: dict[str, object] = {
"can_isolate": can_isolate,
}
if package_manager != "uv":
isolation["package_manager"] = package_manager
if execution_model is not None:
isolation["execution_model"] = execution_model
if share_torch:
isolation["share_torch"] = True
if sealed_host_ro_paths is not None:
isolation["sealed_host_ro_paths"] = sealed_host_ro_paths
if package_manager == "conda":
isolation["conda_channels"] = ["conda-forge"]
isolation["conda_dependencies"] = ["numpy"]
return {
"project": {
"name": "contract-extension",
"dependencies": dependencies or ["numpy"],
},
"tool": {"comfy": {"isolation": isolation}},
}
@pytest.fixture
def manifest_file(tmp_path: Path) -> Path:
path = tmp_path / "pyproject.toml"
path.write_bytes(b"")
return path
def _loader_module(
monkeypatch: pytest.MonkeyPatch, *, preload_extension_wrapper: bool
):
mock_wrapper = MagicMock()
mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {})
iso_mod = types.ModuleType("comfy.isolation")
iso_mod.__path__ = [
str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation")
]
iso_mod.__package__ = "comfy.isolation"
manifest_loader = types.SimpleNamespace(
is_cache_valid=lambda *args, **kwargs: False,
load_from_cache=lambda *args, **kwargs: None,
save_to_cache=lambda *args, **kwargs: None,
)
host_policy = types.SimpleNamespace(
load_host_policy=lambda base_path: {
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
"sealed_worker_ro_import_paths": [],
}
)
folder_paths = types.SimpleNamespace(base_path="/fake/comfyui")
monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod)
monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock())
monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader)
monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy)
monkeypatch.setitem(sys.modules, "folder_paths", folder_paths)
if preload_extension_wrapper:
monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper)
else:
sys.modules.pop("comfy.isolation.extension_wrapper", None)
sys.modules.pop("comfy.isolation.extension_loader", None)
module = importlib.import_module("comfy.isolation.extension_loader")
try:
yield module, mock_wrapper
finally:
sys.modules.pop("comfy.isolation.extension_loader", None)
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"):
delattr(comfy_pkg, "isolation")
@pytest.fixture
def loader_module(monkeypatch: pytest.MonkeyPatch):
yield from _loader_module(monkeypatch, preload_extension_wrapper=True)
@pytest.fixture
def sealed_loader_module(monkeypatch: pytest.MonkeyPatch):
yield from _loader_module(monkeypatch, preload_extension_wrapper=False)
@pytest.fixture
def mocked_loader(loader_module):
module, mock_wrapper = loader_module
mock_ext = AsyncMock()
mock_ext.list_nodes = AsyncMock(return_value={})
mock_manager = MagicMock()
mock_manager.load_extension = MagicMock(return_value=mock_ext)
sealed_type = type("SealedNodeExtension", (), {})
with patch.object(module, "pyisolate") as mock_pi:
mock_pi.ExtensionManager = MagicMock(return_value=mock_manager)
mock_pi.SealedNodeExtension = sealed_type
yield module, mock_pi, mock_manager, sealed_type, mock_wrapper
@pytest.fixture
def sealed_mocked_loader(sealed_loader_module):
module, mock_wrapper = sealed_loader_module
mock_ext = AsyncMock()
mock_ext.list_nodes = AsyncMock(return_value={})
mock_manager = MagicMock()
mock_manager.load_extension = MagicMock(return_value=mock_ext)
sealed_type = type("SealedNodeExtension", (), {})
with patch.object(module, "pyisolate") as mock_pi:
mock_pi.ExtensionManager = MagicMock(return_value=mock_manager)
mock_pi.SealedNodeExtension = sealed_type
yield module, mock_pi, mock_manager, sealed_type, mock_wrapper
async def _load_node(module, manifest: dict, manifest_path: Path, tmp_path: Path) -> dict:
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
mock_tomllib.load.return_value = manifest
await module.load_isolated_node(
node_dir=tmp_path,
manifest_path=manifest_path,
logger=MagicMock(),
build_stub_class=MagicMock(),
venv_root=tmp_path / "venvs",
extension_managers=[],
)
manager = module.pyisolate.ExtensionManager.return_value
return manager.load_extension.call_args[0][0]
@pytest.mark.asyncio
async def test_uv_host_coupled_default(mocked_loader, manifest_file: Path, tmp_path: Path):
module, mock_pi, _mock_manager, sealed_type, _ = mocked_loader
manifest = _make_manifest(package_manager="uv")
config = await _load_node(module, manifest, manifest_file, tmp_path)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
assert extension_type is not sealed_type
assert "execution_model" not in config
@pytest.mark.asyncio
async def test_uv_sealed_worker_opt_in(
sealed_mocked_loader, manifest_file: Path, tmp_path: Path
):
module, mock_pi, _mock_manager, sealed_type, _ = sealed_mocked_loader
manifest = _make_manifest(package_manager="uv", execution_model="sealed_worker")
config = await _load_node(module, manifest, manifest_file, tmp_path)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
assert extension_type is sealed_type
assert config["execution_model"] == "sealed_worker"
assert "apis" not in config
assert "comfy.isolation.extension_wrapper" not in sys.modules
@pytest.mark.asyncio
async def test_conda_defaults_to_sealed_worker(
sealed_mocked_loader, manifest_file: Path, tmp_path: Path
):
module, mock_pi, _mock_manager, sealed_type, _ = sealed_mocked_loader
manifest = _make_manifest(package_manager="conda")
config = await _load_node(module, manifest, manifest_file, tmp_path)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
assert extension_type is sealed_type
assert config["execution_model"] == "sealed_worker"
assert config["package_manager"] == "conda"
assert "comfy.isolation.extension_wrapper" not in sys.modules
@pytest.mark.asyncio
async def test_conda_never_uses_comfy_extension_type(
mocked_loader, manifest_file: Path, tmp_path: Path
):
module, mock_pi, _mock_manager, sealed_type, mock_wrapper = mocked_loader
manifest = _make_manifest(package_manager="conda")
await _load_node(module, manifest, manifest_file, tmp_path)
extension_type = mock_pi.ExtensionManager.call_args[0][0]
assert extension_type is sealed_type
assert extension_type is not mock_wrapper.ComfyNodeExtension
@pytest.mark.asyncio
async def test_conda_forces_share_torch_false(mocked_loader, manifest_file: Path, tmp_path: Path):
module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader
manifest = _make_manifest(package_manager="conda", share_torch=True)
config = await _load_node(module, manifest, manifest_file, tmp_path)
assert config["share_torch"] is False
@pytest.mark.asyncio
async def test_conda_forces_share_cuda_ipc_false(
mocked_loader, manifest_file: Path, tmp_path: Path
):
module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader
manifest = _make_manifest(package_manager="conda", share_torch=True)
config = await _load_node(module, manifest, manifest_file, tmp_path)
assert config["share_cuda_ipc"] is False
@pytest.mark.asyncio
async def test_conda_sandbox_policy_applied(mocked_loader, manifest_file: Path, tmp_path: Path):
module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader
manifest = _make_manifest(package_manager="conda")
custom_policy = {
"sandbox_mode": "required",
"allow_network": True,
"writable_paths": ["/data/write"],
"readonly_paths": ["/data/read"],
}
with patch("platform.system", return_value="Linux"):
with patch.object(module, "load_host_policy", return_value=custom_policy):
config = await _load_node(module, manifest, manifest_file, tmp_path)
assert config["sandbox_mode"] == "required"
assert config["sandbox"] == {
"network": True,
"writable_paths": ["/data/write"],
"readonly_paths": ["/data/read"],
}
def test_sealed_worker_workflow_templates_present() -> None:
missing = [
filename
for filename in SEALED_WORKFLOW_CLASS_TYPES
if not (TEST_WORKFLOW_ROOT / filename).is_file()
]
assert not missing, f"missing sealed-worker workflow templates: {missing}"
@pytest.mark.parametrize(
"workflow_name,expected_class_types",
SEALED_WORKFLOW_CLASS_TYPES.items(),
)
def test_sealed_worker_workflow_class_type_contract(
workflow_name: str, expected_class_types: set[str]
) -> None:
workflow_path = TEST_WORKFLOW_ROOT / workflow_name
assert workflow_path.is_file(), f"workflow missing: {workflow_path}"
assert _workflow_class_types(workflow_path) == expected_class_types
@pytest.mark.asyncio
async def test_sealed_worker_host_policy_ro_import_matrix(
mocked_loader, manifest_file: Path, tmp_path: Path
):
module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader
manifest = _make_manifest(package_manager="uv", execution_model="sealed_worker")
with patch.object(
module,
"load_host_policy",
return_value={
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
"sealed_worker_ro_import_paths": [],
},
):
default_config = await _load_node(module, manifest, manifest_file, tmp_path)
with patch.object(
module,
"load_host_policy",
return_value={
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": [],
"readonly_paths": [],
"sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"],
},
):
opt_in_config = await _load_node(module, manifest, manifest_file, tmp_path)
assert default_config["execution_model"] == "sealed_worker"
assert "sealed_host_ro_paths" not in default_config
assert opt_in_config["execution_model"] == "sealed_worker"
assert opt_in_config["sealed_host_ro_paths"] == ["/home/johnj/ComfyUI"]
assert "apis" not in opt_in_config

View File

@ -0,0 +1,44 @@
import asyncio
import sys
from pathlib import Path
repo_root = Path(__file__).resolve().parents[2]
pyisolate_root = repo_root.parent / "pyisolate"
if pyisolate_root.exists():
sys.path.insert(0, str(pyisolate_root))
from comfy.isolation.adapter import ComfyUIAdapter
from comfy.isolation.runtime_helpers import _wrap_remote_handles_as_host_proxies
from pyisolate._internal.model_serialization import deserialize_from_isolation
from pyisolate._internal.remote_handle import RemoteObjectHandle
from pyisolate._internal.serialization_registry import SerializerRegistry
def test_shared_model_ksampler_contract():
registry = SerializerRegistry.get_instance()
registry.clear()
ComfyUIAdapter().register_serializers(registry)
handle = RemoteObjectHandle("model_0", "ModelPatcher")
class FakeExtension:
async def call_remote_object_method(self, object_id, method_name, *args, **kwargs):
assert object_id == "model_0"
assert method_name == "get_model_object"
assert args == ("latent_format",)
assert kwargs == {}
return "resolved:latent_format"
wrapped = (handle,)
assert isinstance(wrapped, tuple)
assert isinstance(wrapped[0], RemoteObjectHandle)
deserialized = asyncio.run(deserialize_from_isolation(wrapped))
proxied = _wrap_remote_handles_as_host_proxies(deserialized, FakeExtension())
model_for_host = proxied[0]
assert not isinstance(model_for_host, RemoteObjectHandle)
assert hasattr(model_for_host, "get_model_object")
assert model_for_host.get_model_object("latent_format") == "resolved:latent_format"
registry.clear()

View File

@ -0,0 +1,78 @@
from __future__ import annotations
import json
from tests.isolation.singleton_boundary_helpers import (
capture_minimal_sealed_worker_imports,
capture_sealed_singleton_imports,
)
def test_minimal_sealed_worker_forbidden_imports() -> None:
payload = capture_minimal_sealed_worker_imports()
assert payload["mode"] == "minimal_sealed_worker"
assert payload["runtime_probe_function"] == "inspect"
assert payload["forbidden_matches"] == []
def test_torch_share_subset_scope() -> None:
minimal = capture_minimal_sealed_worker_imports()
allowed_torch_share_only = {
"torch",
"folder_paths",
"comfy.utils",
"comfy.model_management",
"main",
"comfy.isolation.extension_wrapper",
}
assert minimal["forbidden_matches"] == []
assert all(
module_name not in minimal["modules"] for module_name in sorted(allowed_torch_share_only)
)
def test_capture_payload_is_json_serializable() -> None:
payload = capture_minimal_sealed_worker_imports()
encoded = json.dumps(payload, sort_keys=True)
assert "\"minimal_sealed_worker\"" in encoded
def test_folder_paths_child_safe() -> None:
payload = capture_sealed_singleton_imports()
assert payload["mode"] == "sealed_singletons"
assert payload["folder_path"] == "/sandbox/input/demo.png"
assert payload["temp_dir"] == "/sandbox/temp"
assert payload["models_dir"] == "/sandbox/models"
assert payload["forbidden_matches"] == []
def test_utils_child_safe() -> None:
payload = capture_sealed_singleton_imports()
progress_calls = [
call
for call in payload["rpc_calls"]
if call["object_id"] == "UtilsProxy" and call["method"] == "progress_bar_hook"
]
assert progress_calls
assert payload["forbidden_matches"] == []
def test_progress_child_safe() -> None:
payload = capture_sealed_singleton_imports()
progress_calls = [
call
for call in payload["rpc_calls"]
if call["object_id"] == "ProgressProxy" and call["method"] == "rpc_set_progress"
]
assert progress_calls
assert payload["forbidden_matches"] == []

View File

@ -0,0 +1,129 @@
"""Tests for WebDirectoryProxy host-side cache and aiohttp handler integration."""
from __future__ import annotations
import base64
import sys
from unittest.mock import MagicMock
import pytest
from comfy.isolation.proxies.web_directory_proxy import (
ALLOWED_EXTENSIONS,
WebDirectoryCache,
)
@pytest.fixture()
def mock_proxy() -> MagicMock:
"""Create a mock WebDirectoryProxy RPC proxy."""
proxy = MagicMock()
proxy.list_web_files.return_value = [
{"relative_path": "js/app.js", "content_type": "application/javascript"},
{"relative_path": "js/utils.js", "content_type": "application/javascript"},
{"relative_path": "index.html", "content_type": "text/html"},
{"relative_path": "style.css", "content_type": "text/css"},
]
proxy.get_web_file.return_value = {
"content": base64.b64encode(b"console.log('hello');").decode("ascii"),
"content_type": "application/javascript",
}
return proxy
@pytest.fixture()
def cache_with_proxy(mock_proxy: MagicMock) -> WebDirectoryCache:
"""Create a WebDirectoryCache with a registered mock proxy."""
cache = WebDirectoryCache()
cache.register_proxy("test-extension", mock_proxy)
return cache
class TestExtensionsListing:
"""AC-2: /extensions endpoint lists proxied JS files in URL format."""
def test_extensions_listing_produces_url_format_paths(
self, cache_with_proxy: WebDirectoryCache
) -> None:
"""Simulate what server.py does: build /extensions/{name}/{path} URLs."""
import urllib.parse
ext_name = "test-extension"
urls = []
for entry in cache_with_proxy.list_files(ext_name):
if entry["relative_path"].endswith(".js"):
urls.append(
"/extensions/" + urllib.parse.quote(ext_name)
+ "/" + entry["relative_path"]
)
# Emit the actual URL list so it appears in test log output.
sys.stdout.write(f"\n--- Proxied JS URLs ({len(urls)}) ---\n")
for url in urls:
sys.stdout.write(f" {url}\n")
sys.stdout.write("--- End URLs ---\n")
# At least one proxied JS URL in /extensions/{name}/{path} format
assert len(urls) >= 1, f"Expected >= 1 proxied JS URL, got {len(urls)}"
assert "/extensions/test-extension/js/app.js" in urls, (
f"Expected /extensions/test-extension/js/app.js in {urls}"
)
class TestCacheHit:
"""AC-3: Cache populated on first request, reused on second."""
def test_cache_hit_single_rpc_call(
self, cache_with_proxy: WebDirectoryCache, mock_proxy: MagicMock
) -> None:
# First call — RPC
result1 = cache_with_proxy.get_file("test-extension", "js/app.js")
assert result1 is not None
assert result1["content"] == b"console.log('hello');"
# Second call — cache hit
result2 = cache_with_proxy.get_file("test-extension", "js/app.js")
assert result2 is not None
assert result2["content"] == b"console.log('hello');"
# Proxy was called exactly once
assert mock_proxy.get_web_file.call_count == 1
def test_cache_returns_none_for_unknown_extension(
self, cache_with_proxy: WebDirectoryCache
) -> None:
result = cache_with_proxy.get_file("nonexistent", "js/app.js")
assert result is None
class TestForbiddenType:
"""AC-4: Disallowed file types return HTTP 403 Forbidden."""
@pytest.mark.parametrize(
"disallowed_path,expected_status",
[
("backdoor.py", 403),
("malware.exe", 403),
("exploit.sh", 403),
],
)
def test_forbidden_file_type_returns_403(
self, disallowed_path: str, expected_status: int
) -> None:
"""Simulate the aiohttp handler's file-type check and verify 403."""
import os
suffix = os.path.splitext(disallowed_path)[1].lower()
# This mirrors the handler logic in server.py:
# if suffix not in ALLOWED_EXTENSIONS: return web.Response(status=403)
if suffix not in ALLOWED_EXTENSIONS:
status = 403
else:
status = 200
sys.stdout.write(
f"\n--- HTTP status for {disallowed_path} (suffix={suffix}): {status} ---\n"
)
assert status == expected_status, (
f"Expected HTTP {expected_status} for {disallowed_path}, got {status}"
)

View File

@ -0,0 +1,130 @@
"""Tests for WebDirectoryProxy — allow-list, traversal prevention, content serving."""
from __future__ import annotations
import base64
from pathlib import Path
import pytest
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
@pytest.fixture()
def web_dir_with_mixed_files(tmp_path: Path) -> Path:
"""Create a temp web directory with allowed and disallowed file types."""
web = tmp_path / "web"
js_dir = web / "js"
js_dir.mkdir(parents=True)
# Allowed types
(js_dir / "app.js").write_text("console.log('hello');")
(web / "index.html").write_text("<html></html>")
(web / "style.css").write_text("body { margin: 0; }")
# Disallowed types
(web / "backdoor.py").write_text("import os; os.system('rm -rf /')")
(web / "malware.exe").write_bytes(b"\x00" * 16)
(web / "exploit.sh").write_text("#!/bin/bash\nrm -rf /")
return web
@pytest.fixture()
def proxy_with_web_dir(web_dir_with_mixed_files: Path) -> WebDirectoryProxy:
"""Create a WebDirectoryProxy with a registered test web directory."""
proxy = WebDirectoryProxy()
# Clear class-level state to avoid cross-test pollution
WebDirectoryProxy._web_dirs = {}
WebDirectoryProxy.register_web_dir("test-extension", str(web_dir_with_mixed_files))
return proxy
class TestAllowList:
"""AC-2: list_web_files returns only allowed file types."""
def test_allowlist_only_safe_types(
self, proxy_with_web_dir: WebDirectoryProxy
) -> None:
files = proxy_with_web_dir.list_web_files("test-extension")
extensions = {Path(f["relative_path"]).suffix for f in files}
# Only .js, .html, .css should appear
assert extensions == {".js", ".html", ".css"}
def test_allowlist_excludes_dangerous_types(
self, proxy_with_web_dir: WebDirectoryProxy
) -> None:
files = proxy_with_web_dir.list_web_files("test-extension")
paths = [f["relative_path"] for f in files]
assert not any(p.endswith(".py") for p in paths)
assert not any(p.endswith(".exe") for p in paths)
assert not any(p.endswith(".sh") for p in paths)
def test_allowlist_correct_count(
self, proxy_with_web_dir: WebDirectoryProxy
) -> None:
files = proxy_with_web_dir.list_web_files("test-extension")
# 3 allowed files: app.js, index.html, style.css
assert len(files) == 3
def test_allowlist_unknown_extension_returns_empty(
self, proxy_with_web_dir: WebDirectoryProxy
) -> None:
files = proxy_with_web_dir.list_web_files("nonexistent-extension")
assert files == []
class TestTraversal:
"""AC-3: get_web_file rejects directory traversal attempts."""
@pytest.mark.parametrize(
"malicious_path",
[
"../../../etc/passwd",
"/etc/passwd",
"../../__init__.py",
],
)
def test_traversal_rejected(
self, proxy_with_web_dir: WebDirectoryProxy, malicious_path: str
) -> None:
with pytest.raises(ValueError):
proxy_with_web_dir.get_web_file("test-extension", malicious_path)
class TestContent:
"""AC-4: get_web_file returns base64 content with correct MIME types."""
def test_content_js_mime_type(
self, proxy_with_web_dir: WebDirectoryProxy
) -> None:
result = proxy_with_web_dir.get_web_file("test-extension", "js/app.js")
assert result["content_type"] == "application/javascript"
def test_content_html_mime_type(
self, proxy_with_web_dir: WebDirectoryProxy
) -> None:
result = proxy_with_web_dir.get_web_file("test-extension", "index.html")
assert result["content_type"] == "text/html"
def test_content_css_mime_type(
self, proxy_with_web_dir: WebDirectoryProxy
) -> None:
result = proxy_with_web_dir.get_web_file("test-extension", "style.css")
assert result["content_type"] == "text/css"
def test_content_base64_roundtrip(
self, proxy_with_web_dir: WebDirectoryProxy, web_dir_with_mixed_files: Path
) -> None:
result = proxy_with_web_dir.get_web_file("test-extension", "js/app.js")
decoded = base64.b64decode(result["content"])
source = (web_dir_with_mixed_files / "js" / "app.js").read_bytes()
assert decoded == source
def test_content_disallowed_type_rejected(
self, proxy_with_web_dir: WebDirectoryProxy
) -> None:
with pytest.raises(ValueError, match="Disallowed file type"):
proxy_with_web_dir.get_web_file("test-extension", "backdoor.py")

View File

@ -0,0 +1,230 @@
# pylint: disable=import-outside-toplevel,import-error
from __future__ import annotations
import os
import sys
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
def _artifact_dir() -> Path | None:
raw = os.environ.get("PYISOLATE_ARTIFACT_DIR")
if not raw:
return None
path = Path(raw)
path.mkdir(parents=True, exist_ok=True)
return path
def _write_artifact(name: str, content: str) -> None:
artifact_dir = _artifact_dir()
if artifact_dir is None:
return
(artifact_dir / name).write_text(content, encoding="utf-8")
def _contains_tensor_marker(value: Any) -> bool:
if isinstance(value, dict):
if value.get("__type__") == "TensorValue":
return True
return any(_contains_tensor_marker(v) for v in value.values())
if isinstance(value, (list, tuple)):
return any(_contains_tensor_marker(v) for v in value)
return False
class InspectRuntimeNode:
RETURN_TYPES = (
"STRING",
"STRING",
"BOOLEAN",
"BOOLEAN",
"STRING",
"STRING",
"BOOLEAN",
)
RETURN_NAMES = (
"path_dump",
"boltons_origin",
"saw_comfy_root",
"imported_comfy_wrapper",
"comfy_module_dump",
"report",
"saw_user_site",
)
FUNCTION = "inspect"
CATEGORY = "PyIsolated/SealedWorker"
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
return {"required": {}}
def inspect(self) -> tuple[str, str, bool, bool, str, str, bool]:
import boltons
path_dump = "\n".join(sys.path)
comfy_root = "/home/johnj/ComfyUI"
saw_comfy_root = any(
entry == comfy_root
or entry.startswith(f"{comfy_root}/comfy")
or entry.startswith(f"{comfy_root}/.venv")
for entry in sys.path
)
imported_comfy_wrapper = "comfy.isolation.extension_wrapper" in sys.modules
comfy_module_dump = "\n".join(
sorted(name for name in sys.modules if name.startswith("comfy"))
)
saw_user_site = any("/.local/lib/" in entry for entry in sys.path)
boltons_origin = getattr(boltons, "__file__", "<missing>")
report_lines = [
"UV sealed worker runtime probe",
f"boltons_origin={boltons_origin}",
f"saw_comfy_root={saw_comfy_root}",
f"imported_comfy_wrapper={imported_comfy_wrapper}",
f"saw_user_site={saw_user_site}",
]
report = "\n".join(report_lines)
_write_artifact("child_bootstrap_paths.txt", path_dump)
_write_artifact("child_import_trace.txt", comfy_module_dump)
_write_artifact("child_dependency_dump.txt", boltons_origin)
logger.warning("][ UV sealed runtime probe executed")
logger.warning("][ boltons origin: %s", boltons_origin)
return (
path_dump,
boltons_origin,
saw_comfy_root,
imported_comfy_wrapper,
comfy_module_dump,
report,
saw_user_site,
)
class BoltonsSlugifyNode:
RETURN_TYPES = ("STRING", "STRING")
RETURN_NAMES = ("slug", "boltons_origin")
FUNCTION = "slugify_text"
CATEGORY = "PyIsolated/SealedWorker"
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
return {"required": {"text": ("STRING", {"default": "Sealed Worker Rocks"})}}
def slugify_text(self, text: str) -> tuple[str, str]:
import boltons
from boltons.strutils import slugify
slug = slugify(text)
origin = getattr(boltons, "__file__", "<missing>")
logger.warning("][ boltons slugify: %r -> %r", text, slug)
return slug, origin
class FilesystemBarrierNode:
RETURN_TYPES = ("STRING", "BOOLEAN", "BOOLEAN", "BOOLEAN")
RETURN_NAMES = (
"report",
"outside_blocked",
"module_mutation_blocked",
"artifact_write_ok",
)
FUNCTION = "probe"
CATEGORY = "PyIsolated/SealedWorker"
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
return {"required": {}}
def probe(self) -> tuple[str, bool, bool, bool]:
artifact_dir = _artifact_dir()
artifact_write_ok = False
if artifact_dir is not None:
probe_path = artifact_dir / "filesystem_barrier_probe.txt"
probe_path.write_text("artifact write ok\n", encoding="utf-8")
artifact_write_ok = probe_path.exists()
module_target = Path(__file__).with_name(
"mutated_from_child_should_not_exist.txt"
)
module_mutation_blocked = False
try:
module_target.write_text("mutation should fail\n", encoding="utf-8")
except Exception:
module_mutation_blocked = True
else:
module_target.unlink(missing_ok=True)
outside_target = Path("/home/johnj/mysolate/.uv_sealed_worker_escape_probe")
outside_blocked = False
try:
outside_target.write_text("escape should fail\n", encoding="utf-8")
except Exception:
outside_blocked = True
else:
outside_target.unlink(missing_ok=True)
report_lines = [
"UV sealed worker filesystem barrier probe",
f"artifact_write_ok={artifact_write_ok}",
f"module_mutation_blocked={module_mutation_blocked}",
f"outside_blocked={outside_blocked}",
]
report = "\n".join(report_lines)
_write_artifact("filesystem_barrier_report.txt", report)
logger.warning("][ filesystem barrier probe executed")
return report, outside_blocked, module_mutation_blocked, artifact_write_ok
class EchoTensorNode:
RETURN_TYPES = ("TENSOR", "BOOLEAN")
RETURN_NAMES = ("tensor", "saw_json_tensor")
FUNCTION = "echo"
CATEGORY = "PyIsolated/SealedWorker"
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
return {"required": {"tensor": ("TENSOR",)}}
def echo(self, tensor: Any) -> tuple[Any, bool]:
saw_json_tensor = _contains_tensor_marker(tensor)
logger.warning("][ tensor echo json_marker=%s", saw_json_tensor)
return tensor, saw_json_tensor
class EchoLatentNode:
RETURN_TYPES = ("LATENT", "BOOLEAN")
RETURN_NAMES = ("latent", "saw_json_tensor")
FUNCTION = "echo_latent"
CATEGORY = "PyIsolated/SealedWorker"
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
return {"required": {"latent": ("LATENT",)}}
def echo_latent(self, latent: Any) -> tuple[Any, bool]:
saw_json_tensor = _contains_tensor_marker(latent)
logger.warning("][ latent echo json_marker=%s", saw_json_tensor)
return latent, saw_json_tensor
NODE_CLASS_MAPPINGS = {
"UVSealedRuntimeProbe": InspectRuntimeNode,
"UVSealedBoltonsSlugify": BoltonsSlugifyNode,
"UVSealedFilesystemBarrier": FilesystemBarrierNode,
"UVSealedTensorEcho": EchoTensorNode,
"UVSealedLatentEcho": EchoLatentNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"UVSealedRuntimeProbe": "UV Sealed Runtime Probe",
"UVSealedBoltonsSlugify": "UV Sealed Boltons Slugify",
"UVSealedFilesystemBarrier": "UV Sealed Filesystem Barrier",
"UVSealedTensorEcho": "UV Sealed Tensor Echo",
"UVSealedLatentEcho": "UV Sealed Latent Echo",
}

View File

@ -0,0 +1,11 @@
[project]
name = "comfyui-toolkit-uv-sealed-worker"
version = "0.1.0"
dependencies = ["boltons"]
[tool.comfy.isolation]
can_isolate = true
share_torch = false
package_manager = "uv"
execution_model = "sealed_worker"
standalone = true

View File

@ -0,0 +1,10 @@
{
"1": {
"class_type": "InternalIsolationProbeImage",
"inputs": {}
},
"2": {
"class_type": "InternalIsolationProbeAudio",
"inputs": {}
}
}

View File

@ -0,0 +1,6 @@
{
"1": {
"class_type": "InternalIsolationProbeUI3D",
"inputs": {}
}
}

View File

@ -0,0 +1,22 @@
{
"1": {
"class_type": "EmptyLatentImage",
"inputs": {}
},
"2": {
"class_type": "ProxyTestSealedWorker",
"inputs": {}
},
"3": {
"class_type": "UVSealedBoltonsSlugify",
"inputs": {}
},
"4": {
"class_type": "UVSealedLatentEcho",
"inputs": {}
},
"5": {
"class_type": "UVSealedRuntimeProbe",
"inputs": {}
}
}

View File

@ -0,0 +1,22 @@
{
"1": {
"class_type": "CondaSealedLatentEcho",
"inputs": {}
},
"2": {
"class_type": "CondaSealedOpenWeatherDataset",
"inputs": {}
},
"3": {
"class_type": "CondaSealedRuntimeProbe",
"inputs": {}
},
"4": {
"class_type": "EmptyLatentImage",
"inputs": {}
},
"5": {
"class_type": "ProxyTestCondaSealedWorker",
"inputs": {}
}
}

View File

@ -0,0 +1,22 @@
{
"1": {
"class_type": "EmptyLatentImage",
"inputs": {}
},
"2": {
"class_type": "ProxyTestSealedWorker",
"inputs": {}
},
"3": {
"class_type": "UVSealedBoltonsSlugify",
"inputs": {}
},
"4": {
"class_type": "UVSealedLatentEcho",
"inputs": {}
},
"5": {
"class_type": "UVSealedRuntimeProbe",
"inputs": {}
}
}

View File

@ -0,0 +1,22 @@
{
"1": {
"class_type": "CondaSealedLatentEcho",
"inputs": {}
},
"2": {
"class_type": "CondaSealedOpenWeatherDataset",
"inputs": {}
},
"3": {
"class_type": "CondaSealedRuntimeProbe",
"inputs": {}
},
"4": {
"class_type": "EmptyLatentImage",
"inputs": {}
},
"5": {
"class_type": "ProxyTestCondaSealedWorker",
"inputs": {}
}
}

View File

@ -1,5 +1,8 @@
import os
import subprocess
import sys
import textwrap
import types
from pathlib import Path
repo_root = Path(__file__).resolve().parents[1]
@ -8,6 +11,8 @@ if pyisolate_root.exists():
sys.path.insert(0, str(pyisolate_root))
from comfy.isolation.adapter import ComfyUIAdapter
from pyisolate._internal.sandbox import build_bwrap_command
from pyisolate._internal.sandbox_detect import RestrictionModel
from pyisolate._internal.serialization_registry import SerializerRegistry
@ -49,3 +54,69 @@ def test_register_serializers():
assert registry.has_handler("VAE")
registry.clear()
def test_child_temp_directory_fence_uses_private_tmp(tmp_path):
adapter = ComfyUIAdapter()
child_script = textwrap.dedent(
"""
from pathlib import Path
child_temp = Path("/tmp/comfyui_temp")
child_temp.mkdir(parents=True, exist_ok=True)
scratch = child_temp / "child_only.txt"
scratch.write_text("child-only", encoding="utf-8")
print(f"CHILD_TEMP={child_temp}")
print(f"CHILD_FILE={scratch}")
"""
)
fake_folder_paths = types.SimpleNamespace(
temp_directory="/host/tmp/should_not_survive",
folder_names_and_paths={},
extension_mimetypes_cache={},
filename_list_cache={},
)
class FolderPathsProxy:
def get_temp_directory(self):
return "/host/tmp/should_not_survive"
original_folder_paths = sys.modules.get("folder_paths")
sys.modules["folder_paths"] = fake_folder_paths
try:
os.environ["PYISOLATE_CHILD"] = "1"
adapter.handle_api_registration(FolderPathsProxy, rpc=None)
finally:
os.environ.pop("PYISOLATE_CHILD", None)
if original_folder_paths is not None:
sys.modules["folder_paths"] = original_folder_paths
else:
sys.modules.pop("folder_paths", None)
assert fake_folder_paths.temp_directory == "/tmp/comfyui_temp"
host_child_file = Path("/tmp/comfyui_temp/child_only.txt")
if host_child_file.exists():
host_child_file.unlink()
cmd = build_bwrap_command(
python_exe=sys.executable,
module_path=str(repo_root / "custom_nodes" / "ComfyUI-IsolationToolkit"),
venv_path=str(repo_root / ".venv"),
uds_address=str(tmp_path / "adapter.sock"),
allow_gpu=False,
restriction_model=RestrictionModel.NONE,
sandbox_config={"writable_paths": ["/dev/shm"], "readonly_paths": [], "network": False},
adapter=adapter,
)
assert "--tmpfs" in cmd and "/tmp" in cmd
assert ["--bind", "/tmp", "/tmp"] not in [cmd[i : i + 3] for i in range(len(cmd) - 2)]
command_tail = cmd[-3:]
assert command_tail[1:] == ["-m", "pyisolate._internal.uds_client"]
cmd = cmd[:-3] + [sys.executable, "-c", child_script]
completed = subprocess.run(cmd, check=True, capture_output=True, text=True)
assert "CHILD_TEMP=/tmp/comfyui_temp" in completed.stdout
assert not host_child_file.exists(), "Child temp file leaked into host /tmp"