diff --git a/.gitignore b/.gitignore index 2700ad5c2..f893b5f14 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ web_custom_versions/ openapi.yaml filtered-openapi.yaml uv.lock +.pyisolate_venvs/ diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..7f57bc269 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -184,6 +184,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") +parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.") + parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") diff --git a/comfy/hooks.py b/comfy/hooks.py index 1a76c7ba4..7a5f69ca7 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -14,6 +14,9 @@ if TYPE_CHECKING: import comfy.lora import comfy.model_management import comfy.patcher_extension +from comfy.cli_args import args +import uuid +import os from node_helpers import conditioning_set_values # ####################################################################################################### @@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum): HookedOnly = "hooked_only" +_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + + class _HookRef: - pass + def __init__(self): + if _ISOLATION_HOOKREF_MODE: + self._pyisolate_id = str(uuid.uuid4()) + + def _ensure_pyisolate_id(self): + pyisolate_id = getattr(self, "_pyisolate_id", None) + if pyisolate_id is None: + pyisolate_id = str(uuid.uuid4()) + self._pyisolate_id = pyisolate_id + return pyisolate_id + + def __eq__(self, other): + if not _ISOLATION_HOOKREF_MODE: + return self is other + if not isinstance(other, _HookRef): + return False + return self._ensure_pyisolate_id() == other._ensure_pyisolate_id() + + def __hash__(self): + if not _ISOLATION_HOOKREF_MODE: + return id(self) + return hash(self._ensure_pyisolate_id()) + + def __str__(self): + if not _ISOLATION_HOOKREF_MODE: + return super().__str__() + return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}" def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -168,6 +200,8 @@ class WeightHook(Hook): key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) else: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if self.weights is None: + self.weights = {} weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: if target == EnumWeightTarget.Clip: diff --git a/comfy/isolation/__init__.py b/comfy/isolation/__init__.py new file mode 100644 index 000000000..18ce059c6 --- /dev/null +++ b/comfy/isolation/__init__.py @@ -0,0 +1,442 @@ +# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation +from __future__ import annotations +import asyncio +import inspect +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Set, TYPE_CHECKING +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +load_isolated_node = None +find_manifest_directories = None +build_stub_class = None +get_class_types_for_extension = None +scan_shm_forensics = None +start_shm_forensics = None + +if _IMPORT_TORCH: + import folder_paths + from .extension_loader import load_isolated_node + from .manifest_loader import find_manifest_directories + from .runtime_helpers import build_stub_class, get_class_types_for_extension + from .shm_forensics import scan_shm_forensics, start_shm_forensics + +if TYPE_CHECKING: + from pyisolate import ExtensionManager + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +isolated_node_timings: List[tuple[float, Path, int]] = [] + +if _IMPORT_TORCH: + PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs" + PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True) + +logger = logging.getLogger(__name__) +_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 +_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000 + + +def initialize_proxies() -> None: + from .child_hooks import is_child_process + + is_child = is_child_process() + logger.warning( + "%s DIAG:initialize_proxies | is_child=%s | PYISOLATE_CHILD=%s", + LOG_PREFIX, is_child, os.environ.get("PYISOLATE_CHILD"), + ) + + if is_child: + from .child_hooks import initialize_child_process + + initialize_child_process() + logger.warning("%s DIAG:initialize_proxies child_process initialized", LOG_PREFIX) + else: + from .host_hooks import initialize_host_process + + initialize_host_process() + logger.warning("%s DIAG:initialize_proxies host_process initialized", LOG_PREFIX) + if start_shm_forensics is not None: + start_shm_forensics() + + +@dataclass(frozen=True) +class IsolatedNodeSpec: + node_name: str + display_name: str + stub_class: type + module_path: Path + + +_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = [] +_CLAIMED_PATHS: Set[Path] = set() +_ISOLATION_SCAN_ATTEMPTED = False +_EXTENSION_MANAGERS: List["ExtensionManager"] = [] +_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {} +_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None +_EARLY_START_TIME: Optional[float] = None + + +def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + return + _EARLY_START_TIME = time.perf_counter() + _ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes()) + + +async def await_isolation_loading() -> List[IsolatedNodeSpec]: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + specs = await _ISOLATION_BACKGROUND_TASK + return specs + return await initialize_isolation_nodes() + + +async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]: + global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS + + if _ISOLATED_NODE_SPECS: + return _ISOLATED_NODE_SPECS + + if _ISOLATION_SCAN_ATTEMPTED: + return [] + + _ISOLATION_SCAN_ATTEMPTED = True + if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None: + return [] + manifest_entries = find_manifest_directories() + _CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries} + + if not manifest_entries: + return [] + + os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1" + concurrency_limit = max(1, (os.cpu_count() or 4) // 2) + semaphore = asyncio.Semaphore(concurrency_limit) + + async def load_with_semaphore( + node_dir: Path, manifest: Path + ) -> List[IsolatedNodeSpec]: + async with semaphore: + load_start = time.perf_counter() + spec_list = await load_isolated_node( + node_dir, + manifest, + logger, + lambda name, info, extension: build_stub_class( + name, + info, + extension, + _RUNNING_EXTENSIONS, + logger, + ), + PYISOLATE_VENV_ROOT, + _EXTENSION_MANAGERS, + ) + spec_list = [ + IsolatedNodeSpec( + node_name=node_name, + display_name=display_name, + stub_class=stub_cls, + module_path=node_dir, + ) + for node_name, display_name, stub_cls in spec_list + ] + isolated_node_timings.append( + (time.perf_counter() - load_start, node_dir, len(spec_list)) + ) + return spec_list + + tasks = [ + load_with_semaphore(node_dir, manifest) + for node_dir, manifest in manifest_entries + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + specs: List[IsolatedNodeSpec] = [] + for result in results: + if isinstance(result, Exception): + logger.error( + "%s Isolated node failed during startup; continuing: %s", + LOG_PREFIX, + result, + ) + continue + specs.extend(result) + + _ISOLATED_NODE_SPECS = specs + return list(_ISOLATED_NODE_SPECS) + + +def _get_class_types_for_extension(extension_name: str) -> Set[str]: + """Get all node class types (node names) belonging to an extension.""" + extension = _RUNNING_EXTENSIONS.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in _ISOLATED_NODE_SPECS: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + + return class_types + + +async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None: + """Evict running extensions not needed for current execution. + + When *caches* is provided, cache entries for evicted extensions' node + class_types are invalidated to prevent stale ``RemoteObjectHandle`` + references from surviving in the output cache. + """ + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:notify_graph_wait_idle", + ) + + evicted_class_types: Set[str] = set() + + async def _stop_extension( + ext_name: str, extension: "ComfyNodeExtension", reason: str + ) -> None: + # Collect class_types BEFORE stopping so we can invalidate cache entries. + ext_class_types = _get_class_types_for_extension(ext_name) + evicted_class_types.update(ext_class_types) + logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason) + logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name) + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + _RUNNING_EXTENSIONS.pop(ext_name, None) + logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name) + if scan_shm_forensics is not None: + scan_shm_forensics("ISO:stop_extension", refresh_model_context=True) + + if scan_shm_forensics is not None: + scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True) + isolated_class_types_in_graph = needed_class_types.intersection( + {spec.node_name for spec in _ISOLATED_NODE_SPECS} + ) + graph_uses_isolation = bool(isolated_class_types_in_graph) + logger.debug( + "%s ISO:notify_graph_start running=%d needed=%d", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + len(needed_class_types), + ) + if graph_uses_isolation: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + ext_class_types = _get_class_types_for_extension(ext_name) + + # If NONE of this extension's nodes are in the execution graph -> evict. + if not ext_class_types.intersection(needed_class_types): + await _stop_extension( + ext_name, + extension, + "isolated custom_node not in execution graph, evicting", + ) + else: + logger.debug( + "%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + ) + + # Isolated child processes add steady VRAM pressure; reclaim host-side models + # at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom. + try: + import comfy.model_management as model_management + + device = model_management.get_torch_device() + if getattr(device, "type", None) == "cuda": + required = max( + model_management.minimum_inference_memory(), + _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES, + ) + free_before = model_management.get_free_memory(device) + if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + await _stop_extension( + ext_name, + extension, + f"boundary low-vram restart (free={int(free_before)} target={int(required)})", + ) + if model_management.get_free_memory(device) < required: + model_management.unload_all_models() + model_management.cleanup_models_gc() + model_management.cleanup_models() + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.soft_empty_cache() + except Exception: + logger.debug( + "%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True + ) + finally: + # Invalidate cached outputs for evicted extensions so stale + # RemoteObjectHandle references are not served from cache. + if evicted_class_types and caches: + total_invalidated = 0 + for cache in caches: + if hasattr(cache, "invalidate_by_class_types"): + total_invalidated += cache.invalidate_by_class_types( + evicted_class_types + ) + if total_invalidated > 0: + logger.info( + "%s ISO:cache_invalidated count=%d class_types=%s", + LOG_PREFIX, + total_invalidated, + evicted_class_types, + ) + scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True) + logger.debug( + "%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS) + ) + + +async def flush_running_extensions_transport_state() -> int: + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:flush_transport_wait_idle", + ) + total_flushed = 0 + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + flush_fn = getattr(extension, "flush_transport_state", None) + if not callable(flush_fn): + continue + try: + flushed = await flush_fn() + if isinstance(flushed, int): + total_flushed += flushed + if flushed > 0: + logger.debug( + "%s %s workflow-end flush released=%d", + LOG_PREFIX, + ext_name, + flushed, + ) + except Exception: + logger.debug( + "%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True + ) + scan_shm_forensics( + "ISO:flush_running_extensions_transport_state", refresh_model_context=True + ) + return total_flushed + + +async def wait_for_model_patcher_quiescence( + timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS, + *, + fail_loud: bool = False, + marker: str = "ISO:wait_model_patcher_idle", +) -> bool: + try: + from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry + + registry = ModelPatcherRegistry() + start = time.perf_counter() + idle = await registry.wait_all_idle(timeout_ms) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + if idle: + logger.debug( + "%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + ) + return True + + states = await registry.get_all_operation_states() + logger.error( + "%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + states, + ) + if fail_loud: + raise TimeoutError( + f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms" + ) + return False + except Exception: + if fail_loud: + raise + logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True) + return False + + +def get_claimed_paths() -> Set[Path]: + return _CLAIMED_PATHS + + +def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None: + """Update all active RPC instances with the current event loop. + + This MUST be called at the start of each workflow execution to ensure + RPC calls are scheduled on the correct event loop. This handles the case + where asyncio.run() creates a new event loop for each workflow. + + Args: + loop: The event loop to use. If None, uses asyncio.get_running_loop(). + """ + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + + update_count = 0 + + # Update RPCs from ExtensionManagers + for manager in _EXTENSION_MANAGERS: + if not hasattr(manager, "extensions"): + continue + for name, extension in manager.extensions.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'") + + # Also update RPCs from running extensions (they may have direct RPC refs) + for name, extension in _RUNNING_EXTENSIONS.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'") + + if update_count > 0: + logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances") + else: + logger.debug( + f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})" + ) + + +__all__ = [ + "LOG_PREFIX", + "initialize_proxies", + "initialize_isolation_nodes", + "start_isolation_loading_early", + "await_isolation_loading", + "notify_execution_graph", + "flush_running_extensions_transport_state", + "wait_for_model_patcher_quiescence", + "get_claimed_paths", + "update_rpc_event_loops", + "IsolatedNodeSpec", + "get_class_types_for_extension", +] diff --git a/comfy/isolation/adapter.py b/comfy/isolation/adapter.py new file mode 100644 index 000000000..4751dee51 --- /dev/null +++ b/comfy/isolation/adapter.py @@ -0,0 +1,965 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position +from __future__ import annotations + +import logging +import os +import inspect +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, cast + +from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped] +from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped] + +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp. +# Safe to import in sealed workers without host framework modules. +from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy +from comfy.isolation.proxies.helper_proxies import HelperProxiesService +from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + +# Singleton proxies that transitively import torch, PIL, or heavy host modules. +# Only available when torch/host framework is present. +CLIPProxy = None +CLIPRegistry = None +ModelPatcherProxy = None +ModelPatcherRegistry = None +ModelSamplingProxy = None +ModelSamplingRegistry = None +VAEProxy = None +VAERegistry = None +FirstStageModelRegistry = None +ModelManagementProxy = None +PromptServerService = None +ProgressProxy = None +UtilsProxy = None +_HAS_TORCH_PROXIES = False +if _IMPORT_TORCH: + from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry + from comfy.isolation.model_patcher_proxy import ( + ModelPatcherProxy, + ModelPatcherRegistry, + ) + from comfy.isolation.model_sampling_proxy import ( + ModelSamplingProxy, + ModelSamplingRegistry, + ) + from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + from comfy.isolation.proxies.prompt_server_impl import PromptServerService + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + _HAS_TORCH_PROXIES = True + +logger = logging.getLogger(__name__) + +# Force /dev/shm for shared memory (bwrap makes /tmp private) +import tempfile + +if os.path.exists("/dev/shm"): + # Only override if not already set or if default is not /dev/shm + current_tmp = tempfile.gettempdir() + if not current_tmp.startswith("/dev/shm"): + logger.debug( + f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm" + ) + os.environ["TMPDIR"] = "/dev/shm" + tempfile.tempdir = None # Clear cache to force re-evaluation + + +class ComfyUIAdapter(IsolationAdapter): + # ComfyUI-specific IsolationAdapter implementation + + @property + def identifier(self) -> str: + return "comfyui" + + def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]: + if "ComfyUI" in module_path and "custom_nodes" in module_path: + parts = module_path.split("ComfyUI") + if len(parts) > 1: + comfy_root = parts[0] + "ComfyUI" + return { + "preferred_root": comfy_root, + "additional_paths": [ + os.path.join(comfy_root, "custom_nodes"), + os.path.join(comfy_root, "comfy"), + ], + "filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"], + } + return None + + def get_sandbox_system_paths(self) -> Optional[List[str]]: + """Returns required application paths to mount in the sandbox.""" + # By inspecting where our adapter is loaded from, we can determine the comfy root + adapter_file = inspect.getfile(self.__class__) + # adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py + comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file))) + if os.path.exists(comfy_root): + return [comfy_root] + return None + + def setup_child_environment(self, snapshot: Dict[str, Any]) -> None: + comfy_root = snapshot.get("preferred_root") + if not comfy_root: + return + + requirements_path = Path(comfy_root) / "requirements.txt" + if requirements_path.exists(): + import re + + for line in requirements_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + pkg_name = re.split(r"[<>=!~\[]", line)[0].strip() + if pkg_name: + logging.getLogger(pkg_name).setLevel(logging.ERROR) + + def register_serializers(self, registry: SerializerRegistryProtocol) -> None: + if not _IMPORT_TORCH: + # Sealed worker without torch — register torch-free TensorValue handler + # so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts. + import numpy as np + + _TORCH_DTYPE_TO_NUMPY = { + "torch.float32": np.float32, + "torch.float64": np.float64, + "torch.float16": np.float16, + "torch.bfloat16": np.float32, # numpy has no bfloat16; upcast + "torch.int32": np.int32, + "torch.int64": np.int64, + "torch.int16": np.int16, + "torch.int8": np.int8, + "torch.uint8": np.uint8, + "torch.bool": np.bool_, + } + + def _deserialize_tensor_value(data: Dict[str, Any]) -> Any: + dtype_str = data["dtype"] + np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32) + shape = tuple(data["tensor_size"]) + arr = np.array(data["data"], dtype=np_dtype).reshape(shape) + return arr + + _NUMPY_TO_TORCH_DTYPE = { + np.float32: "torch.float32", + np.float64: "torch.float64", + np.float16: "torch.float16", + np.int32: "torch.int32", + np.int64: "torch.int64", + np.int16: "torch.int16", + np.int8: "torch.int8", + np.uint8: "torch.uint8", + np.bool_: "torch.bool", + } + + def _serialize_tensor_value(obj: Any) -> Dict[str, Any]: + arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj) + dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32") + return { + "__type__": "TensorValue", + "dtype": dtype_str, + "tensor_size": list(arr.shape), + "requires_grad": False, + "data": arr.tolist(), + } + + registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True) + # ndarray output from sealed workers serializes as TensorValue for host torch reconstruction + registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True) + return + + import torch + + def serialize_device(obj: Any) -> Dict[str, Any]: + return {"__type__": "device", "device_str": str(obj)} + + def deserialize_device(data: Dict[str, Any]) -> Any: + return torch.device(data["device_str"]) + + registry.register("device", serialize_device, deserialize_device) + + _VALID_DTYPES = { + "float16", "float32", "float64", "bfloat16", + "int8", "int16", "int32", "int64", + "uint8", "bool", + } + + def serialize_dtype(obj: Any) -> Dict[str, Any]: + return {"__type__": "dtype", "dtype_str": str(obj)} + + def deserialize_dtype(data: Dict[str, Any]) -> Any: + dtype_name = data["dtype_str"].replace("torch.", "") + if dtype_name not in _VALID_DTYPES: + raise ValueError(f"Invalid dtype: {data['dtype_str']}") + return getattr(torch, dtype_name) + + registry.register("dtype", serialize_dtype, deserialize_dtype) + + from comfy_api.latest._io import FolderType + from comfy_api.latest._ui import SavedImages, SavedResult + + def serialize_saved_result(obj: Any) -> Dict[str, Any]: + return { + "__type__": "SavedResult", + "filename": obj.filename, + "subfolder": obj.subfolder, + "folder_type": obj.type.value, + } + + def deserialize_saved_result(data: Dict[str, Any]) -> Any: + if isinstance(data, SavedResult): + return data + folder_type = data["folder_type"] if "folder_type" in data else data["type"] + return SavedResult( + filename=data["filename"], + subfolder=data["subfolder"], + type=FolderType(folder_type), + ) + + registry.register( + "SavedResult", + serialize_saved_result, + deserialize_saved_result, + data_type=True, + ) + + def serialize_saved_images(obj: Any) -> Dict[str, Any]: + return { + "__type__": "SavedImages", + "results": [serialize_saved_result(result) for result in obj.results], + "is_animated": obj.is_animated, + } + + def deserialize_saved_images(data: Dict[str, Any]) -> Any: + return SavedImages( + results=[deserialize_saved_result(result) for result in data["results"]], + is_animated=data.get("is_animated", False), + ) + + registry.register( + "SavedImages", + serialize_saved_images, + deserialize_saved_images, + data_type=True, + ) + + def serialize_model_patcher(obj: Any) -> Dict[str, Any]: + # Child-side: must already have _instance_id (proxy) + if os.environ.get("PYISOLATE_CHILD") == "1": + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + raise RuntimeError( + f"ModelPatcher in child lacks _instance_id: " + f"{type(obj).__module__}.{type(obj).__name__}" + ) + # Host-side: register with registry + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + model_id = ModelPatcherRegistry().register(obj) + return {"__type__": "ModelPatcherRef", "model_id": model_id} + + def deserialize_model_patcher(data: Any) -> Any: + """Deserialize ModelPatcher refs; pass through already-materialized objects.""" + if isinstance(data, dict): + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + return data + + def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelPatcherRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + else: + return ModelPatcherRegistry()._get_instance(data["model_id"]) + + # Register ModelPatcher type for serialization + registry.register( + "ModelPatcher", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherProxy type (already a proxy, just return ref) + registry.register( + "ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherRef for deserialization (context-aware: host or child) + registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref) + + def serialize_clip(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "CLIPRef", "clip_id": obj._instance_id} + clip_id = CLIPRegistry().register(obj) + return {"__type__": "CLIPRef", "clip_id": clip_id} + + def deserialize_clip(data: Any) -> Any: + if isinstance(data, dict): + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + return data + + def deserialize_clip_ref(data: Dict[str, Any]) -> Any: + """Context-aware CLIPRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + else: + return CLIPRegistry()._get_instance(data["clip_id"]) + + # Register CLIP type for serialization + registry.register("CLIP", serialize_clip, deserialize_clip) + # Register CLIPProxy type (already a proxy, just return ref) + registry.register("CLIPProxy", serialize_clip, deserialize_clip) + # Register CLIPRef for deserialization (context-aware: host or child) + registry.register("CLIPRef", None, deserialize_clip_ref) + + def serialize_vae(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "VAERef", "vae_id": obj._instance_id} + vae_id = VAERegistry().register(obj) + return {"__type__": "VAERef", "vae_id": vae_id} + + def deserialize_vae(data: Any) -> Any: + if isinstance(data, dict): + return VAEProxy(data["vae_id"]) + return data + + def deserialize_vae_ref(data: Dict[str, Any]) -> Any: + """Context-aware VAERef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + # Child: create a proxy + return VAEProxy(data["vae_id"]) + else: + # Host: lookup real VAE from registry + return VAERegistry()._get_instance(data["vae_id"]) + + # Register VAE type for serialization + registry.register("VAE", serialize_vae, deserialize_vae) + # Register VAEProxy type (already a proxy, just return ref) + registry.register("VAEProxy", serialize_vae, deserialize_vae) + # Register VAERef for deserialization (context-aware: host or child) + registry.register("VAERef", None, deserialize_vae_ref) + + # ModelSampling serialization - handles ModelSampling* types + # copyreg removed - no pickle fallback allowed + + def serialize_model_sampling(obj: Any) -> Dict[str, Any]: + # Proxy with _instance_id — return ref (works from both host and child) + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} + # Child-side: object created locally in child (e.g. ModelSamplingAdvanced + # in nodes_z_image_turbo.py). Serialize as inline data so the host can + # reconstruct the real torch.nn.Module. + if os.environ.get("PYISOLATE_CHILD") == "1": + import base64 + import io as _io + + # Identify base classes from comfy.model_sampling + bases = [] + for base in type(obj).__mro__: + if base.__module__ == "comfy.model_sampling" and base.__name__ != "object": + bases.append(base.__name__) + # Serialize state_dict as base64 safetensors-like + sd = obj.state_dict() + sd_serialized = {} + for k, v in sd.items(): + buf = _io.BytesIO() + torch.save(v, buf) + sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii") + # Capture plain attrs (shift, multiplier, sigma_data, etc.) + plain_attrs = {} + for k, v in obj.__dict__.items(): + if k.startswith("_"): + continue + if isinstance(v, (bool, int, float, str)): + plain_attrs[k] = v + return { + "__type__": "ModelSamplingInline", + "bases": bases, + "state_dict": sd_serialized, + "attrs": plain_attrs, + } + # Host-side: register with ModelSamplingRegistry and return JSON-safe dict + ms_id = ModelSamplingRegistry().register(obj) + return {"__type__": "ModelSamplingRef", "ms_id": ms_id} + + def deserialize_model_sampling(data: Any) -> Any: + """Deserialize ModelSampling refs or inline data.""" + if isinstance(data, dict): + if data.get("__type__") == "ModelSamplingInline": + return _reconstruct_model_sampling_inline(data) + return ModelSamplingProxy(data["ms_id"]) + return data + + def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any: + """Reconstruct a ModelSampling object on the host from inline child data.""" + import comfy.model_sampling as _ms + import base64 + import io as _io + + # Resolve base classes + base_classes = [] + for name in data["bases"]: + cls = getattr(_ms, name, None) + if cls is not None: + base_classes.append(cls) + if not base_classes: + raise RuntimeError( + f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}" + ) + # Create dynamic class matching the child's class hierarchy + ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {}) + obj = ReconstructedSampling.__new__(ReconstructedSampling) + torch.nn.Module.__init__(obj) + # Restore plain attributes first + for k, v in data.get("attrs", {}).items(): + setattr(obj, k, v) + # Restore state_dict (buffers like sigmas) + for k, v_b64 in data.get("state_dict", {}).items(): + buf = _io.BytesIO(base64.b64decode(v_b64)) + tensor = torch.load(buf, weights_only=True) + # Register as buffer so it's part of state_dict + parts = k.split(".") + if len(parts) == 1: + cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member + else: + setattr(obj, parts[0], tensor) + # Register on host so future references use proxy pattern. + # Skip in child process — register() is async RPC and cannot be + # called synchronously during deserialization. + if os.environ.get("PYISOLATE_CHILD") != "1": + ModelSamplingRegistry().register(obj) + return obj + + def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelSamplingRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelSamplingProxy(data["ms_id"]) + else: + return ModelSamplingRegistry()._get_instance(data["ms_id"]) + + # Register all ModelSampling* and StableCascadeSampling classes dynamically + import comfy.model_sampling + + for ms_cls in vars(comfy.model_sampling).values(): + if not isinstance(ms_cls, type): + continue + if not issubclass(ms_cls, torch.nn.Module): + continue + if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"): + continue + registry.register( + ms_cls.__name__, + serialize_model_sampling, + deserialize_model_sampling, + ) + registry.register( + "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling + ) + # Register ModelSamplingRef for deserialization (context-aware: host or child) + registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref) + # Register ModelSamplingInline for deserialization (child→host inline transfer) + registry.register( + "ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data) + ) + + def serialize_cond(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "cond": obj.cond, + } + + def deserialize_cond(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + return cls(data["cond"]) + + def _serialize_public_state(obj: Any) -> Dict[str, Any]: + state: Dict[str, Any] = {} + for key, value in obj.__dict__.items(): + if key.startswith("_"): + continue + if callable(value): + continue + state[key] = value + return state + + def serialize_latent_format(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "state": _serialize_public_state(obj), + } + + def deserialize_latent_format(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + obj = cls() + for key, value in data.get("state", {}).items(): + prop = getattr(type(obj), key, None) + if isinstance(prop, property) and prop.fset is None: + continue + setattr(obj, key, value) + return obj + + import comfy.conds + + for cond_cls in vars(comfy.conds).values(): + if not isinstance(cond_cls, type): + continue + if not issubclass(cond_cls, comfy.conds.CONDRegular): + continue + type_key = f"{cond_cls.__module__}.{cond_cls.__name__}" + registry.register(type_key, serialize_cond, deserialize_cond) + registry.register(cond_cls.__name__, serialize_cond, deserialize_cond) + + import comfy.latent_formats + + for latent_cls in vars(comfy.latent_formats).values(): + if not isinstance(latent_cls, type): + continue + if not issubclass(latent_cls, comfy.latent_formats.LatentFormat): + continue + type_key = f"{latent_cls.__module__}.{latent_cls.__name__}" + registry.register( + type_key, serialize_latent_format, deserialize_latent_format + ) + registry.register( + latent_cls.__name__, serialize_latent_format, deserialize_latent_format + ) + + # V3 API: unwrap NodeOutput.args + def deserialize_node_output(data: Any) -> Any: + return getattr(data, "args", data) + + registry.register("NodeOutput", None, deserialize_node_output) + + # KSAMPLER serializer: stores sampler name instead of function object + # sampler_function is a callable which gets filtered out by JSONSocketTransport + def serialize_ksampler(obj: Any) -> Dict[str, Any]: + func_name = obj.sampler_function.__name__ + # Map function name back to sampler name + if func_name == "sample_unipc": + sampler_name = "uni_pc" + elif func_name == "sample_unipc_bh2": + sampler_name = "uni_pc_bh2" + elif func_name == "dpm_fast_function": + sampler_name = "dpm_fast" + elif func_name == "dpm_adaptive_function": + sampler_name = "dpm_adaptive" + elif func_name.startswith("sample_"): + sampler_name = func_name[7:] # Remove "sample_" prefix + else: + sampler_name = func_name + return { + "__type__": "KSAMPLER", + "sampler_name": sampler_name, + "extra_options": obj.extra_options, + "inpaint_options": obj.inpaint_options, + } + + def deserialize_ksampler(data: Dict[str, Any]) -> Any: + import comfy.samplers + + return comfy.samplers.ksampler( + data["sampler_name"], + data.get("extra_options", {}), + data.get("inpaint_options", {}), + ) + + registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler) + + from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers + + register_hooks_serializers(registry) + + # Generic Numpy Serializer + def serialize_numpy(obj: Any) -> Any: + import torch + + try: + # Attempt zero-copy conversion to Tensor + return torch.from_numpy(obj) + except Exception: + # Fallback for non-numeric arrays (strings, objects, mixes) + return obj.tolist() + + def deserialize_numpy_b64(data: Any) -> Any: + """Deserialize base64-encoded ndarray from sealed worker.""" + import base64 + import numpy as np + if isinstance(data, dict) and "data" in data and "dtype" in data: + raw = base64.b64decode(data["data"]) + arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"]) + return torch.from_numpy(arr.copy()) + return data + + registry.register("ndarray", serialize_numpy, deserialize_numpy_b64) + + # -- File3D (comfy_api.latest._util.geometry_types) --------------------- + # Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129 + + def serialize_file3d(obj: Any) -> Dict[str, Any]: + import base64 + return { + "__type__": "File3D", + "format": obj.format, + "data": base64.b64encode(obj.get_bytes()).decode("ascii"), + } + + def deserialize_file3d(data: Any) -> Any: + import base64 + from io import BytesIO + from comfy_api.latest._util.geometry_types import File3D + return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"]) + + registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True) + + # -- VIDEO (comfy_api.latest._input_impl.video_types) ------------------- + # Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962 + + def serialize_video(obj: Any) -> Dict[str, Any]: + components = obj.get_components() + images = components.images.detach() if components.images.requires_grad else components.images + result: Dict[str, Any] = { + "__type__": "VIDEO", + "images": images, + "frame_rate_num": components.frame_rate.numerator, + "frame_rate_den": components.frame_rate.denominator, + } + if components.audio is not None: + waveform = components.audio["waveform"] + if waveform.requires_grad: + waveform = waveform.detach() + result["audio_waveform"] = waveform + result["audio_sample_rate"] = components.audio["sample_rate"] + if components.metadata is not None: + result["metadata"] = components.metadata + return result + + def deserialize_video(data: Any) -> Any: + from fractions import Fraction + from comfy_api.latest._input_impl.video_types import VideoFromComponents + from comfy_api.latest._util.video_types import VideoComponents + audio = None + if "audio_waveform" in data: + audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]} + components = VideoComponents( + images=data["images"], + frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]), + audio=audio, + metadata=data.get("metadata"), + ) + return VideoFromComponents(components) + + registry.register("VIDEO", serialize_video, deserialize_video, data_type=True) + registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True) + registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True) + + def setup_web_directory(self, module: Any) -> None: + """Detect WEB_DIRECTORY on a module and populate/register it. + + Called by the sealed worker after loading the node module. + Mirrors extension_wrapper.py:216-227 for host-coupled nodes. + Does NOT import extension_wrapper.py (it has `import torch` at module level). + """ + import shutil + + web_dir_attr = getattr(module, "WEB_DIRECTORY", None) + if web_dir_attr is None: + return + + module_dir = os.path.dirname(os.path.abspath(module.__file__)) + web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr)) + + # Read extension name from pyproject.toml + ext_name = os.path.basename(module_dir) + pyproject = os.path.join(module_dir, "pyproject.toml") + if os.path.exists(pyproject): + try: + import tomllib + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + try: + with open(pyproject, "rb") as f: + data = tomllib.load(f) + name = data.get("project", {}).get("name") + if name: + ext_name = name + except Exception: + pass + + # Populate web dir if empty (mirrors _run_prestartup_web_copy) + if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))): + os.makedirs(web_dir_path, exist_ok=True) + + # Module-defined copy spec + copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None) + if copy_spec is not None and callable(copy_spec): + try: + copy_spec(web_dir_path) + except Exception as e: + logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e) + + # Fallback: comfy_3d_viewers + try: + from comfy_3d_viewers import copy_viewer, VIEWER_FILES + for viewer in VIEWER_FILES: + try: + copy_viewer(viewer, web_dir_path) + except Exception: + pass + except ImportError: + pass + + # Fallback: comfy_dynamic_widgets + try: + from comfy_dynamic_widgets import get_js_path + src = os.path.realpath(get_js_path()) + if os.path.exists(src): + dst_dir = os.path.join(web_dir_path, "js") + os.makedirs(dst_dir, exist_ok=True) + shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js")) + except ImportError: + pass + + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + WebDirectoryProxy.register_web_dir(ext_name, web_dir_path) + logger.info( + "][ Adapter: registered web dir for %s (%d files)", + ext_name, + sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()), + ) + + @staticmethod + def register_host_event_handlers(extension: Any) -> None: + """Register host-side event handlers for an isolated extension. + + Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK`` + so the ComfyUI frontend receives progress bar updates. + """ + register_event_handler = inspect.getattr_static( + extension, "register_event_handler", None + ) + if not callable(register_event_handler): + return + + def _host_progress_handler(payload: dict) -> None: + import comfy.utils + + hook = comfy.utils.PROGRESS_BAR_HOOK + if hook is not None: + hook( + payload.get("value", 0), + payload.get("total", 0), + payload.get("preview"), + payload.get("node_id"), + ) + + extension.register_event_handler("progress", _host_progress_handler) + + def setup_child_event_hooks(self, extension: Any) -> None: + """Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension. + + Host-coupled only — sealed workers do not have comfy.utils (torch). + """ + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child) + if not is_child: + return + + if not _IMPORT_TORCH: + logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)") + return + + import comfy.utils + + def _event_progress_hook(value, total, preview=None, node_id=None): + logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id) + extension.emit_event("progress", { + "value": value, + "total": total, + "node_id": node_id, + }) + + comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook + logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel") + + def provide_rpc_services(self) -> List[type[ProxiedSingleton]]: + # Always available — no torch/PIL dependency + services: List[type[ProxiedSingleton]] = [ + FolderPathsProxy, + HelperProxiesService, + WebDirectoryProxy, + ] + # Torch/PIL-dependent proxies + if _HAS_TORCH_PROXIES: + services.extend([ + PromptServerService, + ModelManagementProxy, + UtilsProxy, + ProgressProxy, + VAERegistry, + CLIPRegistry, + ModelPatcherRegistry, + ModelSamplingRegistry, + FirstStageModelRegistry, + ]) + return services + + def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None: + # Resolve the real name whether it's an instance or the Singleton class itself + api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__ + + if api_name == "FolderPathsProxy": + import folder_paths + + # Replace module-level functions with proxy methods + # This is aggressive but necessary for transparent proxying + # Handle both instance and class cases + instance = api() if isinstance(api, type) else api + for name in dir(instance): + if not name.startswith("_"): + setattr(folder_paths, name, getattr(instance, name)) + + # Fence: isolated children get writable temp inside sandbox + if os.environ.get("PYISOLATE_CHILD") == "1": + import tempfile + _child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp") + os.makedirs(_child_temp, exist_ok=True) + folder_paths.temp_directory = _child_temp + + return + + if api_name == "ModelManagementProxy": + if _IMPORT_TORCH: + import comfy.model_management + + instance = api() if isinstance(api, type) else api + # Replace module-level functions with proxy methods + for name in dir(instance): + if not name.startswith("_"): + setattr(comfy.model_management, name, getattr(instance, name)) + return + + if api_name == "UtilsProxy": + if not _IMPORT_TORCH: + logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)") + return + + import comfy.utils + + # Static Injection of RPC mechanism to ensure Child can access it + # independent of instance lifecycle. + api.set_rpc(rpc) + + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child) + + # Progress hook wiring moved to setup_child_event_hooks via event channel + + return + + if api_name == "PromptServerProxy": + if not _IMPORT_TORCH: + return + # Defer heavy import to child context + import server + + instance = api() if isinstance(api, type) else api + proxy = ( + instance.instance + ) # PromptServerProxy instance has .instance property returning self + + original_register_route = proxy.register_route + + def register_route_wrapper( + method: str, path: str, handler: Callable[..., Any] + ) -> None: + callback_id = rpc.register_callback(handler) + loop = getattr(rpc, "loop", None) + if loop and loop.is_running(): + import asyncio + + asyncio.create_task( + original_register_route( + method, path, handler=callback_id, is_callback=True + ) + ) + else: + original_register_route( + method, path, handler=callback_id, is_callback=True + ) + return None + + proxy.register_route = register_route_wrapper + + class RouteTableDefProxy: + def __init__(self, proxy_instance: Any): + self.proxy = proxy_instance + + def get( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("GET", path, handler) + return handler + + return decorator + + def post( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("POST", path, handler) + return handler + + return decorator + + def patch( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("PATCH", path, handler) + return handler + + return decorator + + def put( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("PUT", path, handler) + return handler + + return decorator + + def delete( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("DELETE", path, handler) + return handler + + return decorator + + proxy.routes = RouteTableDefProxy(proxy) + + if ( + hasattr(server, "PromptServer") + and getattr(server.PromptServer, "instance", None) != proxy + ): + server.PromptServer.instance = proxy diff --git a/comfy/isolation/child_hooks.py b/comfy/isolation/child_hooks.py new file mode 100644 index 000000000..a009929eb --- /dev/null +++ b/comfy/isolation/child_hooks.py @@ -0,0 +1,126 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation +# Child process initialization for PyIsolate +import logging +import os + +logger = logging.getLogger(__name__) + + +def is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def _load_extra_model_paths() -> None: + """Load extra_model_paths.yaml so the child's folder_paths has the same search paths as the host. + + The host loads this in main.py:143-145. The child is spawned by + pyisolate's uds_client.py and never runs main.py, so folder_paths + only has the base model directories. Any isolated node calling + folder_paths.get_filename_list() in define_schema() would get empty + results for folders whose files live in extra_model_paths locations. + """ + import folder_paths # noqa: F401 — side-effect import; load_extra_path_config writes to folder_paths internals + from utils.extra_config import load_extra_path_config + + extra_config_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "extra_model_paths.yaml", + ) + if os.path.isfile(extra_config_path): + load_extra_path_config(extra_config_path) + + +def initialize_child_process() -> None: + logger.warning("][ DIAG:child_hooks initialize_child_process START") + if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0": + _load_extra_model_paths() + _setup_child_loop_bridge() + + # Manual RPC injection + try: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + logger.warning("][ DIAG:child_hooks RPC instance: %s", rpc is not None) + if rpc: + _setup_proxy_callers(rpc) + logger.warning("][ DIAG:child_hooks proxy callers configured with RPC") + else: + logger.warning("][ DIAG:child_hooks NO RPC — proxy callers cleared") + _setup_proxy_callers() + except Exception as e: + logger.error(f"][ DIAG:child_hooks Manual RPC Injection failed: {e}") + _setup_proxy_callers() + + _setup_logging() + + +def _setup_child_loop_bridge() -> None: + import asyncio + + main_loop = None + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + try: + main_loop = asyncio.get_event_loop() + except RuntimeError: + pass + + if main_loop is None: + return + + try: + from .proxies.base import set_global_loop + + set_global_loop(main_loop) + except ImportError: + pass + + +def _setup_prompt_server_stub(rpc=None) -> None: + try: + from .proxies.prompt_server_impl import PromptServerStub + + if rpc: + PromptServerStub.set_rpc(rpc) + elif hasattr(PromptServerStub, "clear_rpc"): + PromptServerStub.clear_rpc() + else: + PromptServerStub._rpc = None # type: ignore[attr-defined] + + except Exception as e: + logger.error(f"Failed to setup PromptServerStub: {e}") + + +def _setup_proxy_callers(rpc=None) -> None: + try: + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.helper_proxies import HelperProxiesService + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerStub + from .proxies.utils_proxy import UtilsProxy + + if rpc is None: + FolderPathsProxy.clear_rpc() + HelperProxiesService.clear_rpc() + ModelManagementProxy.clear_rpc() + ProgressProxy.clear_rpc() + PromptServerStub.clear_rpc() + UtilsProxy.clear_rpc() + return + + FolderPathsProxy.set_rpc(rpc) + HelperProxiesService.set_rpc(rpc) + ModelManagementProxy.set_rpc(rpc) + ProgressProxy.set_rpc(rpc) + PromptServerStub.set_rpc(rpc) + UtilsProxy.set_rpc(rpc) + + except Exception as e: + logger.error(f"Failed to setup child singleton proxy callers: {e}") + + +def _setup_logging() -> None: + logging.getLogger().setLevel(logging.INFO) diff --git a/comfy/isolation/clip_proxy.py b/comfy/isolation/clip_proxy.py new file mode 100644 index 000000000..371665314 --- /dev/null +++ b/comfy/isolation/clip_proxy.py @@ -0,0 +1,327 @@ +# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation +# CLIP Proxy implementation +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, +) + +if TYPE_CHECKING: + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + +class CondStageModelRegistry(BaseRegistry[Any]): + _type_prefix = "cond_stage_model" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + +class CondStageModelProxy(BaseProxy[CondStageModelRegistry]): + _registry_class = CondStageModelRegistry + __module__ = "comfy.sd" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +class TokenizerRegistry(BaseRegistry[Any]): + _type_prefix = "tokenizer" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + +class TokenizerProxy(BaseProxy[TokenizerRegistry]): + _registry_class = TokenizerRegistry + __module__ = "comfy.sd" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +logger = logging.getLogger(__name__) + + +class CLIPRegistry(BaseRegistry[Any]): + _type_prefix = "clip" + _allowed_setters = { + "layer_idx", + "tokenizer_options", + "use_clip_schedule", + "apply_hooks_to_conds", + } + + async def get_ram_usage(self, instance_id: str) -> int: + return self._get_instance(instance_id).get_ram_usage() + + async def get_patcher_id(self, instance_id: str) -> str: + from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry + + return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher) + + async def get_cond_stage_model_id(self, instance_id: str) -> str: + return CondStageModelRegistry().register( + self._get_instance(instance_id).cond_stage_model + ) + + async def get_tokenizer_id(self, instance_id: str) -> str: + return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer) + + async def load_model(self, instance_id: str) -> None: + self._get_instance(instance_id).load_model() + + async def clip_layer(self, instance_id: str, layer_idx: int) -> None: + self._get_instance(instance_id).clip_layer(layer_idx) + + async def set_tokenizer_option( + self, instance_id: str, option_name: str, value: Any + ) -> None: + self._get_instance(instance_id).set_tokenizer_option(option_name, value) + + async def get_property(self, instance_id: str, name: str) -> Any: + return getattr(self._get_instance(instance_id), name) + + async def set_property(self, instance_id: str, name: str, value: Any) -> None: + if name not in self._allowed_setters: + raise PermissionError(f"Setting '{name}' is not allowed via RPC") + setattr(self._get_instance(instance_id), name, value) + + async def tokenize( + self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any + ) -> Any: + return self._get_instance(instance_id).tokenize( + text, return_word_ids=return_word_ids, **kwargs + ) + + async def encode(self, instance_id: str, text: str) -> Any: + return detach_if_grad(self._get_instance(instance_id).encode(text)) + + async def encode_from_tokens( + self, + instance_id: str, + tokens: Any, + return_pooled: bool = False, + return_dict: bool = False, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).encode_from_tokens( + tokens, return_pooled=return_pooled, return_dict=return_dict + ) + ) + + async def encode_from_tokens_scheduled( + self, + instance_id: str, + tokens: Any, + unprojected: bool = False, + add_dict: Optional[dict] = None, + show_pbar: bool = True, + ) -> Any: + add_dict = add_dict or {} + return detach_if_grad( + self._get_instance(instance_id).encode_from_tokens_scheduled( + tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar + ) + ) + + async def add_patches( + self, + instance_id: str, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> Any: + return self._get_instance(instance_id).add_patches( + patches, strength_patch=strength_patch, strength_model=strength_model + ) + + async def get_key_patches(self, instance_id: str) -> Any: + return self._get_instance(instance_id).get_key_patches() + + async def load_sd( + self, instance_id: str, sd: dict, full_model: bool = False + ) -> Any: + return self._get_instance(instance_id).load_sd(sd, full_model=full_model) + + async def get_sd(self, instance_id: str) -> Any: + return self._get_instance(instance_id).get_sd() + + async def clone(self, instance_id: str) -> str: + return self.register(self._get_instance(instance_id).clone()) + + +class CLIPProxy(BaseProxy[CLIPRegistry]): + _registry_class = CLIPRegistry + __module__ = "comfy.sd" + + def get_ram_usage(self) -> int: + return self._call_rpc("get_ram_usage") + + @property + def patcher(self) -> "ModelPatcherProxy": + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + if not hasattr(self, "_patcher_proxy"): + patcher_id = self._call_rpc("get_patcher_id") + self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False) + return self._patcher_proxy + + @patcher.setter + def patcher(self, value: Any) -> None: + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + if isinstance(value, ModelPatcherProxy): + self._patcher_proxy = value + else: + logger.warning( + f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}" + ) + + @property + def cond_stage_model(self) -> CondStageModelProxy: + if not hasattr(self, "_cond_stage_model_proxy"): + csm_id = self._call_rpc("get_cond_stage_model_id") + self._cond_stage_model_proxy = CondStageModelProxy( + csm_id, manage_lifecycle=False + ) + return self._cond_stage_model_proxy + + @property + def tokenizer(self) -> TokenizerProxy: + if not hasattr(self, "_tokenizer_proxy"): + tok_id = self._call_rpc("get_tokenizer_id") + self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False) + return self._tokenizer_proxy + + def load_model(self) -> ModelPatcherProxy: + self._call_rpc("load_model") + return self.patcher + + @property + def layer_idx(self) -> Optional[int]: + return self._call_rpc("get_property", "layer_idx") + + @layer_idx.setter + def layer_idx(self, value: Optional[int]) -> None: + self._call_rpc("set_property", "layer_idx", value) + + @property + def tokenizer_options(self) -> dict: + return self._call_rpc("get_property", "tokenizer_options") + + @tokenizer_options.setter + def tokenizer_options(self, value: dict) -> None: + self._call_rpc("set_property", "tokenizer_options", value) + + @property + def use_clip_schedule(self) -> bool: + return self._call_rpc("get_property", "use_clip_schedule") + + @use_clip_schedule.setter + def use_clip_schedule(self, value: bool) -> None: + self._call_rpc("set_property", "use_clip_schedule", value) + + @property + def apply_hooks_to_conds(self) -> Any: + return self._call_rpc("get_property", "apply_hooks_to_conds") + + @apply_hooks_to_conds.setter + def apply_hooks_to_conds(self, value: Any) -> None: + self._call_rpc("set_property", "apply_hooks_to_conds", value) + + def clip_layer(self, layer_idx: int) -> None: + return self._call_rpc("clip_layer", layer_idx) + + def set_tokenizer_option(self, option_name: str, value: Any) -> None: + return self._call_rpc("set_tokenizer_option", option_name, value) + + def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any: + return self._call_rpc( + "tokenize", text, return_word_ids=return_word_ids, **kwargs + ) + + def encode(self, text: str) -> Any: + return self._call_rpc("encode", text) + + def encode_from_tokens( + self, tokens: Any, return_pooled: bool = False, return_dict: bool = False + ) -> Any: + res = self._call_rpc( + "encode_from_tokens", + tokens, + return_pooled=return_pooled, + return_dict=return_dict, + ) + if return_pooled and isinstance(res, list) and not return_dict: + return tuple(res) + return res + + def encode_from_tokens_scheduled( + self, + tokens: Any, + unprojected: bool = False, + add_dict: Optional[dict] = None, + show_pbar: bool = True, + ) -> Any: + add_dict = add_dict or {} + return self._call_rpc( + "encode_from_tokens_scheduled", + tokens, + unprojected=unprojected, + add_dict=add_dict, + show_pbar=show_pbar, + ) + + def add_patches( + self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0 + ) -> Any: + return self._call_rpc( + "add_patches", + patches, + strength_patch=strength_patch, + strength_model=strength_model, + ) + + def get_key_patches(self) -> Any: + return self._call_rpc("get_key_patches") + + def load_sd(self, sd: dict, full_model: bool = False) -> Any: + return self._call_rpc("load_sd", sd, full_model=full_model) + + def get_sd(self) -> Any: + return self._call_rpc("get_sd") + + def clone(self) -> CLIPProxy: + new_id = self._call_rpc("clone") + return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS) + + +if not IS_CHILD_PROCESS: + _CLIP_REGISTRY_SINGLETON = CLIPRegistry() + _COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry() + _TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry() diff --git a/comfy/isolation/custom_node_serializers.py b/comfy/isolation/custom_node_serializers.py new file mode 100644 index 000000000..e7a6e78c2 --- /dev/null +++ b/comfy/isolation/custom_node_serializers.py @@ -0,0 +1,16 @@ +"""Compatibility shim for the indexed serializer path.""" + +from __future__ import annotations + +from typing import Any + + +def register_custom_node_serializers(_registry: Any) -> None: + """Legacy no-op shim. + + Serializer registration now lives directly in the active isolation adapter. + This module remains importable because the isolation index still references it. + """ + return None + +__all__ = ["register_custom_node_serializers"] diff --git a/comfy/isolation/extension_loader.py b/comfy/isolation/extension_loader.py new file mode 100644 index 000000000..0c65b234e --- /dev/null +++ b/comfy/isolation/extension_loader.py @@ -0,0 +1,521 @@ +# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name +from __future__ import annotations + +import logging +import os +import inspect +import sys +import types +import platform +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple + +import pyisolate +from pyisolate import ExtensionManager, ExtensionManagerConfig +from packaging.requirements import InvalidRequirement, Requirement +from packaging.utils import canonicalize_name + +from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache +from .host_policy import load_host_policy + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + + +def _register_web_directory(extension_name: str, node_dir: Path) -> None: + """Register an isolated extension's web directory on the host side.""" + import nodes + + # Method 1: pyproject.toml [tool.comfy] web field + pyproject = node_dir / "pyproject.toml" + if pyproject.exists(): + try: + with pyproject.open("rb") as f: + data = tomllib.load(f) + web_dir_name = data.get("tool", {}).get("comfy", {}).get("web") + if web_dir_name: + web_dir_path = str(node_dir / web_dir_name) + if os.path.isdir(web_dir_path): + nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path + logger.debug( + "][ Registered web dir for isolated %s: %s", + extension_name, + web_dir_path, + ) + return + except Exception: + pass + + # Method 2: __init__.py WEB_DIRECTORY constant (parse without importing) + init_file = node_dir / "__init__.py" + if init_file.exists(): + try: + source = init_file.read_text() + for line in source.splitlines(): + stripped = line.strip() + if stripped.startswith("WEB_DIRECTORY"): + # Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web" + _, _, value = stripped.partition("=") + value = value.strip().strip("\"'") + if value: + web_dir_path = str((node_dir / value).resolve()) + if os.path.isdir(web_dir_path): + nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path + logger.debug( + "][ Registered web dir for isolated %s: %s", + extension_name, + web_dir_path, + ) + return + except Exception: + pass + + +def _get_extension_type(execution_model: str) -> type[Any]: + if execution_model == "sealed_worker": + return pyisolate.SealedNodeExtension + + from .extension_wrapper import ComfyNodeExtension + + return ComfyNodeExtension + + +async def _stop_extension_safe(extension: Any, extension_name: str) -> None: + try: + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + except Exception: + logger.debug("][ %s stop failed", extension_name, exc_info=True) + + +def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str: + req, sep, marker = dep.partition(";") + req = req.strip() + marker_suffix = f";{marker}" if sep else "" + + def _resolve_local_path(local_path: str) -> Path | None: + for base in base_paths: + candidate = (base / local_path).resolve() + if candidate.exists(): + return candidate + return None + + if req.startswith("./") or req.startswith("../"): + resolved = _resolve_local_path(req) + if resolved is not None: + return f"{resolved}{marker_suffix}" + + if req.startswith("file://"): + raw = req[len("file://") :] + if raw.startswith("./") or raw.startswith("../"): + resolved = _resolve_local_path(raw) + if resolved is not None: + return f"file://{resolved}{marker_suffix}" + + return dep + + +def _dependency_name_from_spec(dep: str) -> str | None: + stripped = dep.strip() + if not stripped or stripped == "-e" or stripped.startswith("-e "): + return None + if stripped.startswith(("/", "./", "../", "file://")): + return None + + try: + return canonicalize_name(Requirement(stripped).name) + except InvalidRequirement: + return None + + +def _parse_cuda_wheels_config( + tool_config: dict[str, object], dependencies: list[str] +) -> dict[str, object] | None: + raw_config = tool_config.get("cuda_wheels") + if raw_config is None: + return None + if not isinstance(raw_config, dict): + raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table") + + index_url = raw_config.get("index_url") + index_urls = raw_config.get("index_urls") + if index_urls is not None: + if not isinstance(index_urls, list) or not all( + isinstance(u, str) and u.strip() for u in index_urls + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings" + ) + elif not isinstance(index_url, str) or not index_url.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string" + ) + + packages = raw_config.get("packages") + if not isinstance(packages, list) or not all( + isinstance(package_name, str) and package_name.strip() + for package_name in packages + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings" + ) + + declared_dependencies = { + dependency_name + for dep in dependencies + if (dependency_name := _dependency_name_from_spec(dep)) is not None + } + normalized_packages = [canonicalize_name(package_name) for package_name in packages] + missing = [ + package_name + for package_name in normalized_packages + if package_name not in declared_dependencies + ] + if missing: + missing_joined = ", ".join(sorted(missing)) + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: " + f"{missing_joined}" + ) + + package_map = raw_config.get("package_map", {}) + if not isinstance(package_map, dict): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] must be a table" + ) + + normalized_package_map: dict[str, str] = {} + for dependency_name, index_package_name in package_map.items(): + if not isinstance(dependency_name, str) or not dependency_name.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings" + ) + if not isinstance(index_package_name, str) or not index_package_name.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings" + ) + canonical_dependency_name = canonicalize_name(dependency_name) + if canonical_dependency_name not in normalized_packages: + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in " + "[tool.comfy.isolation.cuda_wheels.packages]" + ) + normalized_package_map[canonical_dependency_name] = index_package_name.strip() + + result: dict = { + "packages": normalized_packages, + "package_map": normalized_package_map, + } + if index_urls is not None: + result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls] + else: + result["index_url"] = index_url.rstrip("/") + "/" + return result + + +def get_enforcement_policy() -> Dict[str, bool]: + return { + "force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1", + "force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1", + } + + +class ExtensionLoadError(RuntimeError): + pass + + +def register_dummy_module(extension_name: str, node_dir: Path) -> None: + normalized_name = extension_name.replace("-", "_").replace(".", "_") + if normalized_name not in sys.modules: + dummy_module = types.ModuleType(normalized_name) + dummy_module.__file__ = str(node_dir / "__init__.py") + dummy_module.__path__ = [str(node_dir)] + dummy_module.__package__ = normalized_name + sys.modules[normalized_name] = dummy_module + + +def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool: + for details in cached_data.values(): + if not isinstance(details, dict): + return True + if details.get("is_v3") and "schema_v1" not in details: + return True + return False + + +async def load_isolated_node( + node_dir: Path, + manifest_path: Path, + logger: logging.Logger, + build_stub_class: Callable[[str, Dict[str, object], Any], type], + venv_root: Path, + extension_managers: List[ExtensionManager], +) -> List[Tuple[str, str, type]]: + try: + with manifest_path.open("rb") as handle: + manifest_data = tomllib.load(handle) + except Exception as e: + logger.warning(f"][ Failed to parse {manifest_path}: {e}") + return [] + + # Parse [tool.comfy.isolation] + tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {}) + can_isolate = tool_config.get("can_isolate", False) + share_torch = tool_config.get("share_torch", False) + package_manager = tool_config.get("package_manager", "uv") + is_conda = package_manager == "conda" + execution_model = tool_config.get("execution_model") + if execution_model is None: + execution_model = "sealed_worker" if is_conda else "host-coupled" + + if "sealed_host_ro_paths" in tool_config: + raise ValueError( + "Manifest field 'sealed_host_ro_paths' is not allowed. " + "Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy." + ) + + # Conda-specific manifest fields + conda_channels: list[str] = ( + tool_config.get("conda_channels", []) if is_conda else [] + ) + conda_dependencies: list[str] = ( + tool_config.get("conda_dependencies", []) if is_conda else [] + ) + conda_platforms: list[str] = ( + tool_config.get("conda_platforms", []) if is_conda else [] + ) + conda_python: str = ( + tool_config.get("conda_python", "*") if is_conda else "*" + ) + + # Parse [project] dependencies + project_config = manifest_data.get("project", {}) + dependencies = project_config.get("dependencies", []) + if not isinstance(dependencies, list): + dependencies = [] + + # Get extension name (default to folder name if not in project.name) + extension_name = project_config.get("name", node_dir.name) + + # LOGIC: Isolation Decision + policy = get_enforcement_policy() + isolated = can_isolate or policy["force_isolated"] + + if not isolated: + return [] + + import folder_paths + + base_paths = [Path(folder_paths.base_path), node_dir] + dependencies = [ + _normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep + for dep in dependencies + ] + cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies) + + manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root)) + extension_type = _get_extension_type(execution_model) + manager: ExtensionManager = pyisolate.ExtensionManager( + extension_type, manager_config + ) + extension_managers.append(manager) + + host_policy = load_host_policy(Path(folder_paths.base_path)) + + sandbox_config = {} + is_linux = platform.system() == "Linux" + + if is_conda: + share_torch = False + share_cuda_ipc = False + else: + share_cuda_ipc = share_torch and is_linux + + if is_linux and isolated: + sandbox_config = { + "network": host_policy["allow_network"], + "writable_paths": host_policy["writable_paths"], + "readonly_paths": host_policy["readonly_paths"], + } + + extension_config: dict = { + "name": extension_name, + "module_path": str(node_dir), + "isolated": True, + "dependencies": dependencies, + "share_torch": share_torch, + "share_cuda_ipc": share_cuda_ipc, + "sandbox_mode": host_policy["sandbox_mode"], + "sandbox": sandbox_config, + } + + _is_sealed = execution_model == "sealed_worker" + _is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux + logger.info( + "][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])", + extension_name, + "x" if share_torch else " ", + "x" if _is_sealed else " ", + "x" if _is_sandboxed else " ", + ) + + if cuda_wheels is not None: + extension_config["cuda_wheels"] = cuda_wheels + + # Conda-specific keys + if is_conda: + extension_config["package_manager"] = "conda" + extension_config["conda_channels"] = conda_channels + extension_config["conda_dependencies"] = conda_dependencies + extension_config["conda_python"] = conda_python + find_links = tool_config.get("find_links", []) + if find_links: + extension_config["find_links"] = find_links + if conda_platforms: + extension_config["conda_platforms"] = conda_platforms + + if execution_model != "host-coupled": + extension_config["execution_model"] = execution_model + if execution_model == "sealed_worker": + policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", []) + if isinstance(policy_ro_paths, list) and policy_ro_paths: + extension_config["sealed_host_ro_paths"] = list(policy_ro_paths) + # Sealed workers keep the host RPC service inventory even when the + # child resolves no API classes locally. + + extension = manager.load_extension(extension_config) + register_dummy_module(extension_name, node_dir) + + # Register host-side event handlers via adapter + from .adapter import ComfyUIAdapter + ComfyUIAdapter.register_host_event_handlers(extension) + + # Register web directory on the host — only when sandbox is disabled. + # In sandbox mode, serving untrusted JS to the browser is not safe. + if host_policy["sandbox_mode"] == "disabled": + _register_web_directory(extension_name, node_dir) + + # Register for proxied web serving — the child's web dir may have + # content that doesn't exist on the host (e.g., pip-installed viewer + # bundles). The WebDirectoryCache will lazily fetch via RPC. + from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache + cache = get_web_directory_cache() + cache.register_proxy(extension_name, WebDirectoryProxy()) + + # Try cache first (lazy spawn) + logger.warning("][ DIAG:ext_loader cache_valid_check for %s", extension_name) + if is_cache_valid(node_dir, manifest_path, venv_root): + cached_data = load_from_cache(node_dir, venv_root) + if cached_data: + if _is_stale_node_cache(cached_data): + logger.warning( + "][ DIAG:ext_loader %s cache is stale/incompatible; rebuilding metadata", + extension_name, + ) + else: + logger.warning("][ DIAG:ext_loader %s USING CACHE — dumping combo options:", extension_name) + for node_name, details in cached_data.items(): + schema_v1 = details.get("schema_v1", {}) + inp = schema_v1.get("input", {}) if schema_v1 else {} + for section_name, section in inp.items(): + if isinstance(section, dict): + for field_name, field_def in section.items(): + if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]: + opts = field_def[1]["options"] + logger.warning( + "][ DIAG:ext_loader CACHE %s.%s.%s options=%d first=%s", + node_name, section_name, field_name, + len(opts), + opts[:3] if opts else "EMPTY", + ) + specs: List[Tuple[str, str, type]] = [] + for node_name, details in cached_data.items(): + stub_cls = build_stub_class(node_name, details, extension) + specs.append( + (node_name, details.get("display_name", node_name), stub_cls) + ) + return specs + else: + logger.warning("][ DIAG:ext_loader %s cache INVALID or MISSING", extension_name) + + # Cache miss - spawn process and get metadata + logger.warning("][ DIAG:ext_loader %s cache miss, spawning process for metadata", extension_name) + + try: + remote_nodes: Dict[str, str] = await extension.list_nodes() + except Exception as exc: + logger.warning( + "][ %s metadata discovery failed, skipping isolated load: %s", + extension_name, + exc, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + if not remote_nodes: + logger.debug("][ %s exposed no isolated nodes; skipping", extension_name) + await _stop_extension_safe(extension, extension_name) + return [] + + specs: List[Tuple[str, str, type]] = [] + cache_data: Dict[str, Dict] = {} + + for node_name, display_name in remote_nodes.items(): + logger.warning("][ DIAG:ext_loader calling get_node_details for %s.%s", extension_name, node_name) + try: + details = await extension.get_node_details(node_name) + except Exception as exc: + logger.warning( + "][ %s failed to load metadata for %s, skipping node: %s", + extension_name, + node_name, + exc, + ) + continue + # DIAG: dump combo options from freshly-fetched details + schema_v1 = details.get("schema_v1", {}) + inp = schema_v1.get("input", {}) if schema_v1 else {} + for section_name, section in inp.items(): + if isinstance(section, dict): + for field_name, field_def in section.items(): + if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]: + opts = field_def[1]["options"] + logger.warning( + "][ DIAG:ext_loader FRESH %s.%s.%s options=%d first=%s", + node_name, section_name, field_name, + len(opts), + opts[:3] if opts else "EMPTY", + ) + details["display_name"] = display_name + cache_data[node_name] = details + stub_cls = build_stub_class(node_name, details, extension) + specs.append((node_name, display_name, stub_cls)) + + if not specs: + logger.warning( + "][ %s produced no usable nodes after metadata scan; skipping", + extension_name, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + # Save metadata to cache for future runs + save_to_cache(node_dir, venv_root, cache_data, manifest_path) + logger.debug(f"][ {extension_name} metadata cached") + + # Re-check web directory AFTER child has populated it + if host_policy["sandbox_mode"] == "disabled": + _register_web_directory(extension_name, node_dir) + + # EJECT: Kill process after getting metadata (will respawn on first execution) + await _stop_extension_safe(extension, extension_name) + + return specs + + +__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"] diff --git a/comfy/isolation/extension_wrapper.py b/comfy/isolation/extension_wrapper.py new file mode 100644 index 000000000..67ba1d5c4 --- /dev/null +++ b/comfy/isolation/extension_wrapper.py @@ -0,0 +1,896 @@ +# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position +from __future__ import annotations + +import asyncio +import torch + + +class AttrDict(dict): + def __getattr__(self, item): + try: + return self[item] + except KeyError as e: + raise AttributeError(item) from e + + def copy(self): + return AttrDict(super().copy()) + + +import importlib +import inspect +import json +import logging +import os +import sys +import uuid +from dataclasses import asdict +from typing import Any, Dict, List, Tuple + +from pyisolate import ExtensionBase + +from comfy_api.internal import _ComfyNodeInternal + +LOG_PREFIX = "][" +V3_DISCOVERY_TIMEOUT = 30 +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + +logger = logging.getLogger(__name__) + + +def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None: + """Run the web asset copy step that prestartup_script.py used to do. + + If the module's web/ directory is empty and the module had a + prestartup_script.py that copied assets from pip packages, this + function replicates that work inside the child process. + + Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if + defined, otherwise falls back to detecting common asset packages. + """ + import shutil + + # Already populated — nothing to do + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + return + + os.makedirs(web_dir_path, exist_ok=True) + + # Try module-defined copy spec first (generic hook for any node pack) + copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None) + if copy_spec is not None and callable(copy_spec): + try: + copy_spec(web_dir_path) + logger.info( + "%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir + ) + return + except Exception as e: + logger.warning( + "%s _PRESTARTUP_WEB_COPY failed for %s: %s", + LOG_PREFIX, module_dir, e, + ) + + # Fallback: detect comfy_3d_viewers and run copy_viewer() + try: + from comfy_3d_viewers import copy_viewer, VIEWER_FILES + viewers = list(VIEWER_FILES.keys()) + for viewer in viewers: + try: + copy_viewer(viewer, web_dir_path) + except Exception: + pass + if any(os.scandir(web_dir_path)): + logger.info( + "%s Copied %d viewer types from comfy_3d_viewers to %s", + LOG_PREFIX, len(viewers), web_dir_path, + ) + except ImportError: + pass + + # Fallback: detect comfy_dynamic_widgets + try: + from comfy_dynamic_widgets import get_js_path + src = os.path.realpath(get_js_path()) + if os.path.exists(src): + dst_dir = os.path.join(web_dir_path, "js") + os.makedirs(dst_dir, exist_ok=True) + dst = os.path.join(dst_dir, "dynamic_widgets.js") + shutil.copy2(src, dst) + except ImportError: + pass + + +def _read_extension_name(module_dir: str) -> str: + """Read extension name from pyproject.toml, falling back to directory name.""" + pyproject = os.path.join(module_dir, "pyproject.toml") + if os.path.exists(pyproject): + try: + import tomllib + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + try: + with open(pyproject, "rb") as f: + data = tomllib.load(f) + name = data.get("project", {}).get("name") + if name: + return name + except Exception: + pass + return os.path.basename(module_dir) + + +def _flush_tensor_transport_state(marker: str) -> int: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return 0 + if not callable(flush_tensor_keeper): + return 0 + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + return flushed + + +def _relieve_child_vram_pressure(marker: str) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _sanitize_for_transport(value): + primitives = (str, int, float, bool, type(None)) + if isinstance(value, primitives): + return value + + cls_name = value.__class__.__name__ + if cls_name == "FlexibleOptionalInputType": + return { + "__pyisolate_flexible_optional__": True, + "type": _sanitize_for_transport(getattr(value, "type", "*")), + } + if cls_name == "AnyType": + return {"__pyisolate_any_type__": True, "value": str(value)} + if cls_name == "ByPassTypeTuple": + return { + "__pyisolate_bypass_tuple__": [ + _sanitize_for_transport(v) for v in tuple(value) + ] + } + + if isinstance(value, dict): + return {k: _sanitize_for_transport(v) for k, v in value.items()} + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]} + if isinstance(value, list): + return [_sanitize_for_transport(v) for v in value] + + return str(value) + + +# Re-export RemoteObjectHandle from pyisolate for backward compatibility +# The canonical definition is now in pyisolate._internal.remote_handle +from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401 + + +class ComfyNodeExtension(ExtensionBase): + def __init__(self) -> None: + super().__init__() + self.node_classes: Dict[str, type] = {} + self.display_names: Dict[str, str] = {} + self.node_instances: Dict[str, Any] = {} + self.remote_objects: Dict[str, Any] = {} + self._route_handlers: Dict[str, Any] = {} + self._module: Any = None + + async def on_module_loaded(self, module: Any) -> None: + self._module = module + + # Registries are initialized in host_hooks.py initialize_host_process() + # They auto-register via ProxiedSingleton when instantiated + # NO additional setup required here - if a registry is missing from host_hooks, it WILL fail + + self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} + self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} + + # Register web directory with WebDirectoryProxy (child-side) + web_dir_attr = getattr(module, "WEB_DIRECTORY", None) + if web_dir_attr is not None: + module_dir = os.path.dirname(os.path.abspath(module.__file__)) + web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr)) + ext_name = _read_extension_name(module_dir) + + # If web dir is empty, run the copy step that prestartup_script.py did + _run_prestartup_web_copy(module, module_dir, web_dir_path) + + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + WebDirectoryProxy.register_web_dir(ext_name, web_dir_path) + + try: + from comfy_api.latest import ComfyExtension + + for name, obj in inspect.getmembers(module): + if not ( + inspect.isclass(obj) + and issubclass(obj, ComfyExtension) + and obj is not ComfyExtension + ): + continue + if not obj.__module__.startswith(module.__name__): + continue + try: + ext_instance = obj() + try: + await asyncio.wait_for( + ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in on_load()", + LOG_PREFIX, + name, + ) + continue + try: + v3_nodes = await asyncio.wait_for( + ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in get_node_list()", + LOG_PREFIX, + name, + ) + continue + for node_cls in v3_nodes: + if hasattr(node_cls, "GET_SCHEMA"): + schema = node_cls.GET_SCHEMA() + self.node_classes[schema.node_id] = node_cls + if schema.display_name: + self.display_names[schema.node_id] = schema.display_name + except Exception as e: + logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e) + except ImportError: + pass + + module_name = getattr(module, "__name__", "isolated_nodes") + for node_cls in self.node_classes.values(): + if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__): + node_cls.__module__ = module_name + + self.node_instances = {} + + async def list_nodes(self) -> Dict[str, str]: + return {name: self.display_names.get(name, name) for name in self.node_classes} + + async def get_node_info(self, node_name: str) -> Dict[str, Any]: + return await self.get_node_details(node_name) + + async def get_node_details(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + is_v3 = issubclass(node_cls, _ComfyNodeInternal) + logger.warning( + "%s DIAG:get_node_details START | node=%s | is_v3=%s | cls=%s", + LOG_PREFIX, node_name, is_v3, node_cls, + ) + + input_types_raw = ( + node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {} + ) + output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None) + if output_is_list is not None: + output_is_list = tuple(bool(x) for x in output_is_list) + + details: Dict[str, Any] = { + "input_types": _sanitize_for_transport(input_types_raw), + "return_types": tuple( + str(t) for t in getattr(node_cls, "RETURN_TYPES", ()) + ), + "return_names": getattr(node_cls, "RETURN_NAMES", None), + "function": str(getattr(node_cls, "FUNCTION", "execute")), + "category": str(getattr(node_cls, "CATEGORY", "")), + "output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)), + "output_is_list": output_is_list, + "is_v3": is_v3, + } + + if is_v3: + try: + logger.warning( + "%s DIAG:get_node_details calling GET_SCHEMA for %s", + LOG_PREFIX, node_name, + ) + schema = node_cls.GET_SCHEMA() + logger.warning( + "%s DIAG:get_node_details GET_SCHEMA returned for %s | schema_inputs=%s", + LOG_PREFIX, node_name, + [getattr(i, 'id', '?') for i in (schema.inputs or [])], + ) + schema_v1 = asdict(schema.get_v1_info(node_cls)) + try: + schema_v3 = asdict(schema.get_v3_info(node_cls)) + except (AttributeError, TypeError): + schema_v3 = self._build_schema_v3_fallback(schema) + details.update( + { + "schema_v1": schema_v1, + "schema_v3": schema_v3, + "hidden": [h.value for h in (schema.hidden or [])], + "description": getattr(schema, "description", ""), + "deprecated": bool(getattr(node_cls, "DEPRECATED", False)), + "experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)), + "api_node": bool(getattr(node_cls, "API_NODE", False)), + "input_is_list": bool( + getattr(node_cls, "INPUT_IS_LIST", False) + ), + "not_idempotent": bool( + getattr(node_cls, "NOT_IDEMPOTENT", False) + ), + "accept_all_inputs": bool( + getattr(node_cls, "ACCEPT_ALL_INPUTS", False) + ), + } + ) + except Exception as exc: + logger.warning( + "%s V3 schema serialization failed for %s: %s", + LOG_PREFIX, + node_name, + exc, + ) + return details + + def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]: + input_dict: Dict[str, Any] = {} + output_dict: Dict[str, Any] = {} + hidden_list: List[str] = [] + + if getattr(schema, "inputs", None): + for inp in schema.inputs: + self._add_schema_io_v3(inp, input_dict) + if getattr(schema, "outputs", None): + for out in schema.outputs: + self._add_schema_io_v3(out, output_dict) + if getattr(schema, "hidden", None): + for h in schema.hidden: + hidden_list.append(getattr(h, "value", str(h))) + + return { + "input": input_dict, + "output": output_dict, + "hidden": hidden_list, + "name": getattr(schema, "node_id", None), + "display_name": getattr(schema, "display_name", None), + "description": getattr(schema, "description", None), + "category": getattr(schema, "category", None), + "output_node": getattr(schema, "is_output_node", False), + "deprecated": getattr(schema, "is_deprecated", False), + "experimental": getattr(schema, "is_experimental", False), + "api_node": getattr(schema, "is_api_node", False), + } + + def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None: + io_id = getattr(io_obj, "id", None) + if io_id is None: + return + + io_type_fn = getattr(io_obj, "get_io_type", None) + io_type = ( + io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None) + ) + + as_dict_fn = getattr(io_obj, "as_dict", None) + payload = as_dict_fn() if callable(as_dict_fn) else {} + + target[str(io_id)] = (io_type, payload) + + async def get_input_types(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + if hasattr(node_cls, "INPUT_TYPES"): + return node_cls.INPUT_TYPES() + return {} + + async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]: + logger.debug( + "%s ISO:child_execute_start ext=%s node=%s input_keys=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + len(inputs), + ) + if os.environ.get("PYISOLATE_CHILD") == "1": + _relieve_child_vram_pressure("EXT:pre_execute") + + resolved_inputs = self._resolve_remote_objects(inputs) + + instance = self._get_node_instance(node_name) + node_cls = self._get_node_class(node_name) + + # V3 API nodes expect hidden parameters in cls.hidden, not as kwargs + # Hidden params come through RPC as string keys like "Hidden.prompt" + from comfy_api.latest._io import Hidden, HiddenHolder + + # Map string representations back to Hidden enum keys + hidden_string_map = { + "Hidden.unique_id": Hidden.unique_id, + "Hidden.prompt": Hidden.prompt, + "Hidden.extra_pnginfo": Hidden.extra_pnginfo, + "Hidden.dynprompt": Hidden.dynprompt, + "Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org, + "Hidden.api_key_comfy_org": Hidden.api_key_comfy_org, + # Uppercase enum VALUE forms — V3 execution engine passes these + "UNIQUE_ID": Hidden.unique_id, + "PROMPT": Hidden.prompt, + "EXTRA_PNGINFO": Hidden.extra_pnginfo, + "DYNPROMPT": Hidden.dynprompt, + "AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org, + "API_KEY_COMFY_ORG": Hidden.api_key_comfy_org, + } + + # Find and extract hidden parameters (both enum and string form) + hidden_found = {} + keys_to_remove = [] + + for key in list(resolved_inputs.keys()): + # Check string form first (from RPC serialization) + if key in hidden_string_map: + hidden_found[hidden_string_map[key]] = resolved_inputs[key] + keys_to_remove.append(key) + # Also check enum form (direct calls) + elif isinstance(key, Hidden): + hidden_found[key] = resolved_inputs[key] + keys_to_remove.append(key) + + # Remove hidden params from kwargs + for key in keys_to_remove: + resolved_inputs.pop(key) + + # Set hidden on node class if any hidden params found + if hidden_found: + if not hasattr(node_cls, "hidden") or node_cls.hidden is None: + node_cls.hidden = HiddenHolder.from_dict(hidden_found) + else: + # Update existing hidden holder + for key, value in hidden_found.items(): + setattr(node_cls.hidden, key.value.lower(), value) + + # INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this + # flag is set. The isolation RPC delivers unwrapped values, so we must + # wrap each input in a single-element list to match the contract. + if getattr(node_cls, "INPUT_IS_LIST", False): + resolved_inputs = {k: [v] for k, v in resolved_inputs.items()} + + function_name = getattr(node_cls, "FUNCTION", "execute") + if not hasattr(instance, function_name): + raise AttributeError(f"Node {node_name} missing callable '{function_name}'") + + handler = getattr(instance, function_name) + + try: + import torch + if asyncio.iscoroutinefunction(handler): + with torch.inference_mode(): + result = await handler(**resolved_inputs) + else: + import functools + + def _run_with_inference_mode(**kwargs): + with torch.inference_mode(): + return handler(**kwargs) + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, functools.partial(_run_with_inference_mode, **resolved_inputs) + ) + except Exception: + logger.exception( + "%s ISO:child_execute_error ext=%s node=%s", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + ) + raise + + if type(result).__name__ == "NodeOutput": + node_output_dict = { + "__node_output__": True, + "args": self._wrap_unpicklable_objects(result.args), + } + if result.ui is not None: + node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui) + if getattr(result, "expand", None) is not None: + node_output_dict["expand"] = result.expand + if getattr(result, "block_execution", None) is not None: + node_output_dict["block_execution"] = result.block_execution + return node_output_dict + if self._is_comfy_protocol_return(result): + wrapped = self._wrap_unpicklable_objects(result) + return wrapped + + if not isinstance(result, tuple): + result = (result,) + wrapped = self._wrap_unpicklable_objects(result) + return wrapped + + async def flush_transport_state(self) -> int: + if os.environ.get("PYISOLATE_CHILD") != "1": + return 0 + logger.debug( + "%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?") + ) + flushed = _flush_tensor_transport_state("EXT:workflow_end") + try: + from comfy.isolation.model_patcher_proxy_registry import ( + ModelPatcherRegistry, + ) + + registry = ModelPatcherRegistry() + removed = registry.sweep_pending_cleanup() + if removed > 0: + logger.debug( + "%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed + ) + except Exception: + logger.debug( + "%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True + ) + logger.debug( + "%s ISO:child_flush_done ext=%s flushed=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + flushed, + ) + return flushed + + async def get_remote_object(self, object_id: str) -> Any: + """Retrieve a remote object by ID for host-side deserialization.""" + if object_id not in self.remote_objects: + raise KeyError(f"Remote object {object_id} not found") + + return self.remote_objects[object_id] + + def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle: + object_id = str(uuid.uuid4()) + self.remote_objects[object_id] = obj + return RemoteObjectHandle(object_id, type(obj).__name__) + + async def call_remote_object_method( + self, + object_id: str, + method_name: str, + *args: Any, + **kwargs: Any, + ) -> Any: + """Invoke a method or attribute-backed accessor on a child-owned object.""" + obj = await self.get_remote_object(object_id) + + if method_name == "get_patcher_attr": + return getattr(obj, args[0]) + if method_name == "get_model_options": + return getattr(obj, "model_options") + if method_name == "set_model_options": + setattr(obj, "model_options", args[0]) + return None + if method_name == "get_object_patches": + return getattr(obj, "object_patches") + if method_name == "get_patches": + return getattr(obj, "patches") + if method_name == "get_wrappers": + return getattr(obj, "wrappers") + if method_name == "get_callbacks": + return getattr(obj, "callbacks") + if method_name == "get_load_device": + return getattr(obj, "load_device") + if method_name == "get_offload_device": + return getattr(obj, "offload_device") + if method_name == "get_hook_mode": + return getattr(obj, "hook_mode") + if method_name == "get_parent": + parent = getattr(obj, "parent", None) + if parent is None: + return None + return self._store_remote_object_handle(parent) + if method_name == "get_inner_model_attr": + attr_name = args[0] + if hasattr(obj.model, attr_name): + return getattr(obj.model, attr_name) + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + return None + if method_name == "inner_model_apply_model": + return obj.model.apply_model(*args[0], **args[1]) + if method_name == "inner_model_extra_conds_shapes": + return obj.model.extra_conds_shapes(*args[0], **args[1]) + if method_name == "inner_model_extra_conds": + return obj.model.extra_conds(*args[0], **args[1]) + if method_name == "inner_model_memory_required": + return obj.model.memory_required(*args[0], **args[1]) + if method_name == "process_latent_in": + return obj.model.process_latent_in(*args[0], **args[1]) + if method_name == "process_latent_out": + return obj.model.process_latent_out(*args[0], **args[1]) + if method_name == "scale_latent_inpaint": + return obj.model.scale_latent_inpaint(*args[0], **args[1]) + if method_name.startswith("get_"): + attr_name = method_name[4:] + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + + target = getattr(obj, method_name) + if callable(target): + result = target(*args, **kwargs) + if inspect.isawaitable(result): + result = await result + if type(result).__name__ == "ModelPatcher": + return self._store_remote_object_handle(result) + return result + if args or kwargs: + raise TypeError(f"{method_name} is not callable on remote object {object_id}") + return target + + def _wrap_unpicklable_objects(self, data: Any) -> Any: + if isinstance(data, (str, int, float, bool, type(None))): + return data + if isinstance(data, torch.Tensor): + tensor = data.detach() if data.requires_grad else data + if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu": + return tensor.cpu() + return tensor + + # Special-case clip vision outputs: preserve attribute access by packing fields + if hasattr(data, "penultimate_hidden_states") or hasattr( + data, "last_hidden_state" + ): + fields = {} + for attr in ( + "penultimate_hidden_states", + "last_hidden_state", + "image_embeds", + "text_embeds", + ): + if hasattr(data, attr): + try: + fields[attr] = self._wrap_unpicklable_objects( + getattr(data, attr) + ) + except Exception: + pass + if fields: + return {"__pyisolate_attribute_container__": True, "data": fields} + + # Avoid converting arbitrary objects with stateful methods (models, etc.) + # They will be handled via RemoteObjectHandle below. + + type_name = type(data).__name__ + if type_name == "ModelPatcherProxy": + return {"__type__": "ModelPatcherRef", "model_id": data._instance_id} + if type_name == "CLIPProxy": + return {"__type__": "CLIPRef", "clip_id": data._instance_id} + if type_name == "VAEProxy": + return {"__type__": "VAERef", "vae_id": data._instance_id} + if type_name == "ModelSamplingProxy": + return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id} + + if isinstance(data, (list, tuple)): + wrapped = [self._wrap_unpicklable_objects(item) for item in data] + return tuple(wrapped) if isinstance(data, tuple) else wrapped + if isinstance(data, dict): + converted_dict = { + k: self._wrap_unpicklable_objects(v) for k, v in data.items() + } + return {"__pyisolate_attrdict__": True, "data": converted_dict} + + from pyisolate._internal.serialization_registry import SerializerRegistry + + registry = SerializerRegistry.get_instance() + if registry.is_data_type(type_name): + serializer = registry.get_serializer(type_name) + if serializer: + return serializer(data) + + return self._store_remote_object_handle(data) + + def _resolve_remote_objects(self, data: Any) -> Any: + if isinstance(data, RemoteObjectHandle): + if data.object_id not in self.remote_objects: + raise KeyError(f"Remote object {data.object_id} not found") + return self.remote_objects[data.object_id] + + if isinstance(data, dict): + ref_type = data.get("__type__") + if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"): + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + if ref_type == "ModelSamplingRef": + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + return {k: self._resolve_remote_objects(v) for k, v in data.items()} + + if isinstance(data, (list, tuple)): + resolved = [self._resolve_remote_objects(item) for item in data] + return tuple(resolved) if isinstance(data, tuple) else resolved + return data + + def _get_node_class(self, node_name: str) -> type: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + return self.node_classes[node_name] + + def _get_node_instance(self, node_name: str) -> Any: + if node_name not in self.node_instances: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + self.node_instances[node_name] = self.node_classes[node_name]() + return self.node_instances[node_name] + + async def before_module_loaded(self) -> None: + # Inject initialization here if we think this is the child + logger.warning( + "%s DIAG:before_module_loaded START | is_child=%s", + LOG_PREFIX, os.environ.get("PYISOLATE_CHILD"), + ) + try: + from comfy.isolation import initialize_proxies + + initialize_proxies() + logger.warning("%s DIAG:before_module_loaded initialize_proxies OK", LOG_PREFIX) + except Exception as e: + logger.error( + "%s DIAG:before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e + ) + + await super().before_module_loaded() + try: + from comfy_api.latest import ComfyAPI_latest + from .proxies.progress_proxy import ProgressProxy + + ComfyAPI_latest.Execution = ProgressProxy + # ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision + # fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision + # latest_ui.folder_paths = fp_proxy + # latest_resources.folder_paths = fp_proxy + except Exception: + pass + + async def call_route_handler( + self, + handler_module: str, + handler_func: str, + request_data: Dict[str, Any], + ) -> Any: + cache_key = f"{handler_module}.{handler_func}" + if cache_key not in self._route_handlers: + if self._module is not None and hasattr(self._module, "__file__"): + node_dir = os.path.dirname(self._module.__file__) + if node_dir not in sys.path: + sys.path.insert(0, node_dir) + try: + module = importlib.import_module(handler_module) + self._route_handlers[cache_key] = getattr(module, handler_func) + except (ImportError, AttributeError) as e: + raise ValueError(f"Route handler not found: {cache_key}") from e + + handler = self._route_handlers[cache_key] + mock_request = MockRequest(request_data) + + if asyncio.iscoroutinefunction(handler): + result = await handler(mock_request) + else: + result = handler(mock_request) + return self._serialize_response(result) + + def _is_comfy_protocol_return(self, result: Any) -> bool: + """ + Check if the result matches the ComfyUI 'Protocol Return' schema. + + A Protocol Return is a dictionary containing specific reserved keys that + ComfyUI's execution engine interprets as instructions (UI updates, + Workflow expansion, etc.) rather than purely data outputs. + + Schema: + - Must be a dict + - Must contain at least one of: 'ui', 'result', 'expand' + """ + if not isinstance(result, dict): + return False + return any(key in result for key in ("ui", "result", "expand")) + + def _serialize_response(self, response: Any) -> Dict[str, Any]: + if response is None: + return {"type": "text", "body": "", "status": 204} + if isinstance(response, dict): + return {"type": "json", "body": response, "status": 200} + if isinstance(response, str): + return {"type": "text", "body": response, "status": 200} + if hasattr(response, "text") and hasattr(response, "status"): + return { + "type": "text", + "body": response.text + if hasattr(response, "text") + else str(response.body), + "status": response.status, + "headers": dict(response.headers) + if hasattr(response, "headers") + else {}, + } + if hasattr(response, "body") and hasattr(response, "status"): + body = response.body + if isinstance(body, bytes): + try: + return { + "type": "text", + "body": body.decode("utf-8"), + "status": response.status, + } + except UnicodeDecodeError: + return { + "type": "binary", + "body": body.hex(), + "status": response.status, + } + return {"type": "json", "body": body, "status": response.status} + return {"type": "text", "body": str(response), "status": 200} + + +class MockRequest: + def __init__(self, data: Dict[str, Any]): + self.method = data.get("method", "GET") + self.path = data.get("path", "/") + self.query = data.get("query", {}) + self._body = data.get("body", {}) + self._text = data.get("text", "") + self.headers = data.get("headers", {}) + self.content_type = data.get( + "content_type", self.headers.get("Content-Type", "application/json") + ) + self.match_info = data.get("match_info", {}) + + async def json(self) -> Any: + if isinstance(self._body, dict): + return self._body + if isinstance(self._body, str): + return json.loads(self._body) + return {} + + async def post(self) -> Dict[str, Any]: + if isinstance(self._body, dict): + return self._body + return {} + + async def text(self) -> str: + if self._text: + return self._text + if isinstance(self._body, str): + return self._body + if isinstance(self._body, dict): + return json.dumps(self._body) + return "" + + async def read(self) -> bytes: + return (await self.text()).encode("utf-8") diff --git a/comfy/isolation/host_hooks.py b/comfy/isolation/host_hooks.py new file mode 100644 index 000000000..e20143591 --- /dev/null +++ b/comfy/isolation/host_hooks.py @@ -0,0 +1,30 @@ +# pylint: disable=import-outside-toplevel +# Host process initialization for PyIsolate +import logging + +logger = logging.getLogger(__name__) + + +def initialize_host_process() -> None: + root = logging.getLogger() + for handler in root.handlers[:]: + root.removeHandler(handler) + root.addHandler(logging.NullHandler()) + + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.helper_proxies import HelperProxiesService + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerService + from .proxies.utils_proxy import UtilsProxy + from .proxies.web_directory_proxy import WebDirectoryProxy + from .vae_proxy import VAERegistry + + FolderPathsProxy() + HelperProxiesService() + ModelManagementProxy() + ProgressProxy() + PromptServerService() + UtilsProxy() + WebDirectoryProxy() + VAERegistry() diff --git a/comfy/isolation/host_policy.py b/comfy/isolation/host_policy.py new file mode 100644 index 000000000..f637e89d9 --- /dev/null +++ b/comfy/isolation/host_policy.py @@ -0,0 +1,178 @@ +# pylint: disable=logging-fstring-interpolation +from __future__ import annotations + +import logging +import os +from pathlib import Path +from pathlib import PurePosixPath +from typing import Dict, List, TypedDict + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + +HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH" +VALID_SANDBOX_MODES = frozenset({"required", "disabled"}) +FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"}) + + +class HostSecurityPolicy(TypedDict): + sandbox_mode: str + allow_network: bool + writable_paths: List[str] + readonly_paths: List[str] + sealed_worker_ro_import_paths: List[str] + whitelist: Dict[str, str] + + +DEFAULT_POLICY: HostSecurityPolicy = { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": ["/dev/shm"], + "readonly_paths": [], + "sealed_worker_ro_import_paths": [], + "whitelist": {}, +} + + +def _default_policy() -> HostSecurityPolicy: + return { + "sandbox_mode": DEFAULT_POLICY["sandbox_mode"], + "allow_network": DEFAULT_POLICY["allow_network"], + "writable_paths": list(DEFAULT_POLICY["writable_paths"]), + "readonly_paths": list(DEFAULT_POLICY["readonly_paths"]), + "sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]), + "whitelist": dict(DEFAULT_POLICY["whitelist"]), + } + + +def _normalize_writable_paths(paths: list[object]) -> list[str]: + normalized_paths: list[str] = [] + for raw_path in paths: + # Host-policy paths are contract-style POSIX paths; keep representation + # stable across Windows/Linux so tests and config behavior stay consistent. + normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/"))) + if normalized_path in FORBIDDEN_WRITABLE_PATHS: + continue + normalized_paths.append(normalized_path) + return normalized_paths + + +def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]: + if not file_path.is_absolute(): + file_path = config_path.parent / file_path + if not file_path.exists(): + logger.warning("whitelist_file %s not found, skipping.", file_path) + return {} + entries: Dict[str, str] = {} + for line in file_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + entries[line] = "*" + logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path) + return entries + + +def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]: + if not isinstance(raw_paths, list): + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths." + ) + + normalized_paths: list[str] = [] + seen: set[str] = set() + for raw_path in raw_paths: + if not isinstance(raw_path, str) or not raw_path.strip(): + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings." + ) + normalized_path = str(PurePosixPath(raw_path.replace("\\", "/"))) + # Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...) + is_absolute = normalized_path.startswith("/") or ( + len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/" + ) + if not is_absolute: + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths." + ) + if normalized_path not in seen: + seen.add(normalized_path) + normalized_paths.append(normalized_path) + + return normalized_paths + + +def load_host_policy(comfy_root: Path) -> HostSecurityPolicy: + config_override = os.environ.get(HOST_POLICY_PATH_ENV) + config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml" + policy = _default_policy() + + if not config_path.exists(): + logger.debug("Host policy file missing at %s, using defaults.", config_path) + return policy + + try: + with config_path.open("rb") as f: + data = tomllib.load(f) + except Exception: + logger.warning( + "Failed to parse host policy from %s, using defaults.", + config_path, + exc_info=True, + ) + return policy + + tool_config = data.get("tool", {}).get("comfy", {}).get("host", {}) + if not isinstance(tool_config, dict): + logger.debug("No [tool.comfy.host] section found, using defaults.") + return policy + + sandbox_mode = tool_config.get("sandbox_mode") + if isinstance(sandbox_mode, str): + normalized_sandbox_mode = sandbox_mode.strip().lower() + if normalized_sandbox_mode in VALID_SANDBOX_MODES: + policy["sandbox_mode"] = normalized_sandbox_mode + else: + logger.warning( + "Invalid host sandbox_mode %r in %s, using default %r.", + sandbox_mode, + config_path, + DEFAULT_POLICY["sandbox_mode"], + ) + + if "allow_network" in tool_config: + policy["allow_network"] = bool(tool_config["allow_network"]) + + if "writable_paths" in tool_config: + policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"]) + + if "readonly_paths" in tool_config: + policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]] + + if "sealed_worker_ro_import_paths" in tool_config: + policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths( + tool_config["sealed_worker_ro_import_paths"] + ) + + whitelist_file = tool_config.get("whitelist_file") + if isinstance(whitelist_file, str): + policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path)) + + whitelist_raw = tool_config.get("whitelist") + if isinstance(whitelist_raw, dict): + policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()}) + + logger.debug( + "Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s", + len(policy["whitelist"]), + policy["sandbox_mode"], + policy["allow_network"], + ) + return policy + + +__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"] diff --git a/comfy/isolation/manifest_loader.py b/comfy/isolation/manifest_loader.py new file mode 100644 index 000000000..4ae21d94d --- /dev/null +++ b/comfy/isolation/manifest_loader.py @@ -0,0 +1,221 @@ +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import hashlib +import json +import logging +import os +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import folder_paths + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + +CACHE_SUBDIR = "cache" +CACHE_KEY_FILE = "cache_key" +CACHE_DATA_FILE = "node_info.json" +CACHE_KEY_LENGTH = 16 +_NESTED_SCAN_ROOT = "packages" +_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"} + + +def _read_manifest(manifest_path: Path) -> dict[str, Any] | None: + try: + with manifest_path.open("rb") as f: + data = tomllib.load(f) + if isinstance(data, dict): + return data + except Exception: + return None + return None + + +def _is_isolation_manifest(data: dict[str, Any]) -> bool: + return ( + "tool" in data + and "comfy" in data["tool"] + and "isolation" in data["tool"]["comfy"] + ) + + +def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]: + packages_root = entry / _NESTED_SCAN_ROOT + if not packages_root.exists() or not packages_root.is_dir(): + return [] + + nested: List[Tuple[Path, Path]] = [] + for manifest in sorted(packages_root.rglob("pyproject.toml")): + node_dir = manifest.parent + if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts): + continue + + data = _read_manifest(manifest) + if not data or not _is_isolation_manifest(data): + continue + + isolation = data["tool"]["comfy"]["isolation"] + if isolation.get("standalone") is True: + nested.append((node_dir, manifest)) + + return nested + + +def find_manifest_directories() -> List[Tuple[Path, Path]]: + """Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation].""" + manifest_dirs: List[Tuple[Path, Path]] = [] + + # Standard custom_nodes paths + for base_path in folder_paths.get_folder_paths("custom_nodes"): + base = Path(base_path) + if not base.exists() or not base.is_dir(): + continue + + for entry in base.iterdir(): + if not entry.is_dir(): + continue + + # Look for pyproject.toml + manifest = entry / "pyproject.toml" + if not manifest.exists(): + continue + + data = _read_manifest(manifest) + if not data or not _is_isolation_manifest(data): + continue + + manifest_dirs.append((entry, manifest)) + manifest_dirs.extend(_discover_nested_manifests(entry)) + + return manifest_dirs + + +def compute_cache_key(node_dir: Path, manifest_path: Path) -> str: + """Hash manifest + .py mtimes + Python version + PyIsolate version.""" + hasher = hashlib.sha256() + + try: + # Hashing the manifest content ensures config changes invalidate cache + hasher.update(manifest_path.read_bytes()) + except OSError: + hasher.update(b"__manifest_read_error__") + + try: + py_files = sorted(node_dir.rglob("*.py")) + for py_file in py_files: + rel_path = py_file.relative_to(node_dir) + if "__pycache__" in str(rel_path) or ".venv" in str(rel_path): + continue + hasher.update(str(rel_path).encode("utf-8")) + try: + hasher.update(str(py_file.stat().st_mtime).encode("utf-8")) + except OSError: + hasher.update(b"__file_stat_error__") + except OSError: + hasher.update(b"__dir_scan_error__") + + hasher.update(sys.version.encode("utf-8")) + + try: + import pyisolate + + hasher.update(pyisolate.__version__.encode("utf-8")) + except (ImportError, AttributeError): + hasher.update(b"__pyisolate_unknown__") + + return hasher.hexdigest()[:CACHE_KEY_LENGTH] + + +def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]: + """Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/.""" + cache_dir = venv_root / node_dir.name / CACHE_SUBDIR + return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE) + + +def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool: + """Return True only if stored cache key matches current computed key.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_key_file.exists() or not cache_data_file.exists(): + return False + current_key = compute_cache_key(node_dir, manifest_path) + stored_key = cache_key_file.read_text(encoding="utf-8").strip() + return current_key == stored_key + except Exception as e: + logger.debug( + "%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e + ) + return False + + +def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]: + """Load node metadata from cache, return None on any error.""" + try: + _, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_data_file.exists(): + return None + data = json.loads(cache_data_file.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + return data + except Exception: + return None + + +def save_to_cache( + node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path +) -> None: + """Save node metadata and cache key atomically.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + cache_dir = cache_key_file.parent + cache_dir.mkdir(parents=True, exist_ok=True) + cache_key = compute_cache_key(node_dir, manifest_path) + + # Atomic write: data + tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f: + json.dump(node_data, f, indent=2) + os.replace(tmp_data_path, cache_data_file) + except Exception: + try: + os.unlink(tmp_data_path) + except OSError: + pass + raise + + # Atomic write: key + tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f: + f.write(cache_key) + os.replace(tmp_key_path, cache_key_file) + except Exception: + try: + os.unlink(tmp_key_path) + except OSError: + pass + raise + + except Exception as e: + logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e) + + +__all__ = [ + "LOG_PREFIX", + "find_manifest_directories", + "compute_cache_key", + "get_cache_path", + "is_cache_valid", + "load_from_cache", + "save_to_cache", +] diff --git a/comfy/isolation/model_patcher_proxy.py b/comfy/isolation/model_patcher_proxy.py new file mode 100644 index 000000000..f44de1d5a --- /dev/null +++ b/comfy/isolation/model_patcher_proxy.py @@ -0,0 +1,888 @@ +# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access +# RPC proxy for ModelPatcher (parent process) +from __future__ import annotations + +import logging +from typing import Any, Optional, List, Set, Dict, Callable + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, +) +from comfy.isolation.model_patcher_proxy_registry import ( + ModelPatcherRegistry, + AutoPatcherEjector, +) + +logger = logging.getLogger(__name__) + + +class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): + _registry_class = ModelPatcherRegistry + __module__ = "comfy.model_patcher" + _APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024 + + def _spawn_related_proxy(self, instance_id: str) -> "ModelPatcherProxy": + proxy = ModelPatcherProxy( + instance_id, + self._registry, + manage_lifecycle=not IS_CHILD_PROCESS, + ) + if getattr(self, "_rpc_caller", None) is not None: + proxy._rpc_caller = self._rpc_caller + return proxy + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is not None: + self._rpc_caller = rpc.create_caller( + self._registry_class, self._registry_class.get_remote_id() + ) + else: + self._rpc_caller = self._registry + return self._rpc_caller + + def get_all_callbacks(self, call_type: str = None) -> Any: + return self._call_rpc("get_all_callbacks", call_type) + + def get_all_wrappers(self, wrapper_type: str = None) -> Any: + return self._call_rpc("get_all_wrappers", wrapper_type) + + def _load_list(self, *args, **kwargs) -> Any: + return self._call_rpc("load_list_internal", *args, **kwargs) + + def prepare_hook_patches_current_keyframe( + self, t: Any, hook_group: Any, model_options: Any + ) -> None: + self._call_rpc( + "prepare_hook_patches_current_keyframe", t, hook_group, model_options + ) + + def add_hook_patches( + self, + hook: Any, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> None: + self._call_rpc( + "add_hook_patches", hook, patches, strength_patch, strength_model + ) + + def clear_cached_hook_weights(self) -> None: + self._call_rpc("clear_cached_hook_weights") + + def get_combined_hook_patches(self, hooks: Any) -> Any: + return self._call_rpc("get_combined_hook_patches", hooks) + + def get_additional_models_with_key(self, key: str) -> Any: + return self._call_rpc("get_additional_models_with_key", key) + + @property + def object_patches(self) -> Any: + return self._call_rpc("get_object_patches") + + @property + def patches(self) -> Any: + res = self._call_rpc("get_patches") + if isinstance(res, dict): + new_res = {} + for k, v in res.items(): + new_list = [] + for item in v: + if isinstance(item, list): + new_list.append(tuple(item)) + else: + new_list.append(item) + new_res[k] = new_list + return new_res + return res + + @property + def pinned(self) -> Set: + val = self._call_rpc("get_patcher_attr", "pinned") + return set(val) if val is not None else set() + + @property + def hook_patches(self) -> Dict: + val = self._call_rpc("get_patcher_attr", "hook_patches") + if val is None: + return {} + try: + from comfy.hooks import _HookRef + import json + + new_val = {} + for k, v in val.items(): + if isinstance(k, str): + if k.startswith("PYISOLATE_HOOKREF:"): + ref_id = k.split(":", 1)[1] + h = _HookRef() + h._pyisolate_id = ref_id + new_val[h] = v + elif k.startswith("__pyisolate_key__"): + try: + json_str = k[len("__pyisolate_key__") :] + data = json.loads(json_str) + ref_id = None + if isinstance(data, list): + for item in data: + if ( + isinstance(item, list) + and len(item) == 2 + and item[0] == "id" + ): + ref_id = item[1] + break + if ref_id: + h = _HookRef() + h._pyisolate_id = ref_id + new_val[h] = v + else: + new_val[k] = v + except Exception: + new_val[k] = v + else: + new_val[k] = v + else: + new_val[k] = v + return new_val + except ImportError: + return val + + def set_hook_mode(self, hook_mode: Any) -> None: + self._call_rpc("set_hook_mode", hook_mode) + + def register_all_hook_patches( + self, + hooks: Any, + target_dict: Any, + model_options: Any = None, + registered: Any = None, + ) -> None: + self._call_rpc( + "register_all_hook_patches", hooks, target_dict, model_options, registered + ) + + def is_clone(self, other: Any) -> bool: + if isinstance(other, ModelPatcherProxy): + return self._call_rpc("is_clone_by_id", other._instance_id) + return False + + def clone(self) -> ModelPatcherProxy: + new_id = self._call_rpc("clone") + return self._spawn_related_proxy(new_id) + + def clone_has_same_weights(self, clone: Any) -> bool: + if isinstance(clone, ModelPatcherProxy): + return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id) + if not IS_CHILD_PROCESS: + return self._call_rpc("is_clone", clone) + return False + + def get_model_object(self, name: str) -> Any: + return self._call_rpc("get_model_object", name) + + @property + def model_options(self) -> dict: + data = self._call_rpc("get_model_options") + import json + + def _decode_keys(obj): + if isinstance(obj, dict): + new_d = {} + for k, v in obj.items(): + if isinstance(k, str) and k.startswith("__pyisolate_key__"): + try: + json_str = k[17:] + val = json.loads(json_str) + if isinstance(val, list): + val = tuple(val) + new_d[val] = _decode_keys(v) + except: + new_d[k] = _decode_keys(v) + else: + new_d[k] = _decode_keys(v) + return new_d + if isinstance(obj, list): + return [_decode_keys(x) for x in obj] + return obj + + return _decode_keys(data) + + @model_options.setter + def model_options(self, value: dict) -> None: + self._call_rpc("set_model_options", value) + + def apply_hooks(self, hooks: Any) -> Any: + return self._call_rpc("apply_hooks", hooks) + + def prepare_state(self, timestep: Any) -> Any: + return self._call_rpc("prepare_state", timestep) + + def restore_hook_patches(self) -> None: + self._call_rpc("restore_hook_patches") + + def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None: + self._call_rpc("unpatch_hooks", whitelist_keys_set) + + def model_patches_to(self, device: Any) -> Any: + return self._call_rpc("model_patches_to", device) + + def partially_load( + self, device: Any, extra_memory: Any, force_patch_weights: bool = False + ) -> Any: + return self._call_rpc( + "partially_load", device, extra_memory, force_patch_weights + ) + + def partially_unload( + self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False + ) -> int: + return self._call_rpc( + "partially_unload", device_to, memory_to_free, force_patch_weights + ) + + def load( + self, + device_to: Any = None, + lowvram_model_memory: int = 0, + force_patch_weights: bool = False, + full_load: bool = False, + ) -> None: + self._call_rpc( + "load", device_to, lowvram_model_memory, force_patch_weights, full_load + ) + + def patch_model( + self, + device_to: Any = None, + lowvram_model_memory: int = 0, + load_weights: bool = True, + force_patch_weights: bool = False, + ) -> Any: + self._call_rpc( + "patch_model", + device_to, + lowvram_model_memory, + load_weights, + force_patch_weights, + ) + return self + + def unpatch_model( + self, device_to: Any = None, unpatch_weights: bool = True + ) -> None: + self._call_rpc("unpatch_model", device_to, unpatch_weights) + + def detach(self, unpatch_all: bool = True) -> Any: + self._call_rpc("detach", unpatch_all) + return self.model + + def _cpu_tensor_bytes(self, obj: Any) -> int: + import torch + + if isinstance(obj, torch.Tensor): + if obj.device.type == "cpu": + return obj.nbytes + return 0 + if isinstance(obj, dict): + return sum(self._cpu_tensor_bytes(v) for v in obj.values()) + if isinstance(obj, (list, tuple)): + return sum(self._cpu_tensor_bytes(v) for v in obj) + return 0 + + def _ensure_apply_model_headroom(self, required_bytes: int) -> bool: + if required_bytes <= 0: + return True + + import torch + import comfy.model_management as model_management + + target_raw = self.load_device + try: + if isinstance(target_raw, torch.device): + target = target_raw + elif isinstance(target_raw, str): + target = torch.device(target_raw) + elif isinstance(target_raw, int): + target = torch.device(f"cuda:{target_raw}") + else: + target = torch.device(target_raw) + except Exception: + return True + + if target.type != "cuda": + return True + + required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES + if model_management.get_free_memory(target) >= required: + return True + + model_management.cleanup_models_gc() + model_management.cleanup_models() + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + model_management.free_memory(required, target, for_dynamic=True) + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + # Escalate to non-dynamic unloading before dispatching CUDA transfer. + model_management.free_memory(required, target, for_dynamic=False) + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + model_management.load_models_gpu( + [self], + minimum_memory_required=required, + ) + + return model_management.get_free_memory(target) >= required + + def apply_model(self, *args, **kwargs) -> Any: + import torch + + def _preferred_device() -> Any: + for value in args: + if isinstance(value, torch.Tensor): + return value.device + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + return value.device + return None + + def _move_result_to_device(obj: Any, device: Any) -> Any: + if device is None: + return obj + if isinstance(obj, torch.Tensor): + return obj.to(device) if obj.device != device else obj + if isinstance(obj, dict): + return {k: _move_result_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, list): + return [_move_result_to_device(v, device) for v in obj] + if isinstance(obj, tuple): + return tuple(_move_result_to_device(v, device) for v in obj) + return obj + + # DynamicVRAM models must keep load/offload decisions in host process. + # Child-side CUDA staging here can deadlock before first inference RPC. + if self.is_dynamic(): + out = self._call_rpc("inner_model_apply_model", args, kwargs) + return _move_result_to_device(out, _preferred_device()) + + required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs) + self._ensure_apply_model_headroom(required_bytes) + + def _to_cuda(obj: Any) -> Any: + if isinstance(obj, torch.Tensor) and obj.device.type == "cpu": + return obj.to("cuda") + if isinstance(obj, dict): + return {k: _to_cuda(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cuda(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cuda(v) for v in obj) + return obj + + try: + args_cuda = _to_cuda(args) + kwargs_cuda = _to_cuda(kwargs) + except torch.OutOfMemoryError: + self._ensure_apply_model_headroom(required_bytes) + args_cuda = _to_cuda(args) + kwargs_cuda = _to_cuda(kwargs) + + out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda) + return _move_result_to_device(out, _preferred_device()) + + def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any: + keys = self._call_rpc("model_state_dict", filter_prefix) + return dict.fromkeys(keys, None) + + def add_patches(self, *args: Any, **kwargs: Any) -> Any: + res = self._call_rpc("add_patches", *args, **kwargs) + if isinstance(res, list): + return [tuple(x) if isinstance(x, list) else x for x in res] + return res + + def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any: + return self._call_rpc("get_key_patches", filter_prefix) + + def patch_weight_to_device(self, key, device_to=None, inplace_update=False): + self._call_rpc("patch_weight_to_device", key, device_to, inplace_update) + + def pin_weight_to_device(self, key): + self._call_rpc("pin_weight_to_device", key) + + def unpin_weight(self, key): + self._call_rpc("unpin_weight", key) + + def unpin_all_weights(self): + self._call_rpc("unpin_all_weights") + + def calculate_weight(self, patches, weight, key, intermediate_dtype=None): + return self._call_rpc( + "calculate_weight", patches, weight, key, intermediate_dtype + ) + + def inject_model(self) -> None: + self._call_rpc("inject_model") + + def eject_model(self) -> None: + self._call_rpc("eject_model") + + def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any: + return AutoPatcherEjector( + self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only + ) + + @property + def is_injected(self) -> bool: + return self._call_rpc("get_is_injected") + + @property + def skip_injection(self) -> bool: + return self._call_rpc("get_skip_injection") + + @skip_injection.setter + def skip_injection(self, value: bool) -> None: + self._call_rpc("set_skip_injection", value) + + def clean_hooks(self) -> None: + self._call_rpc("clean_hooks") + + def pre_run(self) -> None: + self._call_rpc("pre_run") + + def cleanup(self) -> None: + try: + self._call_rpc("cleanup") + except Exception: + logger.debug( + "ModelPatcherProxy cleanup RPC failed for %s", + self._instance_id, + exc_info=True, + ) + finally: + super().cleanup() + + @property + def model(self) -> _InnerModelProxy: + return _InnerModelProxy(self) + + def __getattr__(self, name: str) -> Any: + _whitelisted_attrs = { + "hook_patches_backup", + "hook_backup", + "cached_hook_patches", + "current_hooks", + "forced_hooks", + "is_clip", + "patches_uuid", + "pinned", + "attachments", + "additional_models", + "injections", + "hook_patches", + "model_lowvram", + "model_loaded_weight_memory", + "backup", + "object_patches_backup", + "weight_wrapper_patches", + "weight_inplace_update", + "force_cast_weights", + } + if name in _whitelisted_attrs: + return self._call_rpc("get_patcher_attr", name) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def load_lora( + self, + lora_path: str, + strength_model: float, + clip: Optional[Any] = None, + strength_clip: float = 1.0, + ) -> tuple: + clip_id = None + if clip is not None: + clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None)) + result = self._call_rpc( + "load_lora", lora_path, strength_model, clip_id, strength_clip + ) + new_model = None + if result.get("model_id"): + new_model = self._spawn_related_proxy(result["model_id"]) + new_clip = None + if result.get("clip_id"): + from comfy.isolation.clip_proxy import CLIPProxy + + new_clip = CLIPProxy(result["clip_id"]) + return (new_model, new_clip) + + @property + def load_device(self) -> Any: + return self._call_rpc("get_load_device") + + @property + def offload_device(self) -> Any: + return self._call_rpc("get_offload_device") + + @property + def device(self) -> Any: + return self.load_device + + def current_loaded_device(self) -> Any: + return self._call_rpc("current_loaded_device") + + @property + def size(self) -> int: + return self._call_rpc("get_size") + + def model_size(self) -> Any: + return self._call_rpc("model_size") + + def loaded_size(self) -> Any: + return self._call_rpc("loaded_size") + + def get_ram_usage(self) -> int: + return self._call_rpc("get_ram_usage") + + def lowvram_patch_counter(self) -> int: + return self._call_rpc("lowvram_patch_counter") + + def memory_required(self, input_shape: Any) -> Any: + return self._call_rpc("memory_required", input_shape) + + def get_operation_state(self) -> Dict[str, Any]: + state = self._call_rpc("get_operation_state") + return state if isinstance(state, dict) else {} + + def wait_for_idle(self, timeout_ms: int = 0) -> bool: + return bool(self._call_rpc("wait_for_idle", timeout_ms)) + + def is_dynamic(self) -> bool: + return bool(self._call_rpc("is_dynamic")) + + def get_free_memory(self, device: Any) -> Any: + return self._call_rpc("get_free_memory", device) + + def partially_unload_ram(self, ram_to_unload: int) -> Any: + return self._call_rpc("partially_unload_ram", ram_to_unload) + + def model_dtype(self) -> Any: + res = self._call_rpc("model_dtype") + if isinstance(res, str) and res.startswith("torch."): + try: + import torch + + attr = res.split(".")[-1] + if hasattr(torch, attr): + return getattr(torch, attr) + except ImportError: + pass + return res + + @property + def hook_mode(self) -> Any: + return self._call_rpc("get_hook_mode") + + @hook_mode.setter + def hook_mode(self, value: Any) -> None: + self._call_rpc("set_hook_mode", value) + + def set_model_sampler_cfg_function( + self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_cfg_function", + sampler_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_post_cfg_function( + self, post_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_post_cfg_function", + post_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_pre_cfg_function( + self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_pre_cfg_function", + pre_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None: + self._call_rpc("set_model_sampler_calc_cond_batch_function", fn) + + def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None: + self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function) + + def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None: + self._call_rpc("set_model_denoise_mask_function", denoise_mask_function) + + def set_model_patch(self, patch: Any, name: str) -> None: + self._call_rpc("set_model_patch", patch, name) + + def set_model_patch_replace( + self, + patch: Any, + name: str, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self._call_rpc( + "set_model_patch_replace", + patch, + name, + block_name, + number, + transformer_index, + ) + + def set_model_attn1_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn2_patch") + + def set_model_attn1_replace( + self, + patch: Any, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self.set_model_patch_replace( + patch, "attn1", block_name, number, transformer_index + ) + + def set_model_attn2_replace( + self, + patch: Any, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self.set_model_patch_replace( + patch, "attn2", block_name, number, transformer_index + ) + + def set_model_attn1_output_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn1_output_patch") + + def set_model_attn2_output_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn2_output_patch") + + def set_model_input_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "input_block_patch") + + def set_model_input_block_patch_after_skip(self, patch: Any) -> None: + self.set_model_patch(patch, "input_block_patch_after_skip") + + def set_model_output_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "output_block_patch") + + def set_model_emb_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "emb_patch") + + def set_model_forward_timestep_embed_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "forward_timestep_embed_patch") + + def set_model_double_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "double_block") + + def set_model_post_input_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "post_input") + + def set_model_rope_options( + self, + scale_x=1.0, + shift_x=0.0, + scale_y=1.0, + shift_y=0.0, + scale_t=1.0, + shift_t=0.0, + **kwargs: Any, + ) -> None: + options = { + "scale_x": scale_x, + "shift_x": shift_x, + "scale_y": scale_y, + "shift_y": shift_y, + "scale_t": scale_t, + "shift_t": shift_t, + } + options.update(kwargs) + self._call_rpc("set_model_rope_options", options) + + def set_model_compute_dtype(self, dtype: Any) -> None: + self._call_rpc("set_model_compute_dtype", dtype) + + def add_object_patch(self, name: str, obj: Any) -> None: + self._call_rpc("add_object_patch", name, obj) + + def add_weight_wrapper(self, name: str, function: Any) -> None: + self._call_rpc("add_weight_wrapper", name, function) + + def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None: + self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn) + + def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None: + self.add_wrapper_with_key(wrapper_type, None, wrapper) + + def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None: + self._call_rpc("remove_wrappers_with_key", wrapper_type, key) + + @property + def wrappers(self) -> Any: + return self._call_rpc("get_wrappers") + + def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None: + self._call_rpc("add_callback_with_key", call_type, key, callback) + + def add_callback(self, call_type: str, callback: Any) -> None: + self.add_callback_with_key(call_type, None, callback) + + def remove_callbacks_with_key(self, call_type: str, key: str) -> None: + self._call_rpc("remove_callbacks_with_key", call_type, key) + + @property + def callbacks(self) -> Any: + return self._call_rpc("get_callbacks") + + def set_attachments(self, key: str, attachment: Any) -> None: + self._call_rpc("set_attachments", key, attachment) + + def get_attachment(self, key: str) -> Any: + return self._call_rpc("get_attachment", key) + + def remove_attachments(self, key: str) -> None: + self._call_rpc("remove_attachments", key) + + def set_injections(self, key: str, injections: Any) -> None: + self._call_rpc("set_injections", key, injections) + + def get_injections(self, key: str) -> Any: + return self._call_rpc("get_injections", key) + + def remove_injections(self, key: str) -> None: + self._call_rpc("remove_injections", key) + + def set_additional_models(self, key: str, models: Any) -> None: + ids = [m._instance_id for m in models] + self._call_rpc("set_additional_models", key, ids) + + def remove_additional_models(self, key: str) -> None: + self._call_rpc("remove_additional_models", key) + + def get_nested_additional_models(self) -> Any: + return self._call_rpc("get_nested_additional_models") + + def get_additional_models(self) -> List[ModelPatcherProxy]: + ids = self._call_rpc("get_additional_models") + return [self._spawn_related_proxy(mid) for mid in ids] + + def model_patches_models(self) -> Any: + return self._call_rpc("model_patches_models") + + @property + def parent(self) -> Any: + return self._call_rpc("get_parent") + + def model_mmap_residency(self, free: bool = False) -> tuple: + result = self._call_rpc("model_mmap_residency", free) + if isinstance(result, list): + return tuple(result) + return result + + def pinned_memory_size(self) -> int: + return self._call_rpc("pinned_memory_size") + + def get_non_dynamic_delegate(self) -> ModelPatcherProxy: + new_id = self._call_rpc("get_non_dynamic_delegate") + return self._spawn_related_proxy(new_id) + + def disable_model_cfg1_optimization(self) -> None: + self._call_rpc("disable_model_cfg1_optimization") + + def set_model_noise_refiner_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "noise_refiner") + + +class _InnerModelProxy: + def __init__(self, parent: ModelPatcherProxy): + self._parent = parent + self._model_sampling = None + + def __getattr__(self, name: str) -> Any: + if name.startswith("_"): + raise AttributeError(name) + if name == "model_config": + from types import SimpleNamespace + + data = self._parent._call_rpc("get_inner_model_attr", name) + if isinstance(data, dict): + return SimpleNamespace(**data) + return data + if name in ( + "latent_format", + "model_type", + "current_weight_patches_uuid", + ): + return self._parent._call_rpc("get_inner_model_attr", name) + if name == "load_device": + return self._parent._call_rpc("get_inner_model_attr", "load_device") + if name == "device": + return self._parent._call_rpc("get_inner_model_attr", "device") + if name == "current_patcher": + proxy = ModelPatcherProxy( + self._parent._instance_id, + self._parent._registry, + manage_lifecycle=False, + ) + if getattr(self._parent, "_rpc_caller", None) is not None: + proxy._rpc_caller = self._parent._rpc_caller + return proxy + if name == "model_sampling": + if self._model_sampling is None: + self._model_sampling = self._parent._call_rpc( + "get_model_object", "model_sampling" + ) + return self._model_sampling + if name == "extra_conds_shapes": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_extra_conds_shapes", a, k + ) + if name == "extra_conds": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_extra_conds", a, k + ) + if name == "memory_required": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_memory_required", a, k + ) + if name == "apply_model": + # Delegate to parent's method to get the CPU->CUDA optimization + return self._parent.apply_model + if name == "process_latent_in": + return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k) + if name == "process_latent_out": + return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k) + if name == "scale_latent_inpaint": + return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k) + if name == "diffusion_model": + return self._parent._call_rpc("get_inner_model_attr", "diffusion_model") + raise AttributeError(f"'{name}' not supported on isolated InnerModel") diff --git a/comfy/isolation/model_patcher_proxy_registry.py b/comfy/isolation/model_patcher_proxy_registry.py new file mode 100644 index 000000000..b657121eb --- /dev/null +++ b/comfy/isolation/model_patcher_proxy_registry.py @@ -0,0 +1,1311 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,unused-import +# RPC server for ModelPatcher isolation (child process) +from __future__ import annotations + +import asyncio +import gc +import logging +import threading +import time +from dataclasses import dataclass, field +from typing import Any, Optional, List + +try: + from comfy.model_patcher import AutoPatcherEjector +except ImportError: + + class AutoPatcherEjector: + def __init__(self, model, skip_and_inject_on_exit_only=False): + self.model = model + self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only + self.prev_skip_injection = False + self.was_injected = False + + def __enter__(self): + self.was_injected = False + self.prev_skip_injection = self.model.skip_injection + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = True + if self.model.is_injected: + self.model.eject_model() + self.was_injected = True + + def __exit__(self, *args): + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = self.prev_skip_injection + self.model.inject_model() + if self.was_injected and not self.model.skip_injection: + self.model.inject_model() + self.model.skip_injection = self.prev_skip_injection + + +from comfy.isolation.proxies.base import ( + BaseRegistry, + detach_if_grad, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class _OperationState: + lease: threading.Lock = field(default_factory=threading.Lock) + active_count: int = 0 + active_by_method: dict[str, int] = field(default_factory=dict) + total_operations: int = 0 + last_method: Optional[str] = None + last_started_ts: Optional[float] = None + last_ended_ts: Optional[float] = None + last_elapsed_ms: Optional[float] = None + last_error: Optional[str] = None + last_thread_id: Optional[int] = None + last_loop_id: Optional[int] = None + + +class ModelPatcherRegistry(BaseRegistry[Any]): + _type_prefix = "model" + + def __init__(self) -> None: + super().__init__() + self._pending_cleanup_ids: set[str] = set() + self._operation_states: dict[str, _OperationState] = {} + self._operation_state_cv = threading.Condition(self._lock) + + def _get_or_create_operation_state(self, instance_id: str) -> _OperationState: + state = self._operation_states.get(instance_id) + if state is None: + state = _OperationState() + self._operation_states[instance_id] = state + return state + + def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]: + start_epoch = time.time() + start_perf = time.perf_counter() + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + state.active_count += 1 + state.active_by_method[method_name] = ( + state.active_by_method.get(method_name, 0) + 1 + ) + state.total_operations += 1 + state.last_method = method_name + state.last_started_ts = start_epoch + state.last_thread_id = threading.get_ident() + try: + state.last_loop_id = id(asyncio.get_running_loop()) + except RuntimeError: + state.last_loop_id = None + logger.debug( + "ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s", + instance_id, + method_name, + start_epoch, + threading.get_ident(), + state.last_loop_id, + ) + return start_epoch, start_perf + + def _end_operation( + self, + instance_id: str, + method_name: str, + start_perf: float, + error: Optional[BaseException] = None, + ) -> None: + end_epoch = time.time() + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + state.active_count = max(0, state.active_count - 1) + if method_name in state.active_by_method: + remaining = state.active_by_method[method_name] - 1 + if remaining <= 0: + state.active_by_method.pop(method_name, None) + else: + state.active_by_method[method_name] = remaining + state.last_ended_ts = end_epoch + state.last_elapsed_ms = elapsed_ms + state.last_error = None if error is None else repr(error) + if state.active_count == 0: + self._operation_state_cv.notify_all() + logger.debug( + "ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s", + instance_id, + method_name, + end_epoch, + elapsed_ms, + None if error is None else type(error).__name__, + ) + + def _run_operation_with_lease(self, instance_id: str, method_name: str, fn): + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + lease = state.lease + with lease: + _, start_perf = self._begin_operation(instance_id, method_name) + try: + result = fn() + except Exception as exc: + self._end_operation(instance_id, method_name, start_perf, error=exc) + raise + self._end_operation(instance_id, method_name, start_perf) + return result + + def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]: + with self._operation_state_cv: + state = self._operation_states.get(instance_id) + if state is None: + return { + "instance_id": instance_id, + "active_count": 0, + "active_methods": [], + "total_operations": 0, + "last_method": None, + "last_started_ts": None, + "last_ended_ts": None, + "last_elapsed_ms": None, + "last_error": None, + "last_thread_id": None, + "last_loop_id": None, + } + return { + "instance_id": instance_id, + "active_count": state.active_count, + "active_methods": sorted(state.active_by_method.keys()), + "total_operations": state.total_operations, + "last_method": state.last_method, + "last_started_ts": state.last_started_ts, + "last_ended_ts": state.last_ended_ts, + "last_elapsed_ms": state.last_elapsed_ms, + "last_error": state.last_error, + "last_thread_id": state.last_thread_id, + "last_loop_id": state.last_loop_id, + } + + def unregister_sync(self, instance_id: str) -> None: + with self._operation_state_cv: + instance = self._registry.pop(instance_id, None) + if instance is not None: + self._id_map.pop(id(instance), None) + self._pending_cleanup_ids.discard(instance_id) + self._operation_states.pop(instance_id, None) + self._operation_state_cv.notify_all() + + async def get_operation_state(self, instance_id: str) -> dict[str, Any]: + return self._snapshot_operation_state(instance_id) + + async def get_all_operation_states(self) -> dict[str, dict[str, Any]]: + with self._operation_state_cv: + ids = sorted(self._operation_states.keys()) + return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids} + + async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool: + timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0) + deadline = None if timeout_s is None else (time.monotonic() + timeout_s) + with self._operation_state_cv: + while True: + active = self._operation_states.get(instance_id) + if active is None or active.active_count == 0: + return True + if deadline is None: + self._operation_state_cv.wait() + continue + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + self._operation_state_cv.wait(timeout=remaining) + + async def wait_all_idle(self, timeout_ms: int = 0) -> bool: + timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0) + deadline = None if timeout_s is None else (time.monotonic() + timeout_s) + with self._operation_state_cv: + while True: + has_active = any( + state.active_count > 0 for state in self._operation_states.values() + ) + if not has_active: + return True + if deadline is None: + self._operation_state_cv.wait() + continue + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + self._operation_state_cv.wait(timeout=remaining) + + async def clone(self, instance_id: str) -> str: + instance = self._get_instance(instance_id) + new_model = instance.clone() + return self.register(new_model) + + async def is_clone(self, instance_id: str, other: Any) -> bool: + instance = self._get_instance(instance_id) + if hasattr(other, "model"): + return instance.is_clone(other) + return False + + async def get_model_object(self, instance_id: str, name: str) -> Any: + instance = self._get_instance(instance_id) + if name == "model": + return f"" + result = instance.get_model_object(name) + if name == "model_sampling": + # Return inline serialization so the child reconstructs the real + # class with correct isinstance behavior. Returning a + # ModelSamplingProxy breaks isinstance checks (e.g. + # offset_first_sigma_for_snr in k_diffusion/sampling.py:173). + return self._serialize_model_sampling_inline(result) + + return detach_if_grad(result) + + @staticmethod + def _serialize_model_sampling_inline(obj: Any) -> dict: + """Serialize a ModelSampling object as inline data for the child to reconstruct.""" + import torch + import base64 + import io as _io + + bases = [] + for base in type(obj).__mro__: + if base.__module__ == "comfy.model_sampling" and base.__name__ != "object": + bases.append(base.__name__) + + sd = obj.state_dict() + sd_serialized = {} + for k, v in sd.items(): + buf = _io.BytesIO() + torch.save(v, buf) + sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii") + + plain_attrs = {} + for k, v in obj.__dict__.items(): + if k.startswith("_"): + continue + if isinstance(v, (bool, int, float, str)): + plain_attrs[k] = v + + return { + "__type__": "ModelSamplingInline", + "bases": bases, + "state_dict": sd_serialized, + "attrs": plain_attrs, + } + + async def get_model_options(self, instance_id: str) -> dict: + instance = self._get_instance(instance_id) + import copy + + opts = copy.deepcopy(instance.model_options) + return self._sanitize_rpc_result(opts) + + async def set_model_options(self, instance_id: str, options: dict) -> None: + self._get_instance(instance_id).model_options = options + + async def get_patcher_attr(self, instance_id: str, name: str) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), name, None) + ) + + async def model_state_dict(self, instance_id: str, filter_prefix=None) -> Any: + instance = self._get_instance(instance_id) + sd_keys = instance.model.state_dict().keys() + return dict.fromkeys(sd_keys, None) + + def _sanitize_rpc_result(self, obj, seen=None): + if seen is None: + seen = set() + if obj is None: + return None + if isinstance(obj, (bool, int, float, str)): + if isinstance(obj, str) and len(obj) > 500000: + return f"" + return obj + obj_id = id(obj) + if obj_id in seen: + return None + seen.add(obj_id) + if isinstance(obj, (list, tuple)): + return [self._sanitize_rpc_result(x, seen) for x in obj] + if isinstance(obj, set): + return [self._sanitize_rpc_result(x, seen) for x in obj] + if isinstance(obj, dict): + new_dict = {} + for k, v in obj.items(): + if isinstance(k, tuple): + import json + + try: + key_str = "__pyisolate_key__" + json.dumps(list(k)) + new_dict[key_str] = self._sanitize_rpc_result(v, seen) + except Exception: + new_dict[str(k)] = self._sanitize_rpc_result(v, seen) + else: + new_dict[str(k)] = self._sanitize_rpc_result(v, seen) + return new_dict + if ( + hasattr(obj, "__dict__") + and not hasattr(obj, "__get__") + and not hasattr(obj, "__call__") + ): + return self._sanitize_rpc_result(obj.__dict__, seen) + if hasattr(obj, "items") and hasattr(obj, "get"): + return {str(k): self._sanitize_rpc_result(v, seen) for k, v in obj.items()} + return None + + async def get_load_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).load_device + + async def get_offload_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).offload_device + + async def current_loaded_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).current_loaded_device() + + async def get_size(self, instance_id: str) -> int: + return self._get_instance(instance_id).size + + async def model_size(self, instance_id: str) -> Any: + return self._get_instance(instance_id).model_size() + + async def loaded_size(self, instance_id: str) -> Any: + return self._get_instance(instance_id).loaded_size() + + async def get_ram_usage(self, instance_id: str) -> int: + return self._get_instance(instance_id).get_ram_usage() + + async def model_mmap_residency(self, instance_id: str, free: bool = False) -> tuple: + return self._get_instance(instance_id).model_mmap_residency(free=free) + + async def pinned_memory_size(self, instance_id: str) -> int: + return self._get_instance(instance_id).pinned_memory_size() + + async def get_non_dynamic_delegate(self, instance_id: str) -> str: + instance = self._get_instance(instance_id) + delegate = instance.get_non_dynamic_delegate() + return self.register(delegate) + + async def disable_model_cfg1_optimization(self, instance_id: str) -> None: + self._get_instance(instance_id).disable_model_cfg1_optimization() + + async def lowvram_patch_counter(self, instance_id: str) -> int: + return self._get_instance(instance_id).lowvram_patch_counter() + + async def memory_required(self, instance_id: str, input_shape: Any) -> Any: + return self._run_operation_with_lease( + instance_id, + "memory_required", + lambda: self._get_instance(instance_id).memory_required(input_shape), + ) + + async def is_dynamic(self, instance_id: str) -> bool: + instance = self._get_instance(instance_id) + if hasattr(instance, "is_dynamic"): + return bool(instance.is_dynamic()) + return False + + async def get_free_memory(self, instance_id: str, device: Any) -> Any: + instance = self._get_instance(instance_id) + if hasattr(instance, "get_free_memory"): + return instance.get_free_memory(device) + import comfy.model_management + + return comfy.model_management.get_free_memory(device) + + async def partially_unload_ram(self, instance_id: str, ram_to_unload: int) -> Any: + instance = self._get_instance(instance_id) + if hasattr(instance, "partially_unload_ram"): + return instance.partially_unload_ram(ram_to_unload) + return None + + async def model_dtype(self, instance_id: str) -> Any: + return self._run_operation_with_lease( + instance_id, + "model_dtype", + lambda: self._get_instance(instance_id).model_dtype(), + ) + + async def model_patches_to(self, instance_id: str, device: Any) -> Any: + return self._get_instance(instance_id).model_patches_to(device) + + async def partially_load( + self, + instance_id: str, + device: Any, + extra_memory: Any, + force_patch_weights: bool = False, + ) -> Any: + return self._run_operation_with_lease( + instance_id, + "partially_load", + lambda: self._get_instance(instance_id).partially_load( + device, extra_memory, force_patch_weights=force_patch_weights + ), + ) + + async def partially_unload( + self, + instance_id: str, + device_to: Any, + memory_to_free: int = 0, + force_patch_weights: bool = False, + ) -> int: + return self._run_operation_with_lease( + instance_id, + "partially_unload", + lambda: self._get_instance(instance_id).partially_unload( + device_to, memory_to_free, force_patch_weights + ), + ) + + async def load( + self, + instance_id: str, + device_to: Any = None, + lowvram_model_memory: int = 0, + force_patch_weights: bool = False, + full_load: bool = False, + ) -> None: + self._run_operation_with_lease( + instance_id, + "load", + lambda: self._get_instance(instance_id).load( + device_to, lowvram_model_memory, force_patch_weights, full_load + ), + ) + + async def patch_model( + self, + instance_id: str, + device_to: Any = None, + lowvram_model_memory: int = 0, + load_weights: bool = True, + force_patch_weights: bool = False, + ) -> None: + def _invoke() -> None: + try: + self._get_instance(instance_id).patch_model( + device_to, lowvram_model_memory, load_weights, force_patch_weights + ) + except AttributeError as e: + logger.error( + f"Isolation Error: Failed to patch model attribute: {e}. Skipping." + ) + return + + self._run_operation_with_lease(instance_id, "patch_model", _invoke) + + async def unpatch_model( + self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True + ) -> None: + self._run_operation_with_lease( + instance_id, + "unpatch_model", + lambda: self._get_instance(instance_id).unpatch_model( + device_to, unpatch_weights + ), + ) + + async def detach(self, instance_id: str, unpatch_all: bool = True) -> None: + self._get_instance(instance_id).detach(unpatch_all) + + async def prepare_state(self, instance_id: str, timestep: Any) -> Any: + instance = self._get_instance(instance_id) + cp = getattr(instance.model, "current_patcher", instance) + if cp is None: + cp = instance + return cp.prepare_state(timestep) + + async def pre_run(self, instance_id: str) -> None: + self._get_instance(instance_id).pre_run() + + async def cleanup(self, instance_id: str) -> None: + def _invoke() -> None: + try: + instance = self._get_instance(instance_id) + except Exception: + logger.debug( + "ModelPatcher cleanup requested for missing instance %s", + instance_id, + exc_info=True, + ) + return + + try: + instance.cleanup() + finally: + with self._lock: + self._pending_cleanup_ids.add(instance_id) + gc.collect() + + self._run_operation_with_lease(instance_id, "cleanup", _invoke) + + def sweep_pending_cleanup(self) -> int: + removed = 0 + with self._operation_state_cv: + pending_ids = list(self._pending_cleanup_ids) + self._pending_cleanup_ids.clear() + for instance_id in pending_ids: + instance = self._registry.pop(instance_id, None) + if instance is None: + continue + self._id_map.pop(id(instance), None) + self._operation_states.pop(instance_id, None) + removed += 1 + self._operation_state_cv.notify_all() + + gc.collect() + return removed + + def purge_all(self) -> int: + with self._operation_state_cv: + removed = len(self._registry) + self._registry.clear() + self._id_map.clear() + self._pending_cleanup_ids.clear() + self._operation_states.clear() + self._operation_state_cv.notify_all() + gc.collect() + return removed + + async def apply_hooks(self, instance_id: str, hooks: Any) -> Any: + instance = self._get_instance(instance_id) + cp = getattr(instance.model, "current_patcher", instance) + if cp is None: + cp = instance + return cp.apply_hooks(hooks=hooks) + + async def clean_hooks(self, instance_id: str) -> None: + self._get_instance(instance_id).clean_hooks() + + async def restore_hook_patches(self, instance_id: str) -> None: + self._get_instance(instance_id).restore_hook_patches() + + async def unpatch_hooks( + self, instance_id: str, whitelist_keys_set: Optional[set] = None + ) -> None: + self._get_instance(instance_id).unpatch_hooks(whitelist_keys_set) + + async def register_all_hook_patches( + self, + instance_id: str, + hooks: Any, + target_dict: Any, + model_options: Any, + registered: Any, + ) -> None: + from types import SimpleNamespace + import comfy.hooks + + instance = self._get_instance(instance_id) + if isinstance(hooks, SimpleNamespace) or hasattr(hooks, "__dict__"): + hook_data = hooks.__dict__ if hasattr(hooks, "__dict__") else hooks + new_hooks = comfy.hooks.HookGroup() + if hasattr(hook_data, "hooks"): + new_hooks.hooks = ( + hook_data["hooks"] + if isinstance(hook_data, dict) + else hook_data.hooks + ) + hooks = new_hooks + instance.register_all_hook_patches( + hooks, target_dict, model_options, registered + ) + + async def get_hook_mode(self, instance_id: str) -> Any: + return getattr(self._get_instance(instance_id), "hook_mode", None) + + async def set_hook_mode(self, instance_id: str, value: Any) -> None: + setattr(self._get_instance(instance_id), "hook_mode", value) + + async def inject_model(self, instance_id: str) -> None: + instance = self._get_instance(instance_id) + try: + instance.inject_model() + except AttributeError as e: + if "inject" in str(e): + logger.error( + "Isolation Error: Injector object lost method code during serialization. Cannot inject. Skipping." + ) + return + raise e + + async def eject_model(self, instance_id: str) -> None: + self._get_instance(instance_id).eject_model() + + async def get_is_injected(self, instance_id: str) -> bool: + return self._get_instance(instance_id).is_injected + + async def set_skip_injection(self, instance_id: str, value: bool) -> None: + self._get_instance(instance_id).skip_injection = value + + async def get_skip_injection(self, instance_id: str) -> bool: + return self._get_instance(instance_id).skip_injection + + async def set_model_sampler_cfg_function( + self, + instance_id: str, + sampler_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + if not callable(sampler_cfg_function): + logger.error( + f"set_model_sampler_cfg_function: Expected callable, got {type(sampler_cfg_function)}. Skipping." + ) + return + self._get_instance(instance_id).set_model_sampler_cfg_function( + sampler_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_post_cfg_function( + self, + instance_id: str, + post_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + self._get_instance(instance_id).set_model_sampler_post_cfg_function( + post_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_pre_cfg_function( + self, + instance_id: str, + pre_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + self._get_instance(instance_id).set_model_sampler_pre_cfg_function( + pre_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_calc_cond_batch_function( + self, instance_id: str, fn: Any + ) -> None: + self._get_instance(instance_id).set_model_sampler_calc_cond_batch_function(fn) + + async def set_model_unet_function_wrapper( + self, instance_id: str, unet_wrapper_function: Any + ) -> None: + self._get_instance(instance_id).set_model_unet_function_wrapper( + unet_wrapper_function + ) + + async def set_model_denoise_mask_function( + self, instance_id: str, denoise_mask_function: Any + ) -> None: + self._get_instance(instance_id).set_model_denoise_mask_function( + denoise_mask_function + ) + + async def set_model_patch(self, instance_id: str, patch: Any, name: str) -> None: + self._get_instance(instance_id).set_model_patch(patch, name) + + async def set_model_patch_replace( + self, + instance_id: str, + patch: Any, + name: str, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self._get_instance(instance_id).set_model_patch_replace( + patch, name, block_name, number, transformer_index + ) + + async def set_model_input_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_input_block_patch(patch) + + async def set_model_input_block_patch_after_skip( + self, instance_id: str, patch: Any + ) -> None: + self._get_instance(instance_id).set_model_input_block_patch_after_skip(patch) + + async def set_model_output_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_output_block_patch(patch) + + async def set_model_emb_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_emb_patch(patch) + + async def set_model_forward_timestep_embed_patch( + self, instance_id: str, patch: Any + ) -> None: + self._get_instance(instance_id).set_model_forward_timestep_embed_patch(patch) + + async def set_model_double_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_double_block_patch(patch) + + async def set_model_post_input_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_post_input_patch(patch) + + async def set_model_rope_options(self, instance_id: str, options: dict) -> None: + self._get_instance(instance_id).set_model_rope_options(**options) + + async def set_model_compute_dtype(self, instance_id: str, dtype: Any) -> None: + self._get_instance(instance_id).set_model_compute_dtype(dtype) + + async def clone_has_same_weights_by_id( + self, instance_id: str, other_id: str + ) -> bool: + instance = self._get_instance(instance_id) + other = self._get_instance(other_id) + if not other: + return False + return instance.clone_has_same_weights(other) + + async def load_list_internal(self, instance_id: str, *args, **kwargs) -> Any: + return self._get_instance(instance_id)._load_list(*args, **kwargs) + + async def is_clone_by_id(self, instance_id: str, other_id: str) -> bool: + instance = self._get_instance(instance_id) + other = self._get_instance(other_id) + if hasattr(instance, "is_clone"): + return instance.is_clone(other) + return False + + async def add_object_patch(self, instance_id: str, name: str, obj: Any) -> None: + self._get_instance(instance_id).add_object_patch(name, obj) + + async def add_weight_wrapper( + self, instance_id: str, name: str, function: Any + ) -> None: + self._get_instance(instance_id).add_weight_wrapper(name, function) + + async def add_wrapper_with_key( + self, instance_id: str, wrapper_type: Any, key: str, fn: Any + ) -> None: + self._get_instance(instance_id).add_wrapper_with_key(wrapper_type, key, fn) + + async def remove_wrappers_with_key( + self, instance_id: str, wrapper_type: str, key: str + ) -> None: + self._get_instance(instance_id).remove_wrappers_with_key(wrapper_type, key) + + async def get_wrappers( + self, instance_id: str, wrapper_type: str = None, key: str = None + ) -> Any: + if wrapper_type is None and key is None: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "wrappers", {}) + ) + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_wrappers(wrapper_type, key) + ) + + async def get_all_wrappers(self, instance_id: str, wrapper_type: str = None) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "get_all_wrappers", lambda x: [])( + wrapper_type + ) + ) + + async def add_callback_with_key( + self, instance_id: str, call_type: str, key: str, callback: Any + ) -> None: + self._get_instance(instance_id).add_callback_with_key(call_type, key, callback) + + async def remove_callbacks_with_key( + self, instance_id: str, call_type: str, key: str + ) -> None: + self._get_instance(instance_id).remove_callbacks_with_key(call_type, key) + + async def get_callbacks( + self, instance_id: str, call_type: str = None, key: str = None + ) -> Any: + if call_type is None and key is None: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "callbacks", {}) + ) + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_callbacks(call_type, key) + ) + + async def get_all_callbacks(self, instance_id: str, call_type: str = None) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "get_all_callbacks", lambda x: [])( + call_type + ) + ) + + async def set_attachments( + self, instance_id: str, key: str, attachment: Any + ) -> None: + self._get_instance(instance_id).set_attachments(key, attachment) + + async def get_attachment(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_attachment(key) + ) + + async def remove_attachments(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_attachments(key) + + async def set_injections(self, instance_id: str, key: str, injections: Any) -> None: + self._get_instance(instance_id).set_injections(key, injections) + + async def get_injections(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_injections(key) + ) + + async def remove_injections(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_injections(key) + + async def set_additional_models( + self, instance_id: str, key: str, models: Any + ) -> None: + self._get_instance(instance_id).set_additional_models(key, models) + + async def remove_additional_models(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_additional_models(key) + + async def get_nested_additional_models(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_nested_additional_models() + ) + + async def get_additional_models(self, instance_id: str) -> List[str]: + models = self._get_instance(instance_id).get_additional_models() + return [self.register(m) for m in models] + + async def get_additional_models_with_key(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_additional_models_with_key(key) + ) + + async def model_patches_models(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).model_patches_models() + ) + + async def get_patches(self, instance_id: str) -> Any: + return self._sanitize_rpc_result(self._get_instance(instance_id).patches.copy()) + + async def get_object_patches(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).object_patches.copy() + ) + + async def add_patches( + self, + instance_id: str, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> Any: + return self._get_instance(instance_id).add_patches( + patches, strength_patch, strength_model + ) + + async def get_key_patches( + self, instance_id: str, filter_prefix: Optional[str] = None + ) -> Any: + res = self._get_instance(instance_id).get_key_patches() + if filter_prefix: + res = {k: v for k, v in res.items() if k.startswith(filter_prefix)} + safe_res = {} + for k, v in res.items(): + safe_res[k] = [ + f"" + if hasattr(t, "shape") + else str(t) + for t in v + ] + return safe_res + + async def add_hook_patches( + self, + instance_id: str, + hook: Any, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> None: + if hasattr(hook, "hook_ref") and isinstance(hook.hook_ref, dict): + try: + hook.hook_ref = tuple(sorted(hook.hook_ref.items())) + except Exception: + hook.hook_ref = None + self._get_instance(instance_id).add_hook_patches( + hook, patches, strength_patch, strength_model + ) + + async def get_combined_hook_patches(self, instance_id: str, hooks: Any) -> Any: + if hooks is not None and hasattr(hooks, "hooks"): + for hook in getattr(hooks, "hooks", []): + hook_ref = getattr(hook, "hook_ref", None) + if isinstance(hook_ref, dict): + try: + hook.hook_ref = tuple(sorted(hook_ref.items())) + except Exception: + hook.hook_ref = None + res = self._get_instance(instance_id).get_combined_hook_patches(hooks) + return self._sanitize_rpc_result(res) + + async def clear_cached_hook_weights(self, instance_id: str) -> None: + self._get_instance(instance_id).clear_cached_hook_weights() + + async def prepare_hook_patches_current_keyframe( + self, instance_id: str, t: Any, hook_group: Any, model_options: Any + ) -> None: + self._get_instance(instance_id).prepare_hook_patches_current_keyframe( + t, hook_group, model_options + ) + + async def get_parent(self, instance_id: str) -> Any: + return getattr(self._get_instance(instance_id), "parent", None) + + async def patch_weight_to_device( + self, + instance_id: str, + key: str, + device_to: Any = None, + inplace_update: bool = False, + ) -> None: + self._get_instance(instance_id).patch_weight_to_device( + key, device_to, inplace_update + ) + + async def pin_weight_to_device(self, instance_id: str, key: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.pin_weight_to_device(key) + + async def unpin_weight(self, instance_id: str, key: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.unpin_weight(key) + + async def unpin_all_weights(self, instance_id: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.unpin_all_weights() + + async def calculate_weight( + self, + instance_id: str, + patches: Any, + weight: Any, + key: str, + intermediate_dtype: Any = float, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).calculate_weight( + patches, weight, key, intermediate_dtype + ) + ) + + async def get_inner_model_attr(self, instance_id: str, name: str) -> Any: + try: + value = getattr(self._get_instance(instance_id).model, name) + if name == "model_config": + value = self._extract_model_config(value) + return self._sanitize_rpc_result(value) + except AttributeError: + return None + + @staticmethod + def _extract_model_config(config: Any) -> dict: + """Extract JSON-safe attributes from a model config object. + + ComfyUI model config classes (supported_models_base.BASE subclasses) + have a permissive __getattr__ that returns None for any unknown + attribute instead of raising AttributeError. This defeats hasattr-based + duck-typing in _sanitize_rpc_result, causing TypeError when it tries + to call obj.items() (which resolves to None). We extract the real + class-level and instance-level attributes into a plain dict. + """ + # Attributes consumed by ModelSampling*.__init__ and other callers + _CONFIG_KEYS = ( + "sampling_settings", + "unet_config", + "unet_extra_config", + "latent_format", + "manual_cast_dtype", + "custom_operations", + "optimizations", + "memory_usage_factor", + "supported_inference_dtypes", + ) + result: dict = {} + for key in _CONFIG_KEYS: + # Use type(config).__dict__ first (class attrs), then instance __dict__ + # to avoid triggering the permissive __getattr__ + if key in type(config).__dict__: + val = type(config).__dict__[key] + # Skip classmethods/staticmethods/descriptors + if not callable(val) or isinstance(val, (dict, list, tuple)): + result[key] = val + elif hasattr(config, "__dict__") and key in config.__dict__: + result[key] = config.__dict__[key] + # Also include instance overrides (e.g. set_inference_dtype sets unet_config['dtype']) + if hasattr(config, "__dict__"): + for key, val in config.__dict__.items(): + if key in _CONFIG_KEYS: + result[key] = val + return result + + async def inner_model_memory_required( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + return self._run_operation_with_lease( + instance_id, + "inner_model_memory_required", + lambda: self._get_instance(instance_id).model.memory_required( + *args, **kwargs + ), + ) + + async def inner_model_extra_conds_shapes( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + return self._run_operation_with_lease( + instance_id, + "inner_model_extra_conds_shapes", + lambda: self._get_instance(instance_id).model.extra_conds_shapes( + *args, **kwargs + ), + ) + + async def inner_model_extra_conds( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + def _invoke() -> Any: + result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs) + try: + import torch + import comfy.conds + except Exception: + return result + + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + if isinstance(obj, comfy.conds.CONDRegular): + return type(obj)(_to_cpu(obj.cond)) + return obj + + return _to_cpu(result) + + return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke) + + async def inner_model_state_dict( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + sd = self._get_instance(instance_id).model.state_dict(*args, **kwargs) + return { + k: {"numel": v.numel(), "element_size": v.element_size()} + for k, v in sd.items() + } + + async def inner_model_apply_model( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + def _invoke() -> Any: + import torch + + instance = self._get_instance(instance_id) + target = getattr(instance, "load_device", None) + if target is None and args and hasattr(args[0], "device"): + target = args[0].device + elif target is None: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + + def _move(obj): + if target is None: + return obj + if isinstance(obj, (tuple, list)): + return type(obj)(_move(o) for o in obj) + if hasattr(obj, "to"): + return obj.to(target) + return obj + + moved_args = tuple(_move(a) for a in args) + moved_kwargs = {k: _move(v) for k, v in kwargs.items()} + result = instance.model.apply_model(*moved_args, **moved_kwargs) + moved_result = detach_if_grad(_move(result)) + + # DynamicVRAM + isolation: returning CUDA tensors across RPC can stall + # at the transport boundary. Marshal dynamic-path results as CPU and let + # the proxy restore device placement in the child process. + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke) + + async def process_latent_in( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + import torch + + def _invoke() -> Any: + instance = self._get_instance(instance_id) + result = detach_if_grad(instance.model.process_latent_in(*args, **kwargs)) + + # DynamicVRAM + isolation: returning CUDA tensors across RPC can stall + # at the transport boundary. Marshal dynamic-path results as CPU and let + # the proxy restore placement when needed. + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(result) + return result + + return self._run_operation_with_lease(instance_id, "process_latent_in", _invoke) + + async def process_latent_out( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + import torch + + def _invoke() -> Any: + instance = self._get_instance(instance_id) + result = instance.model.process_latent_out(*args, **kwargs) + moved_result = None + try: + target = None + if args and hasattr(args[0], "device"): + target = args[0].device + elif kwargs: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + if target is not None and hasattr(result, "to"): + moved_result = detach_if_grad(result.to(target)) + except Exception: + logger.debug( + "process_latent_out: failed to move result to target device", + exc_info=True, + ) + if moved_result is None: + moved_result = detach_if_grad(result) + + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke) + + async def scale_latent_inpaint( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + import torch + + def _invoke() -> Any: + instance = self._get_instance(instance_id) + result = instance.model.scale_latent_inpaint(*args, **kwargs) + moved_result = None + try: + target = None + if args and hasattr(args[0], "device"): + target = args[0].device + elif kwargs: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + if target is not None and hasattr(result, "to"): + moved_result = detach_if_grad(result.to(target)) + except Exception: + logger.debug( + "scale_latent_inpaint: failed to move result to target device", + exc_info=True, + ) + if moved_result is None: + moved_result = detach_if_grad(result) + + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease( + instance_id, "scale_latent_inpaint", _invoke + ) + + async def load_lora( + self, + instance_id: str, + lora_path: str, + strength_model: float, + clip_id: Optional[str] = None, + strength_clip: float = 1.0, + ) -> dict: + import comfy.utils + import comfy.sd + import folder_paths + from comfy.isolation.clip_proxy import CLIPRegistry + + model = self._get_instance(instance_id) + clip = None + if clip_id: + clip = CLIPRegistry()._get_instance(clip_id) + lora_full_path = folder_paths.get_full_path("loras", lora_path) + if lora_full_path is None: + raise ValueError(f"LoRA file not found: {lora_path}") + lora = comfy.utils.load_torch_file(lora_full_path) + new_model, new_clip = comfy.sd.load_lora_for_models( + model, clip, lora, strength_model, strength_clip + ) + new_model_id = self.register(new_model) if new_model else None + new_clip_id = ( + CLIPRegistry().register(new_clip) if (new_clip and clip_id) else None + ) + return {"model_id": new_model_id, "clip_id": new_clip_id} diff --git a/comfy/isolation/model_patcher_proxy_utils.py b/comfy/isolation/model_patcher_proxy_utils.py new file mode 100644 index 000000000..038687f01 --- /dev/null +++ b/comfy/isolation/model_patcher_proxy_utils.py @@ -0,0 +1,156 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access +# Isolation utilities and serializers for ModelPatcherProxy +from __future__ import annotations + +import logging +import os +from typing import Any + +from comfy.cli_args import args + +logger = logging.getLogger(__name__) + + +def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any: + from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + isolation_active = args.use_process_isolation or is_child + + if not isolation_active: + return model_patcher + if is_child: + return model_patcher + if isinstance(model_patcher, ModelPatcherProxy): + return model_patcher + + registry = ModelPatcherRegistry() + model_id = registry.register(model_patcher) + logger.debug(f"Isolated ModelPatcher: {model_id}") + return ModelPatcherProxy(model_id, registry, manage_lifecycle=True) + + +def register_hooks_serializers(registry=None): + from pyisolate._internal.serialization_registry import SerializerRegistry + import comfy.hooks + + if registry is None: + registry = SerializerRegistry.get_instance() + + def serialize_enum(obj): + return {"__enum__": f"{type(obj).__name__}.{obj.name}"} + + def deserialize_enum(data): + cls_name, val_name = data["__enum__"].split(".") + cls = getattr(comfy.hooks, cls_name) + return cls[val_name] + + registry.register("EnumHookType", serialize_enum, deserialize_enum) + registry.register("EnumHookScope", serialize_enum, deserialize_enum) + registry.register("EnumHookMode", serialize_enum, deserialize_enum) + registry.register("EnumWeightTarget", serialize_enum, deserialize_enum) + + def serialize_hook_group(obj): + return {"__type__": "HookGroup", "hooks": obj.hooks} + + def deserialize_hook_group(data): + hg = comfy.hooks.HookGroup() + for h in data["hooks"]: + hg.add(h) + return hg + + registry.register("HookGroup", serialize_hook_group, deserialize_hook_group) + + def serialize_dict_state(obj): + d = obj.__dict__.copy() + d["__type__"] = type(obj).__name__ + if "custom_should_register" in d: + del d["custom_should_register"] + return d + + def deserialize_dict_state_generic(cls): + def _deserialize(data): + h = cls() + h.__dict__.update(data) + return h + + return _deserialize + + def deserialize_hook_keyframe(data): + h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0)) + h.__dict__.update(data) + return h + + registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe) + + def deserialize_hook_keyframe_group(data): + h = comfy.hooks.HookKeyframeGroup() + h.__dict__.update(data) + return h + + registry.register( + "HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group + ) + + def deserialize_hook(data): + h = comfy.hooks.Hook() + h.__dict__.update(data) + return h + + registry.register("Hook", serialize_dict_state, deserialize_hook) + + def deserialize_weight_hook(data): + h = comfy.hooks.WeightHook() + h.__dict__.update(data) + return h + + registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook) + + def serialize_set(obj): + return {"__set__": list(obj)} + + def deserialize_set(data): + return set(data["__set__"]) + + registry.register("set", serialize_set, deserialize_set) + + try: + from comfy.weight_adapter.lora import LoRAAdapter + + def serialize_lora(obj): + return {"weights": {}, "loaded_keys": list(obj.loaded_keys)} + + def deserialize_lora(data): + return LoRAAdapter(set(data["loaded_keys"]), data["weights"]) + + registry.register("LoRAAdapter", serialize_lora, deserialize_lora) + except Exception: + pass + + try: + from comfy.hooks import _HookRef + import uuid + + def serialize_hook_ref(obj): + return { + "__hook_ref__": True, + "id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())), + } + + def deserialize_hook_ref(data): + h = _HookRef() + h._pyisolate_id = data.get("id", str(uuid.uuid4())) + return h + + registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref) + except ImportError: + pass + except Exception as e: + logger.warning(f"Failed to register _HookRef: {e}") + + +try: + register_hooks_serializers() +except Exception as e: + logger.error(f"Failed to initialize hook serializers: {e}") diff --git a/comfy/isolation/model_sampling_proxy.py b/comfy/isolation/model_sampling_proxy.py new file mode 100644 index 000000000..8fbfc5b93 --- /dev/null +++ b/comfy/isolation/model_sampling_proxy.py @@ -0,0 +1,360 @@ +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import asyncio +import logging +import os +import threading +import time +from typing import Any + +from comfy.isolation.proxies.base import ( + BaseProxy, + BaseRegistry, + detach_if_grad, + get_thread_loop, + run_coro_in_new_loop, +) + +logger = logging.getLogger(__name__) + + +def _describe_value(obj: Any) -> str: + try: + import torch + except Exception: + torch = None + try: + if torch is not None and isinstance(obj, torch.Tensor): + return ( + "Tensor(shape=%s,dtype=%s,device=%s,id=%s)" + % (tuple(obj.shape), obj.dtype, obj.device, id(obj)) + ) + except Exception: + pass + return "%s(id=%s)" % (type(obj).__name__, id(obj)) + + +def _prefer_device(*tensors: Any) -> Any: + try: + import torch + except Exception: + return None + for t in tensors: + if isinstance(t, torch.Tensor) and t.is_cuda: + return t.device + for t in tensors: + if isinstance(t, torch.Tensor): + return t.device + return None + + +def _to_device(obj: Any, device: Any) -> Any: + try: + import torch + except Exception: + return obj + if device is None: + return obj + if isinstance(obj, torch.Tensor): + if obj.device != device: + return obj.to(device) + return obj + if isinstance(obj, (list, tuple)): + converted = [_to_device(x, device) for x in obj] + return type(obj)(converted) if isinstance(obj, tuple) else converted + if isinstance(obj, dict): + return {k: _to_device(v, device) for k, v in obj.items()} + return obj + + +def _to_cpu_for_rpc(obj: Any) -> Any: + try: + import torch + except Exception: + return obj + if isinstance(obj, torch.Tensor): + t = obj.detach() if obj.requires_grad else obj + if t.is_cuda: + return t.to("cpu") + return t + if isinstance(obj, (list, tuple)): + converted = [_to_cpu_for_rpc(x) for x in obj] + return type(obj)(converted) if isinstance(obj, tuple) else converted + if isinstance(obj, dict): + return {k: _to_cpu_for_rpc(v) for k, v in obj.items()} + return obj + + +class ModelSamplingRegistry(BaseRegistry[Any]): + _type_prefix = "modelsampling" + + async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.calculate_input(sigma, noise)) + + async def calculate_denoised( + self, instance_id: str, sigma: Any, model_output: Any, model_input: Any + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad( + sampling.calculate_denoised(sigma, model_output, model_input) + ) + + async def noise_scaling( + self, + instance_id: str, + sigma: Any, + noise: Any, + latent_image: Any, + max_denoise: bool = False, + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad( + sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise) + ) + + async def inverse_noise_scaling( + self, instance_id: str, sigma: Any, latent: Any + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent)) + + async def timestep(self, instance_id: str, sigma: Any) -> Any: + sampling = self._get_instance(instance_id) + return sampling.timestep(sigma) + + async def sigma(self, instance_id: str, timestep: Any) -> Any: + sampling = self._get_instance(instance_id) + return sampling.sigma(timestep) + + async def percent_to_sigma(self, instance_id: str, percent: float) -> Any: + sampling = self._get_instance(instance_id) + return sampling.percent_to_sigma(percent) + + async def get_sigma_min(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_min) + + async def get_sigma_max(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_max) + + async def get_sigma_data(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_data) + + async def get_sigmas(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigmas) + + async def set_sigmas(self, instance_id: str, sigmas: Any) -> None: + sampling = self._get_instance(instance_id) + sampling.set_sigmas(sigmas) + + +class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]): + _registry_class = ModelSamplingRegistry + __module__ = "comfy.isolation.model_sampling_proxy" + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is not None: + self._rpc_caller = rpc.create_caller( + ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id() + ) + else: + registry = ModelSamplingRegistry() + + class _LocalCaller: + def calculate_input( + self, instance_id: str, sigma: Any, noise: Any + ) -> Any: + return registry.calculate_input(instance_id, sigma, noise) + + def calculate_denoised( + self, + instance_id: str, + sigma: Any, + model_output: Any, + model_input: Any, + ) -> Any: + return registry.calculate_denoised( + instance_id, sigma, model_output, model_input + ) + + def noise_scaling( + self, + instance_id: str, + sigma: Any, + noise: Any, + latent_image: Any, + max_denoise: bool = False, + ) -> Any: + return registry.noise_scaling( + instance_id, sigma, noise, latent_image, max_denoise + ) + + def inverse_noise_scaling( + self, instance_id: str, sigma: Any, latent: Any + ) -> Any: + return registry.inverse_noise_scaling( + instance_id, sigma, latent + ) + + def timestep(self, instance_id: str, sigma: Any) -> Any: + return registry.timestep(instance_id, sigma) + + def sigma(self, instance_id: str, timestep: Any) -> Any: + return registry.sigma(instance_id, timestep) + + def percent_to_sigma(self, instance_id: str, percent: float) -> Any: + return registry.percent_to_sigma(instance_id, percent) + + def get_sigma_min(self, instance_id: str) -> Any: + return registry.get_sigma_min(instance_id) + + def get_sigma_max(self, instance_id: str) -> Any: + return registry.get_sigma_max(instance_id) + + def get_sigma_data(self, instance_id: str) -> Any: + return registry.get_sigma_data(instance_id) + + def get_sigmas(self, instance_id: str) -> Any: + return registry.get_sigmas(instance_id) + + def set_sigmas(self, instance_id: str, sigmas: Any) -> None: + return registry.set_sigmas(instance_id, sigmas) + + self._rpc_caller = _LocalCaller() + return self._rpc_caller + + def _call(self, method_name: str, *args: Any) -> Any: + rpc = self._get_rpc() + method = getattr(rpc, method_name) + result = method(self._instance_id, *args) + timeout_ms = self._rpc_timeout_ms() + start_epoch = time.time() + start_perf = time.perf_counter() + thread_id = threading.get_ident() + call_id = "%s:%s:%s:%.6f" % ( + self._instance_id, + method_name, + thread_id, + start_perf, + ) + logger.debug( + "ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s", + method_name, + self._instance_id, + call_id, + start_epoch, + thread_id, + timeout_ms, + ) + if asyncio.iscoroutine(result): + result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0) + try: + asyncio.get_running_loop() + out = run_coro_in_new_loop(result) + except RuntimeError: + loop = get_thread_loop() + out = loop.run_until_complete(result) + else: + out = result + logger.debug( + "ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s", + method_name, + self._instance_id, + call_id, + _describe_value(out), + ) + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + logger.debug( + "ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s", + method_name, + self._instance_id, + call_id, + elapsed_ms, + thread_id, + ) + logger.debug( + "ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s", + method_name, + self._instance_id, + call_id, + ) + return out + + @staticmethod + def _rpc_timeout_ms() -> int: + raw = os.environ.get( + "COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS", + os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"), + ) + try: + timeout_ms = int(raw) + except ValueError: + timeout_ms = 30000 + return max(1, timeout_ms) + + @property + def sigma_min(self) -> Any: + return self._call("get_sigma_min") + + @property + def sigma_max(self) -> Any: + return self._call("get_sigma_max") + + @property + def sigma_data(self) -> Any: + return self._call("get_sigma_data") + + @property + def sigmas(self) -> Any: + return self._call("get_sigmas") + + def calculate_input(self, sigma: Any, noise: Any) -> Any: + return self._call("calculate_input", sigma, noise) + + def calculate_denoised( + self, sigma: Any, model_output: Any, model_input: Any + ) -> Any: + return self._call("calculate_denoised", sigma, model_output, model_input) + + def noise_scaling( + self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False + ) -> Any: + preferred_device = _prefer_device(noise, latent_image) + out = self._call( + "noise_scaling", + _to_cpu_for_rpc(sigma), + _to_cpu_for_rpc(noise), + _to_cpu_for_rpc(latent_image), + max_denoise, + ) + return _to_device(out, preferred_device) + + def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any: + preferred_device = _prefer_device(latent) + out = self._call( + "inverse_noise_scaling", + _to_cpu_for_rpc(sigma), + _to_cpu_for_rpc(latent), + ) + return _to_device(out, preferred_device) + + def timestep(self, sigma: Any) -> Any: + return self._call("timestep", sigma) + + def sigma(self, timestep: Any) -> Any: + return self._call("sigma", timestep) + + def percent_to_sigma(self, percent: float) -> Any: + return self._call("percent_to_sigma", percent) + + def set_sigmas(self, sigmas: Any) -> None: + return self._call("set_sigmas", sigmas) diff --git a/comfy/isolation/proxies/__init__.py b/comfy/isolation/proxies/__init__.py new file mode 100644 index 000000000..30d0089ad --- /dev/null +++ b/comfy/isolation/proxies/__init__.py @@ -0,0 +1,17 @@ +from .base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, + get_thread_loop, + run_coro_in_new_loop, +) + +__all__ = [ + "IS_CHILD_PROCESS", + "BaseRegistry", + "BaseProxy", + "get_thread_loop", + "run_coro_in_new_loop", + "detach_if_grad", +] diff --git a/comfy/isolation/proxies/base.py b/comfy/isolation/proxies/base.py new file mode 100644 index 000000000..498554217 --- /dev/null +++ b/comfy/isolation/proxies/base.py @@ -0,0 +1,301 @@ +# pylint: disable=global-statement,import-outside-toplevel,protected-access +from __future__ import annotations + +import asyncio +import concurrent.futures +import logging +import os +import threading +import time +import weakref +from typing import Any, Callable, Dict, Generic, Optional, TypeVar + +try: + from pyisolate import ProxiedSingleton +except ImportError: + + class ProxiedSingleton: # type: ignore[no-redef] + pass + + +logger = logging.getLogger(__name__) + +IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1" +_thread_local = threading.local() +T = TypeVar("T") + + +def get_thread_loop() -> asyncio.AbstractEventLoop: + loop = getattr(_thread_local, "loop", None) + if loop is None or loop.is_closed(): + loop = asyncio.new_event_loop() + _thread_local.loop = loop + return loop + + +def run_coro_in_new_loop(coro: Any) -> Any: + result_box: Dict[str, Any] = {} + exc_box: Dict[str, BaseException] = {} + + def runner() -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result_box["value"] = loop.run_until_complete(coro) + except Exception as exc: # noqa: BLE001 + exc_box["exc"] = exc + finally: + loop.close() + + t = threading.Thread(target=runner, daemon=True) + t.start() + t.join() + if "exc" in exc_box: + raise exc_box["exc"] + return result_box.get("value") + + +def detach_if_grad(obj: Any) -> Any: + try: + import torch + except Exception: + return obj + + if isinstance(obj, torch.Tensor): + return obj.detach() if obj.requires_grad else obj + if isinstance(obj, (list, tuple)): + return type(obj)(detach_if_grad(x) for x in obj) + if isinstance(obj, dict): + return {k: detach_if_grad(v) for k, v in obj.items()} + return obj + + +class BaseRegistry(ProxiedSingleton, Generic[T]): + _type_prefix: str = "base" + + def __init__(self) -> None: + if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object: + super().__init__() + self._registry: Dict[str, T] = {} + self._id_map: Dict[int, str] = {} + self._counter = 0 + self._lock = threading.Lock() + + def register(self, instance: T) -> str: + with self._lock: + obj_id = id(instance) + if obj_id in self._id_map: + return self._id_map[obj_id] + instance_id = f"{self._type_prefix}_{self._counter}" + self._counter += 1 + self._registry[instance_id] = instance + self._id_map[obj_id] = instance_id + return instance_id + + def unregister_sync(self, instance_id: str) -> None: + with self._lock: + instance = self._registry.pop(instance_id, None) + if instance: + self._id_map.pop(id(instance), None) + + def _get_instance(self, instance_id: str) -> T: + if IS_CHILD_PROCESS: + raise RuntimeError( + f"[{self.__class__.__name__}] _get_instance called in child" + ) + with self._lock: + instance = self._registry.get(instance_id) + if instance is None: + raise ValueError(f"{instance_id} not found") + return instance + + +_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None + + +def set_global_loop(loop: asyncio.AbstractEventLoop) -> None: + global _GLOBAL_LOOP + _GLOBAL_LOOP = loop + + +def run_sync_rpc_coro(coro: Any, timeout_ms: Optional[int] = None) -> Any: + if timeout_ms is not None: + coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0) + + try: + if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): + try: + curr_loop = asyncio.get_running_loop() + if curr_loop is _GLOBAL_LOOP: + pass + except RuntimeError: + future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP) + return future.result( + timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None + ) + + try: + asyncio.get_running_loop() + return run_coro_in_new_loop(coro) + except RuntimeError: + loop = get_thread_loop() + return loop.run_until_complete(coro) + except asyncio.TimeoutError as exc: + raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc + except concurrent.futures.TimeoutError as exc: + raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc + + +def call_singleton_rpc( + caller: Any, + method_name: str, + *args: Any, + timeout_ms: Optional[int] = None, + **kwargs: Any, +) -> Any: + if caller is None: + raise RuntimeError(f"No RPC caller available for {method_name}") + method = getattr(caller, method_name) + return run_sync_rpc_coro(method(*args, **kwargs), timeout_ms=timeout_ms) + + +class BaseProxy(Generic[T]): + _registry_class: type = BaseRegistry # type: ignore[type-arg] + __module__: str = "comfy.isolation.proxies.base" + _TIMEOUT_RPC_METHODS = frozenset( + { + "partially_load", + "partially_unload", + "load", + "patch_model", + "unpatch_model", + "inner_model_apply_model", + "memory_required", + "model_dtype", + "inner_model_memory_required", + "inner_model_extra_conds_shapes", + "inner_model_extra_conds", + "process_latent_in", + "process_latent_out", + "scale_latent_inpaint", + } + ) + + def __init__( + self, + instance_id: str, + registry: Optional[Any] = None, + manage_lifecycle: bool = False, + ) -> None: + self._instance_id = instance_id + self._rpc_caller: Optional[Any] = None + self._registry = registry if registry is not None else self._registry_class() + self._manage_lifecycle = manage_lifecycle + self._cleaned_up = False + if manage_lifecycle and not IS_CHILD_PROCESS: + self._finalizer = weakref.finalize( + self, self._registry.unregister_sync, instance_id + ) + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is None: + raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child") + self._rpc_caller = rpc.create_caller( + self._registry_class, self._registry_class.get_remote_id() + ) + return self._rpc_caller + + def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]: + if method_name not in self._TIMEOUT_RPC_METHODS: + return None + try: + timeout_ms = int( + os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000") + ) + except ValueError: + timeout_ms = 120000 + return max(1, timeout_ms) + + def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + rpc = self._get_rpc() + method = getattr(rpc, method_name) + timeout_ms = self._rpc_timeout_ms_for_method(method_name) + coro = method(self._instance_id, *args, **kwargs) + if timeout_ms is not None: + coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0) + + start_epoch = time.time() + start_perf = time.perf_counter() + thread_id = threading.get_ident() + try: + running_loop = asyncio.get_running_loop() + loop_id: Optional[int] = id(running_loop) + except RuntimeError: + loop_id = None + logger.debug( + "ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f " + "thread=%s loop=%s timeout_ms=%s", + self.__class__.__name__, + method_name, + self._instance_id, + start_epoch, + thread_id, + loop_id, + timeout_ms, + ) + + try: + return run_sync_rpc_coro(coro, timeout_ms=timeout_ms) + except TimeoutError as exc: + raise TimeoutError( + f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} " + f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})" + ) from exc + finally: + end_epoch = time.time() + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + logger.debug( + "ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f " + "elapsed_ms=%.3f thread=%s loop=%s", + self.__class__.__name__, + method_name, + self._instance_id, + end_epoch, + elapsed_ms, + thread_id, + loop_id, + ) + + def __getstate__(self) -> Dict[str, Any]: + return {"_instance_id": self._instance_id} + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._instance_id = state["_instance_id"] + self._rpc_caller = None + self._registry = self._registry_class() + self._manage_lifecycle = False + self._cleaned_up = False + + def cleanup(self) -> None: + if self._cleaned_up or IS_CHILD_PROCESS: + return + self._cleaned_up = True + finalizer = getattr(self, "_finalizer", None) + if finalizer is not None: + finalizer.detach() + self._registry.unregister_sync(self._instance_id) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self._instance_id}>" + + +def create_rpc_method(method_name: str) -> Callable[..., Any]: + def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any: + return self._call_rpc(method_name, *args, **kwargs) + + method.__name__ = method_name + return method diff --git a/comfy/isolation/proxies/folder_paths_proxy.py b/comfy/isolation/proxies/folder_paths_proxy.py new file mode 100644 index 000000000..b324da4e5 --- /dev/null +++ b/comfy/isolation/proxies/folder_paths_proxy.py @@ -0,0 +1,221 @@ +from __future__ import annotations +import logging +import os +import traceback +from typing import Any, Dict, Optional + +from pyisolate import ProxiedSingleton + +from .base import call_singleton_rpc + +_fp_logger = logging.getLogger(__name__) + + +def _folder_paths(): + import folder_paths + + return folder_paths + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def _serialize_folder_names_and_paths(data: dict[str, tuple[list[str], set[str]]]) -> dict[str, dict[str, list[str]]]: + return { + key: {"paths": list(paths), "extensions": sorted(list(extensions))} + for key, (paths, extensions) in data.items() + } + + +def _deserialize_folder_names_and_paths(data: dict[str, dict[str, list[str]]]) -> dict[str, tuple[list[str], set[str]]]: + return { + key: (list(value.get("paths", [])), set(value.get("extensions", []))) + for key, value in data.items() + } + + +class FolderPathsProxy(ProxiedSingleton): + """ + Dynamic proxy for folder_paths. + Uses __getattr__ for most lookups, with explicit handling for + mutable collections to ensure efficient by-value transfer. + """ + + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("FolderPathsProxy RPC caller is not configured") + return cls._rpc + + def __getattr__(self, name): + if _is_child_process(): + property_rpc = { + "models_dir": "rpc_get_models_dir", + "folder_names_and_paths": "rpc_get_folder_names_and_paths", + "extension_mimetypes_cache": "rpc_get_extension_mimetypes_cache", + "filename_list_cache": "rpc_get_filename_list_cache", + } + rpc_name = property_rpc.get(name) + if rpc_name is not None: + return call_singleton_rpc(self._get_caller(), rpc_name) + raise AttributeError(name) + return getattr(_folder_paths(), name) + + @property + def folder_names_and_paths(self) -> Dict: + if _is_child_process(): + payload = call_singleton_rpc(self._get_caller(), "rpc_get_folder_names_and_paths") + return _deserialize_folder_names_and_paths(payload) + return _folder_paths().folder_names_and_paths + + @property + def extension_mimetypes_cache(self) -> Dict: + if _is_child_process(): + return dict(call_singleton_rpc(self._get_caller(), "rpc_get_extension_mimetypes_cache")) + return dict(_folder_paths().extension_mimetypes_cache) + + @property + def filename_list_cache(self) -> Dict: + if _is_child_process(): + return dict(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list_cache")) + return dict(_folder_paths().filename_list_cache) + + @property + def models_dir(self) -> str: + if _is_child_process(): + return str(call_singleton_rpc(self._get_caller(), "rpc_get_models_dir")) + return _folder_paths().models_dir + + def get_temp_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_temp_directory") + return _folder_paths().get_temp_directory() + + def get_input_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_input_directory") + return _folder_paths().get_input_directory() + + def get_output_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_output_directory") + return _folder_paths().get_output_directory() + + def get_user_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_user_directory") + return _folder_paths().get_user_directory() + + def get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str: + if _is_child_process(): + return call_singleton_rpc( + self._get_caller(), "rpc_get_annotated_filepath", name, default_dir + ) + return _folder_paths().get_annotated_filepath(name, default_dir) + + def exists_annotated_filepath(self, name: str) -> bool: + if _is_child_process(): + return bool( + call_singleton_rpc(self._get_caller(), "rpc_exists_annotated_filepath", name) + ) + return bool(_folder_paths().exists_annotated_filepath(name)) + + def add_model_folder_path( + self, folder_name: str, full_folder_path: str, is_default: bool = False + ) -> None: + if _is_child_process(): + call_singleton_rpc( + self._get_caller(), + "rpc_add_model_folder_path", + folder_name, + full_folder_path, + is_default, + ) + return None + _folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default) + return None + + def get_folder_paths(self, folder_name: str) -> list[str]: + if _is_child_process(): + return list(call_singleton_rpc(self._get_caller(), "rpc_get_folder_paths", folder_name)) + return list(_folder_paths().get_folder_paths(folder_name)) + + def get_filename_list(self, folder_name: str) -> list[str]: + caller_stack = "".join(traceback.format_stack()[-4:-1]) + _fp_logger.warning( + "][ DIAG:FolderPathsProxy.get_filename_list called | folder=%s | is_child=%s | rpc_configured=%s\n%s", + folder_name, _is_child_process(), self._rpc is not None, caller_stack, + ) + if _is_child_process(): + result = list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name)) + _fp_logger.warning( + "][ DIAG:FolderPathsProxy.get_filename_list RPC result | folder=%s | count=%d | first=%s", + folder_name, len(result), result[:3] if result else "EMPTY", + ) + return result + result = list(_folder_paths().get_filename_list(folder_name)) + _fp_logger.warning( + "][ DIAG:FolderPathsProxy.get_filename_list LOCAL result | folder=%s | count=%d | first=%s", + folder_name, len(result), result[:3] if result else "EMPTY", + ) + return result + + def get_full_path(self, folder_name: str, filename: str) -> str | None: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_full_path", folder_name, filename) + return _folder_paths().get_full_path(folder_name, filename) + + async def rpc_get_models_dir(self) -> str: + return _folder_paths().models_dir + + async def rpc_get_folder_names_and_paths(self) -> dict[str, dict[str, list[str]]]: + return _serialize_folder_names_and_paths(_folder_paths().folder_names_and_paths) + + async def rpc_get_extension_mimetypes_cache(self) -> dict[str, Any]: + return dict(_folder_paths().extension_mimetypes_cache) + + async def rpc_get_filename_list_cache(self) -> dict[str, Any]: + return dict(_folder_paths().filename_list_cache) + + async def rpc_get_temp_directory(self) -> str: + return _folder_paths().get_temp_directory() + + async def rpc_get_input_directory(self) -> str: + return _folder_paths().get_input_directory() + + async def rpc_get_output_directory(self) -> str: + return _folder_paths().get_output_directory() + + async def rpc_get_user_directory(self) -> str: + return _folder_paths().get_user_directory() + + async def rpc_get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str: + return _folder_paths().get_annotated_filepath(name, default_dir) + + async def rpc_exists_annotated_filepath(self, name: str) -> bool: + return _folder_paths().exists_annotated_filepath(name) + + async def rpc_add_model_folder_path( + self, folder_name: str, full_folder_path: str, is_default: bool = False + ) -> None: + _folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default) + + async def rpc_get_folder_paths(self, folder_name: str) -> list[str]: + return _folder_paths().get_folder_paths(folder_name) + + async def rpc_get_filename_list(self, folder_name: str) -> list[str]: + return _folder_paths().get_filename_list(folder_name) + + async def rpc_get_full_path(self, folder_name: str, filename: str) -> str | None: + return _folder_paths().get_full_path(folder_name, filename) diff --git a/comfy/isolation/proxies/helper_proxies.py b/comfy/isolation/proxies/helper_proxies.py new file mode 100644 index 000000000..278c098f1 --- /dev/null +++ b/comfy/isolation/proxies/helper_proxies.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, Optional + +from pyisolate import ProxiedSingleton + +from .base import call_singleton_rpc + + +class AnyTypeProxy(str): + """Replacement for custom AnyType objects used by some nodes.""" + + def __new__(cls, value: str = "*"): + return super().__new__(cls, value) + + def __ne__(self, other): # type: ignore[override] + return False + + +class FlexibleOptionalInputProxy(dict): + """Replacement for FlexibleOptionalInputType to allow dynamic inputs.""" + + def __init__(self, flex_type, data: Optional[Dict[str, object]] = None): + super().__init__() + self.type = flex_type + if data: + self.update(data) + + def __getitem__(self, key): # type: ignore[override] + return (self.type,) + + def __contains__(self, key): # type: ignore[override] + return True + + +class ByPassTypeTupleProxy(tuple): + """Replacement for ByPassTypeTuple to mirror wildcard fallback behavior.""" + + def __new__(cls, values): + return super().__new__(cls, values) + + def __getitem__(self, index): # type: ignore[override] + if index >= len(self): + return AnyTypeProxy("*") + return super().__getitem__(index) + + +def _restore_special_value(value: Any) -> Any: + if isinstance(value, dict): + if value.get("__pyisolate_any_type__"): + return AnyTypeProxy(value.get("value", "*")) + if value.get("__pyisolate_flexible_optional__"): + flex_type = _restore_special_value(value.get("type")) + data_raw = value.get("data") + data = ( + {k: _restore_special_value(v) for k, v in data_raw.items()} + if isinstance(data_raw, dict) + else {} + ) + return FlexibleOptionalInputProxy(flex_type, data) + if value.get("__pyisolate_tuple__") is not None: + return tuple( + _restore_special_value(v) for v in value["__pyisolate_tuple__"] + ) + if value.get("__pyisolate_bypass_tuple__") is not None: + return ByPassTypeTupleProxy( + tuple( + _restore_special_value(v) + for v in value["__pyisolate_bypass_tuple__"] + ) + ) + return {k: _restore_special_value(v) for k, v in value.items()} + if isinstance(value, list): + return [_restore_special_value(v) for v in value] + return value + + +def _serialize_special_value(value: Any) -> Any: + if isinstance(value, AnyTypeProxy): + return {"__pyisolate_any_type__": True, "value": str(value)} + if isinstance(value, FlexibleOptionalInputProxy): + return { + "__pyisolate_flexible_optional__": True, + "type": _serialize_special_value(value.type), + "data": {k: _serialize_special_value(v) for k, v in value.items()}, + } + if isinstance(value, ByPassTypeTupleProxy): + return { + "__pyisolate_bypass_tuple__": [_serialize_special_value(v) for v in value] + } + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_serialize_special_value(v) for v in value]} + if isinstance(value, list): + return [_serialize_special_value(v) for v in value] + if isinstance(value, dict): + return {k: _serialize_special_value(v) for k, v in value.items()} + return value + + +def _restore_input_types_local(raw: Dict[str, object]) -> Dict[str, object]: + if not isinstance(raw, dict): + return raw # type: ignore[return-value] + + restored: Dict[str, object] = {} + for section, entries in raw.items(): + if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"): + restored[section] = _restore_special_value(entries) + elif isinstance(entries, dict): + restored[section] = { + k: _restore_special_value(v) for k, v in entries.items() + } + else: + restored[section] = _restore_special_value(entries) + return restored + + +class HelperProxiesService(ProxiedSingleton): + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("HelperProxiesService RPC caller is not configured") + return cls._rpc + + async def rpc_restore_input_types(self, raw: Dict[str, object]) -> Dict[str, object]: + restored = _restore_input_types_local(raw) + return _serialize_special_value(restored) + + +def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]: + """Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects.""" + if os.environ.get("PYISOLATE_CHILD") == "1": + payload = call_singleton_rpc( + HelperProxiesService._get_caller(), + "rpc_restore_input_types", + raw, + ) + return _restore_input_types_local(payload) + return _restore_input_types_local(raw) + + +__all__ = [ + "AnyTypeProxy", + "FlexibleOptionalInputProxy", + "ByPassTypeTupleProxy", + "HelperProxiesService", + "restore_input_types", +] diff --git a/comfy/isolation/proxies/model_management_proxy.py b/comfy/isolation/proxies/model_management_proxy.py new file mode 100644 index 000000000..445210aa4 --- /dev/null +++ b/comfy/isolation/proxies/model_management_proxy.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import os +from typing import Any, Optional + +from pyisolate import ProxiedSingleton + +from .base import call_singleton_rpc + + +def _mm(): + import comfy.model_management + + return comfy.model_management + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +class TorchDeviceProxy: + def __init__(self, device_str: str): + self._device_str = device_str + if ":" in device_str: + device_type, index = device_str.split(":", 1) + self.type = device_type + self.index = int(index) + else: + self.type = device_str + self.index = None + + def __str__(self) -> str: + return self._device_str + + def __repr__(self) -> str: + return f"TorchDeviceProxy({self._device_str!r})" + + +def _serialize_value(value: Any) -> Any: + value_type = type(value) + if value_type.__module__ == "torch" and value_type.__name__ == "device": + return {"__pyisolate_torch_device__": str(value)} + if isinstance(value, TorchDeviceProxy): + return {"__pyisolate_torch_device__": str(value)} + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_serialize_value(item) for item in value]} + if isinstance(value, list): + return [_serialize_value(item) for item in value] + if isinstance(value, dict): + return {key: _serialize_value(inner) for key, inner in value.items()} + return value + + +def _deserialize_value(value: Any) -> Any: + if isinstance(value, dict): + if "__pyisolate_torch_device__" in value: + return TorchDeviceProxy(value["__pyisolate_torch_device__"]) + if "__pyisolate_tuple__" in value: + return tuple(_deserialize_value(item) for item in value["__pyisolate_tuple__"]) + return {key: _deserialize_value(inner) for key, inner in value.items()} + if isinstance(value, list): + return [_deserialize_value(item) for item in value] + return value + + +def _normalize_argument(value: Any) -> Any: + if isinstance(value, TorchDeviceProxy): + import torch + + return torch.device(str(value)) + if isinstance(value, dict): + if "__pyisolate_torch_device__" in value: + import torch + + return torch.device(value["__pyisolate_torch_device__"]) + if "__pyisolate_tuple__" in value: + return tuple(_normalize_argument(item) for item in value["__pyisolate_tuple__"]) + return {key: _normalize_argument(inner) for key, inner in value.items()} + if isinstance(value, list): + return [_normalize_argument(item) for item in value] + return value + + +class ModelManagementProxy(ProxiedSingleton): + """ + Exact-relay proxy for comfy.model_management. + Child calls never import comfy.model_management directly; they serialize + arguments, relay to host, and deserialize the host result back. + """ + + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("ModelManagementProxy RPC caller is not configured") + return cls._rpc + + def _relay_call(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + payload = call_singleton_rpc( + self._get_caller(), + "rpc_call", + method_name, + _serialize_value(args), + _serialize_value(kwargs), + ) + return _deserialize_value(payload) + + @property + def VRAMState(self): + return _mm().VRAMState + + @property + def CPUState(self): + return _mm().CPUState + + @property + def OOM_EXCEPTION(self): + return _mm().OOM_EXCEPTION + + def __getattr__(self, name: str): + if _is_child_process(): + def child_method(*args: Any, **kwargs: Any) -> Any: + return self._relay_call(name, *args, **kwargs) + + return child_method + return getattr(_mm(), name) + + async def rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any: + normalized_args = _normalize_argument(_deserialize_value(args)) + normalized_kwargs = _normalize_argument(_deserialize_value(kwargs)) + method = getattr(_mm(), method_name) + result = method(*normalized_args, **normalized_kwargs) + return _serialize_value(result) diff --git a/comfy/isolation/proxies/progress_proxy.py b/comfy/isolation/proxies/progress_proxy.py new file mode 100644 index 000000000..8f270afa0 --- /dev/null +++ b/comfy/isolation/proxies/progress_proxy.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, Optional + +try: + from pyisolate import ProxiedSingleton +except ImportError: + + class ProxiedSingleton: + pass + +from .base import call_singleton_rpc + + +def _get_progress_state(): + from comfy_execution.progress import get_progress_state + + return get_progress_state() + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + +logger = logging.getLogger(__name__) + + +class ProgressProxy(ProxiedSingleton): + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("ProgressProxy RPC caller is not configured") + return cls._rpc + + def set_progress( + self, + value: float, + max_value: float, + node_id: Optional[str] = None, + image: Any = None, + ) -> None: + if _is_child_process(): + call_singleton_rpc( + self._get_caller(), + "rpc_set_progress", + value, + max_value, + node_id, + image, + ) + return None + + _get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=image, + ) + return None + + async def rpc_set_progress( + self, + value: float, + max_value: float, + node_id: Optional[str] = None, + image: Any = None, + ) -> None: + _get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=image, + ) + + +__all__ = ["ProgressProxy"] diff --git a/comfy/isolation/proxies/prompt_server_impl.py b/comfy/isolation/proxies/prompt_server_impl.py new file mode 100644 index 000000000..3f500522e --- /dev/null +++ b/comfy/isolation/proxies/prompt_server_impl.py @@ -0,0 +1,271 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called +"""Stateless RPC Implementation for PromptServer. + +Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture. +- Host: PromptServerService (RPC Handler) +- Child: PromptServerStub (Interface Implementation) +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any, Dict, Optional, Callable + +import logging + +# IMPORTS +from pyisolate import ProxiedSingleton +from .base import call_singleton_rpc + +logger = logging.getLogger(__name__) +LOG_PREFIX = "[Isolation:C<->H]" + +# ... + +# ============================================================================= +# CHILD SIDE: PromptServerStub +# ============================================================================= + + +class PromptServerStub: + """Stateless Stub for PromptServer.""" + + # Masquerade as the real server module + __module__ = "server" + + _instance: Optional["PromptServerStub"] = None + _rpc: Optional[Any] = None # This will be the Caller object + _source_file: Optional[str] = None + + def __init__(self): + self.routes = RouteStub(self) + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + """Inject RPC client (called by adapter.py or manually).""" + # Create caller for HOST Service + # Assuming Host Service is registered as "PromptServerService" (class name) + # We target the Host Service Class + target_id = "PromptServerService" + # We need to pass a class to create_caller? Usually yes. + # But we don't have the Service class imported here necessarily (if running on child). + # pyisolate check verify_service type? + # If we pass PromptServerStub as the 'class', it might mismatch if checking types. + # But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub. + # We need a dummy class with right name? + # Or just rely on string ID if create_caller supports it? + # Standard: rpc.create_caller(PromptServerStub, target_id) + # But wait, PromptServerStub is the *Local* class. + # We want to call *Remote* class. + # If we use PromptServerStub as the type, returning object will be typed as PromptServerStub? + # The first arg is 'service_cls'. + cls._rpc = rpc.create_caller( + PromptServerService, target_id + ) # We import Service below? + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + # We need PromptServerService available for the create_caller call? + # Or just use the Stub class if ID matches? + # prompt_server_impl.py defines BOTH. So PromptServerService IS available! + + @property + def instance(self) -> "PromptServerStub": + return self + + # ... Compatibility ... + @classmethod + def _get_source_file(cls) -> str: + if cls._source_file is None: + import folder_paths + + cls._source_file = os.path.join(folder_paths.base_path, "server.py") + return cls._source_file + + @property + def __file__(self) -> str: + return self._get_source_file() + + # --- Properties --- + @property + def client_id(self) -> Optional[str]: + return "isolated_client" + + def supports(self, feature: str) -> bool: + return True + + @property + def app(self): + raise RuntimeError( + "PromptServer.app is not accessible in isolated nodes. Use RPC routes instead." + ) + + @property + def prompt_queue(self): + raise RuntimeError( + "PromptServer.prompt_queue is not accessible in isolated nodes." + ) + + # --- UI Communication (RPC Delegates) --- + async def send_sync( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ) -> None: + if self._rpc: + await self._rpc.ui_send_sync(event, data, sid) + + async def send( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ) -> None: + if self._rpc: + await self._rpc.ui_send(event, data, sid) + + def send_progress_text(self, text: str, node_id: str, sid=None) -> None: + if self._rpc: + # Fire and forget likely needed. If method is async on host, caller invocation returns coroutine. + # We must schedule it? + # Or use fire_remote equivalent? + # Caller object usually proxies calls. If host method is async, it returns coro. + # If we are sync here (send_progress_text checks imply sync usage), we must background it. + # But UtilsProxy hook wrapper creates task. + # Does send_progress_text need to be sync? Yes, node code calls it sync. + import asyncio + + try: + loop = asyncio.get_running_loop() + loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid)) + except RuntimeError: + call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid) + + # --- Route Registration Logic --- + def register_route(self, method: str, path: str, handler: Callable): + """Register a route handler via RPC.""" + if not self._rpc: + logger.error("RPC not initialized in PromptServerStub") + return + + # Fire registration async + try: + loop = asyncio.get_running_loop() + loop.create_task(self._rpc.register_route_rpc(method, path, handler)) + except RuntimeError: + call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler) + + +class RouteStub: + """Simulates aiohttp.web.RouteTableDef.""" + + def __init__(self, stub: PromptServerStub): + self._stub = stub + + def get(self, path: str): + def decorator(handler): + self._stub.register_route("GET", path, handler) + return handler + + return decorator + + def post(self, path: str): + def decorator(handler): + self._stub.register_route("POST", path, handler) + return handler + + return decorator + + def patch(self, path: str): + def decorator(handler): + self._stub.register_route("PATCH", path, handler) + return handler + + return decorator + + def put(self, path: str): + def decorator(handler): + self._stub.register_route("PUT", path, handler) + return handler + + return decorator + + def delete(self, path: str): + def decorator(handler): + self._stub.register_route("DELETE", path, handler) + return handler + + return decorator + + +# ============================================================================= +# HOST SIDE: PromptServerService +# ============================================================================= + + +class PromptServerService(ProxiedSingleton): + """Host-side RPC Service for PromptServer.""" + + def __init__(self): + # We will bind to the real server instance lazily or via global import + pass + + @property + def server(self): + from server import PromptServer + + return PromptServer.instance + + async def ui_send_sync( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ): + await self.server.send_sync(event, data, sid) + + async def ui_send( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ): + await self.server.send(event, data, sid) + + async def ui_send_progress_text(self, text: str, node_id: str, sid=None): + # Made async to be awaitable by RPC layer + self.server.send_progress_text(text, node_id, sid) + + async def register_route_rpc(self, method: str, path: str, child_handler_proxy): + """RPC Target: Register a route that forwards to the Child.""" + from aiohttp import web + logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}") + + async def route_wrapper(request: web.Request) -> web.Response: + # 1. Capture request data + req_data = { + "method": request.method, + "path": request.path, + "query": dict(request.query), + } + if request.can_read_body: + req_data["text"] = await request.text() + + try: + # 2. Call Child Handler via RPC (child_handler_proxy is async callable) + result = await child_handler_proxy(req_data) + + # 3. Serialize Response + return self._serialize_response(result) + except Exception as e: + logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}") + return web.Response(status=500, text=str(e)) + + # Register loop + self.server.app.router.add_route(method, path, route_wrapper) + + def _serialize_response(self, result: Any) -> Any: + """Helper to convert Child result -> web.Response""" + from aiohttp import web + if isinstance(result, web.Response): + return result + # Handle dict (json) + if isinstance(result, dict): + return web.json_response(result) + # Handle string + if isinstance(result, str): + return web.Response(text=result) + # Fallback + return web.Response(text=str(result)) diff --git a/comfy/isolation/proxies/utils_proxy.py b/comfy/isolation/proxies/utils_proxy.py new file mode 100644 index 000000000..f84727bbb --- /dev/null +++ b/comfy/isolation/proxies/utils_proxy.py @@ -0,0 +1,64 @@ +# pylint: disable=cyclic-import,import-outside-toplevel +from __future__ import annotations + +from typing import Optional, Any +from pyisolate import ProxiedSingleton + +import os + + +def _comfy_utils(): + import comfy.utils + return comfy.utils + + +class UtilsProxy(ProxiedSingleton): + """ + Proxy for comfy.utils. + Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates + from isolated nodes reach the host. + """ + + # _instance and __new__ removed to rely on SingletonMetaclass + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + # Create caller using class name as ID (standard for Singletons) + cls._rpc = rpc.create_caller(cls, "UtilsProxy") + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + async def progress_bar_hook( + self, + value: int, + total: int, + preview: Optional[bytes] = None, + node_id: Optional[str] = None, + ) -> Any: + """ + Host-side implementation: forwards the call to the real global hook. + Child-side: this method call is intercepted by RPC and sent to host. + """ + if os.environ.get("PYISOLATE_CHILD") == "1": + if UtilsProxy._rpc is None: + raise RuntimeError("UtilsProxy RPC caller is not configured") + return await UtilsProxy._rpc.progress_bar_hook( + value, total, preview, node_id + ) + + # Host Execution + utils = _comfy_utils() + if utils.PROGRESS_BAR_HOOK is not None: + return utils.PROGRESS_BAR_HOOK(value, total, preview, node_id) + return None + + def set_progress_bar_global_hook(self, hook: Any) -> None: + """Forward hook registration (though usually not needed from child).""" + if os.environ.get("PYISOLATE_CHILD") == "1": + raise RuntimeError( + "UtilsProxy.set_progress_bar_global_hook is not available in child without exact relay support" + ) + _comfy_utils().set_progress_bar_global_hook(hook) diff --git a/comfy/isolation/proxies/web_directory_proxy.py b/comfy/isolation/proxies/web_directory_proxy.py new file mode 100644 index 000000000..3acf3f4fc --- /dev/null +++ b/comfy/isolation/proxies/web_directory_proxy.py @@ -0,0 +1,219 @@ +"""WebDirectoryProxy — serves isolated node web assets via RPC. + +Child side: enumerates and reads files from the extension's web/ directory. +Host side: gets an RPC proxy that fetches file listings and contents on demand. + +Only files with allowed extensions (.js, .html, .css) are served. +Directory traversal is rejected. File contents are base64-encoded for +safe JSON-RPC transport. +""" + +from __future__ import annotations + +import base64 +import logging +import os +from pathlib import Path, PurePosixPath +from typing import Any, Dict, List + +from pyisolate import ProxiedSingleton + +logger = logging.getLogger(__name__) + +ALLOWED_EXTENSIONS = frozenset({".js", ".html", ".css"}) + +MIME_TYPES = { + ".js": "application/javascript", + ".html": "text/html", + ".css": "text/css", +} + + +class WebDirectoryProxy(ProxiedSingleton): + """Proxy for serving isolated extension web directories. + + On the child side, this class has direct filesystem access to the + extension's web/ directory. On the host side, callers get an RPC + proxy whose method calls are forwarded to the child. + """ + + # {extension_name: absolute_path_to_web_dir} + _web_dirs: dict[str, str] = {} + + @classmethod + def register_web_dir(cls, extension_name: str, web_dir_path: str) -> None: + """Register an extension's web directory (child-side only).""" + cls._web_dirs[extension_name] = web_dir_path + logger.info( + "][ WebDirectoryProxy: registered %s -> %s", + extension_name, + web_dir_path, + ) + + def list_web_files(self, extension_name: str) -> List[Dict[str, str]]: + """Return a list of servable files in the extension's web directory. + + Each entry is {"relative_path": "js/foo.js", "content_type": "application/javascript"}. + Only files with allowed extensions are included. + """ + web_dir = self._web_dirs.get(extension_name) + if not web_dir: + return [] + + root = Path(web_dir) + if not root.is_dir(): + return [] + + result: List[Dict[str, str]] = [] + for path in sorted(root.rglob("*")): + if not path.is_file(): + continue + ext = path.suffix.lower() + if ext not in ALLOWED_EXTENSIONS: + continue + rel = path.relative_to(root) + result.append({ + "relative_path": str(PurePosixPath(rel)), + "content_type": MIME_TYPES[ext], + }) + return result + + def get_web_file( + self, extension_name: str, relative_path: str + ) -> Dict[str, Any]: + """Return the contents of a single web file as base64. + + Raises ValueError for traversal attempts or disallowed file types. + Returns {"content": , "content_type": }. + """ + _validate_path(relative_path) + + web_dir = self._web_dirs.get(extension_name) + if not web_dir: + raise FileNotFoundError( + f"No web directory registered for {extension_name}" + ) + + root = Path(web_dir) + target = (root / relative_path).resolve() + + # Ensure resolved path is under the web directory + if not str(target).startswith(str(root.resolve())): + raise ValueError(f"Path escapes web directory: {relative_path}") + + if not target.is_file(): + raise FileNotFoundError(f"File not found: {relative_path}") + + ext = target.suffix.lower() + if ext not in ALLOWED_EXTENSIONS: + raise ValueError(f"Disallowed file type: {ext}") + + content_type = MIME_TYPES[ext] + raw = target.read_bytes() + + return { + "content": base64.b64encode(raw).decode("ascii"), + "content_type": content_type, + } + + +def _validate_path(relative_path: str) -> None: + """Reject directory traversal and absolute paths.""" + if os.path.isabs(relative_path): + raise ValueError(f"Absolute paths are not allowed: {relative_path}") + if ".." in PurePosixPath(relative_path).parts: + raise ValueError(f"Directory traversal is not allowed: {relative_path}") + + +# --------------------------------------------------------------------------- +# Host-side cache and aiohttp handler +# --------------------------------------------------------------------------- + + +class WebDirectoryCache: + """Host-side in-memory cache for proxied web directory contents. + + Populated lazily via RPC calls to the child's WebDirectoryProxy. + Once a file is cached, subsequent requests are served from memory. + """ + + def __init__(self) -> None: + # {extension_name: {relative_path: {"content": bytes, "content_type": str}}} + self._file_cache: dict[str, dict[str, dict[str, Any]]] = {} + # {extension_name: [{"relative_path": str, "content_type": str}, ...]} + self._listing_cache: dict[str, list[dict[str, str]]] = {} + # {extension_name: WebDirectoryProxy (RPC proxy instance)} + self._proxies: dict[str, Any] = {} + + def register_proxy(self, extension_name: str, proxy: Any) -> None: + """Register an RPC proxy for an extension's web directory.""" + self._proxies[extension_name] = proxy + logger.info( + "][ WebDirectoryCache: registered proxy for %s", extension_name + ) + + @property + def extension_names(self) -> list[str]: + return list(self._proxies.keys()) + + def list_files(self, extension_name: str) -> list[dict[str, str]]: + """List servable files for an extension (cached after first call).""" + if extension_name not in self._listing_cache: + proxy = self._proxies.get(extension_name) + if proxy is None: + return [] + try: + self._listing_cache[extension_name] = proxy.list_web_files( + extension_name + ) + except Exception: + logger.warning( + "][ WebDirectoryCache: failed to list files for %s", + extension_name, + exc_info=True, + ) + return [] + return self._listing_cache[extension_name] + + def get_file( + self, extension_name: str, relative_path: str + ) -> dict[str, Any] | None: + """Get file content (cached after first fetch). Returns None on miss.""" + ext_cache = self._file_cache.get(extension_name) + if ext_cache and relative_path in ext_cache: + return ext_cache[relative_path] + + proxy = self._proxies.get(extension_name) + if proxy is None: + return None + + try: + result = proxy.get_web_file(extension_name, relative_path) + except (FileNotFoundError, ValueError): + return None + except Exception: + logger.warning( + "][ WebDirectoryCache: failed to fetch %s/%s", + extension_name, + relative_path, + exc_info=True, + ) + return None + + decoded = { + "content": base64.b64decode(result["content"]), + "content_type": result["content_type"], + } + + if extension_name not in self._file_cache: + self._file_cache[extension_name] = {} + self._file_cache[extension_name][relative_path] = decoded + return decoded + + +# Global cache instance — populated during isolation loading +_web_directory_cache = WebDirectoryCache() + + +def get_web_directory_cache() -> WebDirectoryCache: + return _web_directory_cache diff --git a/comfy/isolation/rpc_bridge.py b/comfy/isolation/rpc_bridge.py new file mode 100644 index 000000000..2beb0f09f --- /dev/null +++ b/comfy/isolation/rpc_bridge.py @@ -0,0 +1,49 @@ +import asyncio +import logging +import threading + +logger = logging.getLogger(__name__) + + +class RpcBridge: + """Minimal helper to run coroutines synchronously inside isolated processes. + + If an event loop is already running, the coroutine is executed on a fresh + thread with its own loop to avoid nested run_until_complete errors. + """ + + def run_sync(self, maybe_coro): + if not asyncio.iscoroutine(maybe_coro): + return maybe_coro + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + result_container = {} + exc_container = {} + + def _runner(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + result_container["value"] = new_loop.run_until_complete(maybe_coro) + except Exception as exc: # pragma: no cover + exc_container["error"] = exc + finally: + try: + new_loop.close() + except Exception: + pass + + t = threading.Thread(target=_runner, daemon=True) + t.start() + t.join() + + if "error" in exc_container: + raise exc_container["error"] + return result_container.get("value") + + return asyncio.run(maybe_coro) diff --git a/comfy/isolation/runtime_helpers.py b/comfy/isolation/runtime_helpers.py new file mode 100644 index 000000000..f56b1859a --- /dev/null +++ b/comfy/isolation/runtime_helpers.py @@ -0,0 +1,471 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member +from __future__ import annotations + +import copy +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set, TYPE_CHECKING + +from .proxies.helper_proxies import restore_input_types +from .shm_forensics import scan_shm_forensics + +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +_ComfyNodeInternal = object +latest_io = None + +if _IMPORT_TORCH: + from comfy_api.internal import _ComfyNodeInternal + from comfy_api.latest import _io as latest_io + +if TYPE_CHECKING: + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + + +class _RemoteObjectRegistryCaller: + def __init__(self, extension: Any) -> None: + self._extension = extension + + def __getattr__(self, method_name: str) -> Any: + async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any: + return await self._extension.call_remote_object_method( + instance_id, + method_name, + *args, + **kwargs, + ) + + return _call + + +def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any: + from pyisolate._internal.remote_handle import RemoteObjectHandle + + if isinstance(value, RemoteObjectHandle): + if value.type_name == "ModelPatcher": + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "VAE": + from comfy.isolation.vae_proxy import VAEProxy + + proxy = VAEProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "CLIP": + from comfy.isolation.clip_proxy import CLIPProxy + + proxy = CLIPProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "ModelSampling": + from comfy.isolation.model_sampling_proxy import ModelSamplingProxy + + proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + return value + + if isinstance(value, dict): + return { + k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items() + } + + if isinstance(value, (list, tuple)): + wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value] + return type(value)(wrapped) + + return value + + +def _resource_snapshot() -> Dict[str, int]: + fd_count = -1 + shm_sender_files = 0 + try: + fd_count = len(os.listdir("/proc/self/fd")) + except Exception: + pass + try: + shm_root = Path("/dev/shm") + if shm_root.exists(): + prefix = f"torch_{os.getpid()}_" + shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*")) + except Exception: + pass + return {"fd_count": fd_count, "shm_sender_files": shm_sender_files} + + +def _tensor_transport_summary(value: Any) -> Dict[str, int]: + summary: Dict[str, int] = { + "tensor_count": 0, + "cpu_tensors": 0, + "cuda_tensors": 0, + "shared_cpu_tensors": 0, + "tensor_bytes": 0, + } + try: + import torch + except Exception: + return summary + + def visit(node: Any) -> None: + if isinstance(node, torch.Tensor): + summary["tensor_count"] += 1 + summary["tensor_bytes"] += int(node.numel() * node.element_size()) + if node.device.type == "cpu": + summary["cpu_tensors"] += 1 + if node.is_shared(): + summary["shared_cpu_tensors"] += 1 + elif node.device.type == "cuda": + summary["cuda_tensors"] += 1 + return + if isinstance(node, dict): + for v in node.values(): + visit(v) + return + if isinstance(node, (list, tuple)): + for v in node: + visit(v) + + visit(value) + return summary + + +def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None: + for key, value in inputs.items(): + key_text = str(key) + if "unique_id" in key_text: + return str(value) + return None + + +def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return + if not callable(flush_tensor_keeper): + return + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + + +def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _detach_shared_cpu_tensors(value: Any) -> Any: + try: + import torch + except Exception: + return value + + if isinstance(value, torch.Tensor): + if value.device.type == "cpu" and value.is_shared(): + clone = value.clone() + if value.requires_grad: + clone.requires_grad_(True) + return clone + return value + if isinstance(value, list): + return [_detach_shared_cpu_tensors(v) for v in value] + if isinstance(value, tuple): + return tuple(_detach_shared_cpu_tensors(v) for v in value) + if isinstance(value, dict): + return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()} + return value + + +def build_stub_class( + node_name: str, + info: Dict[str, object], + extension: "ComfyNodeExtension", + running_extensions: Dict[str, "ComfyNodeExtension"], + logger: logging.Logger, +) -> type: + if latest_io is None: + raise RuntimeError("comfy_api.latest._io is required to build isolation stubs") + is_v3 = bool(info.get("is_v3", False)) + function_name = "_pyisolate_execute" + restored_input_types = restore_input_types(info.get("input_types", {})) + + async def _execute(self, **inputs): + from comfy.isolation import _RUNNING_EXTENSIONS + + # Update BOTH the local dict AND the module-level dict + running_extensions[extension.name] = extension + _RUNNING_EXTENSIONS[extension.name] = extension + prev_child = None + node_unique_id = _extract_hidden_unique_id(inputs) + summary = _tensor_transport_summary(inputs) + resources = _resource_snapshot() + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + summary["tensor_count"], + summary["cpu_tensors"], + summary["cuda_tensors"], + summary["shared_cpu_tensors"], + summary["tensor_bytes"], + resources["fd_count"], + resources["shm_sender_files"], + ) + scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True) + try: + if os.environ.get("PYISOLATE_CHILD") != "1": + _relieve_host_vram_pressure("RUNTIME:pre_execute", logger) + scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True) + from pyisolate._internal.model_serialization import ( + serialize_for_isolation, + deserialize_from_isolation, + ) + + prev_child = os.environ.pop("PYISOLATE_CHILD", None) + logger.debug( + "%s ISO:serialize_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + # Unwrap NodeOutput-like dicts before serialization. + # OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)} + # and the executor may pass this dict as input to downstream nodes. + unwrapped_inputs = {} + for k, v in inputs.items(): + if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v): + result = v.get("result") + if isinstance(result, (tuple, list)) and len(result) > 0: + unwrapped_inputs[k] = result[0] + else: + unwrapped_inputs[k] = result + else: + unwrapped_inputs[k] = v + serialized = serialize_for_isolation(unwrapped_inputs) + logger.debug( + "%s ISO:serialize_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + logger.debug( + "%s ISO:dispatch_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + result = await extension.execute_node(node_name, **serialized) + logger.debug( + "%s ISO:dispatch_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + # Reconstruct NodeOutput if the child serialized one + if isinstance(result, dict) and result.get("__node_output__"): + from comfy_api.latest import io as latest_io + args_raw = result.get("args", ()) + deserialized_args = await deserialize_from_isolation(args_raw, extension) + deserialized_args = _wrap_remote_handles_as_host_proxies( + deserialized_args, extension + ) + deserialized_args = _detach_shared_cpu_tensors(deserialized_args) + ui_raw = result.get("ui") + deserialized_ui = None + if ui_raw is not None: + deserialized_ui = await deserialize_from_isolation(ui_raw, extension) + deserialized_ui = _wrap_remote_handles_as_host_proxies( + deserialized_ui, extension + ) + deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui) + scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) + return latest_io.NodeOutput( + *deserialized_args, + ui=deserialized_ui, + expand=result.get("expand"), + block_execution=result.get("block_execution"), + ) + # OUTPUT_NODE: if sealed worker returned a tuple/list whose first + # element is a {"ui": ...} dict, unwrap it for the executor. + if (isinstance(result, (tuple, list)) and len(result) == 1 + and isinstance(result[0], dict) and "ui" in result[0]): + return result[0] + deserialized = await deserialize_from_isolation(result, extension) + deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension) + scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) + return _detach_shared_cpu_tensors(deserialized) + except ImportError: + return await extension.execute_node(node_name, **inputs) + except Exception: + logger.exception( + "%s ISO:execute_error ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + raise + finally: + if prev_child is not None: + os.environ["PYISOLATE_CHILD"] = prev_child + logger.debug( + "%s ISO:execute_end ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True) + + def _input_types( + cls, + include_hidden: bool = True, + return_schema: bool = False, + live_inputs: Any = None, + ): + if not is_v3: + return restored_input_types + + inputs_copy = copy.deepcopy(restored_input_types) + if not include_hidden: + inputs_copy.pop("hidden", None) + + v3_data: Dict[str, Any] = {"hidden_inputs": {}} + dynamic = inputs_copy.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + + if return_schema: + hidden_vals = info.get("hidden", []) or [] + hidden_enums = [] + for h in hidden_vals: + try: + hidden_enums.append(latest_io.Hidden(h)) + except Exception: + hidden_enums.append(h) + + class SchemaProxy: + hidden = hidden_enums + + return inputs_copy, SchemaProxy, v3_data + return inputs_copy + + def _validate_class(cls): + return True + + def _get_node_info_v1(cls): + node_info = copy.deepcopy(info.get("schema_v1", {})) + relative_python_module = node_info.get("python_module") + if not isinstance(relative_python_module, str) or not relative_python_module: + relative_python_module = f"custom_nodes.{extension.name}" + node_info["python_module"] = relative_python_module + return node_info + + def _get_base_class(cls): + return latest_io.ComfyNode + + attributes: Dict[str, object] = { + "FUNCTION": function_name, + "CATEGORY": info.get("category", ""), + "OUTPUT_NODE": info.get("output_node", False), + "RETURN_TYPES": tuple(info.get("return_types", ()) or ()), + "RETURN_NAMES": info.get("return_names"), + function_name: _execute, + "_pyisolate_extension": extension, + "_pyisolate_node_name": node_name, + "INPUT_TYPES": classmethod(_input_types), + } + + output_is_list = info.get("output_is_list") + if output_is_list is not None: + attributes["OUTPUT_IS_LIST"] = tuple(output_is_list) + + if is_v3: + attributes["VALIDATE_CLASS"] = classmethod(_validate_class) + attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1) + attributes["GET_BASE_CLASS"] = classmethod(_get_base_class) + attributes["DESCRIPTION"] = info.get("description", "") + attributes["EXPERIMENTAL"] = info.get("experimental", False) + attributes["DEPRECATED"] = info.get("deprecated", False) + attributes["API_NODE"] = info.get("api_node", False) + attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False) + attributes["ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False) + attributes["_ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False) + attributes["INPUT_IS_LIST"] = info.get("input_is_list", False) + + class_name = f"PyIsolate_{node_name}".replace(" ", "_") + bases = (_ComfyNodeInternal,) if is_v3 else () + stub_cls = type(class_name, bases, attributes) + + if is_v3: + try: + stub_cls.VALIDATE_CLASS() + except Exception as e: + logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e) + + return stub_cls + + +def get_class_types_for_extension( + extension_name: str, + running_extensions: Dict[str, "ComfyNodeExtension"], + specs: List[Any], +) -> Set[str]: + extension = running_extensions.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in specs: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + return class_types + + +__all__ = ["build_stub_class", "get_class_types_for_extension"] diff --git a/comfy/isolation/shm_forensics.py b/comfy/isolation/shm_forensics.py new file mode 100644 index 000000000..36223505a --- /dev/null +++ b/comfy/isolation/shm_forensics.py @@ -0,0 +1,217 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel +from __future__ import annotations + +import atexit +import hashlib +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + + +def _shm_debug_enabled() -> bool: + return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1" + + +class _SHMForensicsTracker: + def __init__(self) -> None: + self._started = False + self._tracked_files: Set[str] = set() + self._current_model_context: Dict[str, str] = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + + @staticmethod + def _snapshot_shm() -> Set[str]: + shm_path = Path("/dev/shm") + if not shm_path.exists(): + return set() + return {f.name for f in shm_path.glob("torch_*")} + + def start(self) -> None: + if self._started or not _shm_debug_enabled(): + return + self._tracked_files = self._snapshot_shm() + self._started = True + logger.debug( + "%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files) + ) + + def stop(self) -> None: + if not self._started: + return + self.scan("shutdown", refresh_model_context=True) + self._started = False + logger.debug("%s SHM:forensics_disabled", LOG_PREFIX) + + def _compute_model_hash(self, model_patcher: Any) -> str: + try: + model_instance_id = getattr(model_patcher, "_instance_id", None) + if model_instance_id is not None: + model_id_text = str(model_instance_id) + return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text + + import torch + + real_model = ( + model_patcher.model + if hasattr(model_patcher, "model") + else model_patcher + ) + tensor = None + if hasattr(real_model, "parameters"): + for p in real_model.parameters(): + if torch.is_tensor(p) and p.numel() > 0: + tensor = p + break + + if tensor is None: + return "0000" + + flat = tensor.flatten() + values = [] + indices = [0, flat.shape[0] // 2, flat.shape[0] - 1] + for i in indices: + if i < flat.shape[0]: + values.append(flat[i].item()) + + size = 0 + if hasattr(model_patcher, "model_size"): + size = model_patcher.model_size() + sample_str = f"{values}_{id(model_patcher):016x}_{size}" + return hashlib.sha256(sample_str.encode()).hexdigest()[-4:] + except Exception: + return "err!" + + def _get_models_snapshot(self) -> List[Dict[str, Any]]: + try: + import comfy.model_management as model_management + except Exception: + return [] + + snapshot: List[Dict[str, Any]] = [] + try: + for loaded_model in model_management.current_loaded_models: + model = loaded_model.model + if model is None: + continue + if str(getattr(loaded_model, "device", "")) != "cuda:0": + continue + + name = ( + model.model.__class__.__name__ + if hasattr(model, "model") + else type(model).__name__ + ) + model_hash = self._compute_model_hash(model) + model_instance_id = getattr(model, "_instance_id", None) + if model_instance_id is None: + model_instance_id = model_hash + snapshot.append( + { + "name": str(name), + "id": str(model_instance_id), + "hash": str(model_hash or "????"), + "used": bool(getattr(loaded_model, "currently_used", False)), + } + ) + except Exception: + return [] + + return snapshot + + def _update_model_context(self) -> None: + snapshot = self._get_models_snapshot() + selected = None + + used_models = [m for m in snapshot if m.get("used") and m.get("id")] + if used_models: + selected = used_models[-1] + else: + live_models = [m for m in snapshot if m.get("id")] + if live_models: + selected = live_models[-1] + + if selected is None: + self._current_model_context = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + return + + self._current_model_context = { + "id": str(selected.get("id", "unknown")), + "name": str(selected.get("name", "unknown")), + "hash": str(selected.get("hash", "????") or "????"), + } + + def scan(self, marker: str, refresh_model_context: bool = True) -> None: + if not self._started or not _shm_debug_enabled(): + return + + if refresh_model_context: + self._update_model_context() + + current = self._snapshot_shm() + added = current - self._tracked_files + removed = self._tracked_files - current + self._tracked_files = current + + if not added and not removed: + logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker) + return + + for filename in sorted(added): + logger.info("%s SHM:created | %s", LOG_PREFIX, filename) + model_id = self._current_model_context["id"] + if model_id == "unknown": + logger.error( + "%s SHM:model_association_missing | file=%s | reason=no_active_model_context", + LOG_PREFIX, + filename, + ) + else: + logger.info( + "%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s", + LOG_PREFIX, + model_id, + filename, + self._current_model_context["name"], + self._current_model_context["hash"], + ) + + for filename in sorted(removed): + logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename) + + logger.debug( + "%s SHM:scan marker=%s created=%d deleted=%d active=%d", + LOG_PREFIX, + marker, + len(added), + len(removed), + len(self._tracked_files), + ) + + +_TRACKER = _SHMForensicsTracker() + + +def start_shm_forensics() -> None: + _TRACKER.start() + + +def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None: + _TRACKER.scan(marker, refresh_model_context=refresh_model_context) + + +def stop_shm_forensics() -> None: + _TRACKER.stop() + + +atexit.register(stop_shm_forensics) diff --git a/comfy/isolation/vae_proxy.py b/comfy/isolation/vae_proxy.py new file mode 100644 index 000000000..8260d06a3 --- /dev/null +++ b/comfy/isolation/vae_proxy.py @@ -0,0 +1,214 @@ +# pylint: disable=attribute-defined-outside-init +import logging +from typing import Any + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, +) +from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry + +logger = logging.getLogger(__name__) + + +class FirstStageModelRegistry(BaseRegistry[Any]): + _type_prefix = "first_stage_model" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + async def has_property(self, instance_id: str, name: str) -> bool: + obj = self._get_instance(instance_id) + return hasattr(obj, name) + + +class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]): + _registry_class = FirstStageModelRegistry + __module__ = "comfy.ldm.models.autoencoder" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +class VAERegistry(BaseRegistry[Any]): + _type_prefix = "vae" + + async def get_patcher_id(self, instance_id: str) -> str: + vae = self._get_instance(instance_id) + return ModelPatcherRegistry().register(vae.patcher) + + async def get_first_stage_model_id(self, instance_id: str) -> str: + vae = self._get_instance(instance_id) + return FirstStageModelRegistry().register(vae.first_stage_model) + + async def encode(self, instance_id: str, pixels: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).encode(pixels)) + + async def encode_tiled( + self, + instance_id: str, + pixels: Any, + tile_x: int = 512, + tile_y: int = 512, + overlap: int = 64, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).encode_tiled( + pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap + ) + ) + + async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs)) + + async def decode_tiled( + self, + instance_id: str, + samples: Any, + tile_x: int = 64, + tile_y: int = 64, + overlap: int = 16, + **kwargs: Any, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).decode_tiled( + samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs + ) + ) + + async def get_property(self, instance_id: str, name: str) -> Any: + return getattr(self._get_instance(instance_id), name) + + async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int: + return self._get_instance(instance_id).memory_used_encode(shape, dtype) + + async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int: + return self._get_instance(instance_id).memory_used_decode(shape, dtype) + + async def process_input(self, instance_id: str, image: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).process_input(image)) + + async def process_output(self, instance_id: str, image: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).process_output(image)) + + +class VAEProxy(BaseProxy[VAERegistry]): + _registry_class = VAERegistry + __module__ = "comfy.sd" + + @property + def patcher(self) -> ModelPatcherProxy: + if not hasattr(self, "_patcher_proxy"): + patcher_id = self._call_rpc("get_patcher_id") + self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False) + return self._patcher_proxy + + @property + def first_stage_model(self) -> FirstStageModelProxy: + if not hasattr(self, "_first_stage_model_proxy"): + fsm_id = self._call_rpc("get_first_stage_model_id") + self._first_stage_model_proxy = FirstStageModelProxy( + fsm_id, manage_lifecycle=False + ) + return self._first_stage_model_proxy + + @property + def vae_dtype(self) -> Any: + return self._get_property("vae_dtype") + + def encode(self, pixels: Any) -> Any: + return self._call_rpc("encode", pixels) + + def encode_tiled( + self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64 + ) -> Any: + return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap) + + def decode(self, samples: Any, **kwargs: Any) -> Any: + return self._call_rpc("decode", samples, **kwargs) + + def decode_tiled( + self, + samples: Any, + tile_x: int = 64, + tile_y: int = 64, + overlap: int = 16, + **kwargs: Any, + ) -> Any: + return self._call_rpc( + "decode_tiled", samples, tile_x, tile_y, overlap, **kwargs + ) + + def get_sd(self) -> Any: + return self._call_rpc("get_sd") + + def _get_property(self, name: str) -> Any: + return self._call_rpc("get_property", name) + + @property + def latent_dim(self) -> int: + return self._get_property("latent_dim") + + @property + def latent_channels(self) -> int: + return self._get_property("latent_channels") + + @property + def downscale_ratio(self) -> Any: + return self._get_property("downscale_ratio") + + @property + def upscale_ratio(self) -> Any: + return self._get_property("upscale_ratio") + + @property + def output_channels(self) -> int: + return self._get_property("output_channels") + + @property + def check_not_vide(self) -> bool: + return self._get_property("not_video") + + @property + def device(self) -> Any: + return self._get_property("device") + + @property + def working_dtypes(self) -> Any: + return self._get_property("working_dtypes") + + @property + def disable_offload(self) -> bool: + return self._get_property("disable_offload") + + @property + def size(self) -> Any: + return self._get_property("size") + + def memory_used_encode(self, shape: Any, dtype: Any) -> int: + return self._call_rpc("memory_used_encode", shape, dtype) + + def memory_used_decode(self, shape: Any, dtype: Any) -> int: + return self._call_rpc("memory_used_decode", shape, dtype) + + def process_input(self, image: Any) -> Any: + return self._call_rpc("process_input", image) + + def process_output(self, image: Any) -> Any: + return self._call_rpc("process_output", image) + + +if not IS_CHILD_PROCESS: + _VAE_REGISTRY_SINGLETON = VAERegistry() + _FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry() diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..4ed4a9250 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +import os from functools import partial from scipy import integrate @@ -12,8 +13,8 @@ from . import deis from . import sa_solver import comfy.model_patcher import comfy.model_sampling - import comfy.memory_management +from comfy.cli_args import args from comfy.utils import model_trange as trange def append_zero(x): @@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + if isolation_active: + target_device = sigmas.device + if x.device != target_device: + x = x.to(target_device) + s_in = s_in.to(target_device) + for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. diff --git a/comfy/model_base.py b/comfy/model_base.py index c2ae646aa..86005d018 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import os import comfy.ldm.lightricks.av_model import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -114,8 +115,20 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.IMG_TO_IMG_FLOW: c = comfy.model_sampling.IMG_TO_IMG_FLOW + from comfy.cli_args import args + isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + class ModelSampling(s, c): - pass + if isolation_runtime_enabled: + def __reduce__(self): + """Ensure pickling yields a proxy instead of failing on local class.""" + try: + from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy + registry = ModelSamplingRegistry() + ms_id = registry.register(self) + return (ModelSamplingProxy, (ms_id,)) + except Exception as exc: + raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc return ModelSampling(model_config) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0eebf1ded..40876d872 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -373,7 +373,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0] + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD @@ -498,6 +498,9 @@ except: current_loaded_models = [] +def _isolation_mode_enabled(): + return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -604,8 +607,9 @@ class LoadedModel: if freed >= memory_to_free: return False self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None + if self.model_finalizer is not None: + self.model_finalizer.detach() + self.model_finalizer = None self.real_model = None return True @@ -619,8 +623,15 @@ class LoadedModel: if self._patcher_finalizer is not None: self._patcher_finalizer.detach() + def dead_state(self): + model_ref_gone = self.model is None + real_model_ref = self.real_model + real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None + return model_ref_gone, real_model_ref_gone + def is_dead(self): - return self.real_model() is not None and self.model is None + model_ref_gone, real_model_ref_gone = self.dead_state() + return model_ref_gone or real_model_ref_gone def use_more_memory(extra_memory, loaded_models, device): @@ -666,6 +677,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins unloaded_model = [] can_unload = [] unloaded_models = [] + isolation_active = _isolation_mode_enabled() for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] @@ -674,6 +686,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False + if can_unload and isolation_active: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + flush_tensor_keeper = None + if callable(flush_tensor_keeper): + flushed = flush_tensor_keeper() + if flushed > 0: + logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed) + gc.collect() + can_unload_sorted = sorted(can_unload) for x in can_unload_sorted: i = x[-1] @@ -704,7 +727,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + unloaded = current_loaded_models.pop(i) + model_obj = unloaded.model + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() + unloaded_models.append(unloaded) if len(unloaded_model) > 0: soft_empty_cache() @@ -763,7 +792,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu for i in to_unload: model_to_unload = current_loaded_models.pop(i) model_to_unload.model.detach(unpatch_all=False) - model_to_unload.model_finalizer.detach() + if model_to_unload.model_finalizer is not None: + model_to_unload.model_finalizer.detach() + model_to_unload.model_finalizer = None total_memory_required = {} @@ -836,25 +867,62 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): - do_gc = False - reset_cast_buffers() + if not _isolation_mode_enabled(): + dead_found = False + for i in range(len(current_loaded_models)): + if current_loaded_models[i].is_dead(): + dead_found = True + break + if dead_found: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + gc.collect() + soft_empty_cache() + + for i in range(len(current_loaded_models) - 1, -1, -1): + cur = current_loaded_models[i] + if cur.is_dead(): + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + leaked = current_loaded_models.pop(i) + model_obj = getattr(leaked, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() + return + + dead_found = False + has_real_model_leak = False for i in range(len(current_loaded_models)): - cur = current_loaded_models[i] - if cur.is_dead(): - logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) - do_gc = True - break + model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state() + if model_ref_gone or real_model_ref_gone: + dead_found = True + if real_model_ref_gone and not model_ref_gone: + has_real_model_leak = True - if do_gc: + if dead_found: + if has_real_model_leak: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + else: + logging.debug("Cleaning stale loaded-model entries with released patcher references.") gc.collect() soft_empty_cache() - for i in range(len(current_loaded_models)): + for i in range(len(current_loaded_models) - 1, -1, -1): cur = current_loaded_models[i] - if cur.is_dead(): - logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) + model_ref_gone, real_model_ref_gone = cur.dead_state() + if model_ref_gone or real_model_ref_gone: + if real_model_ref_gone and not model_ref_gone: + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + else: + logging.debug("Cleaning stale loaded-model entry with released patcher reference.") + leaked = current_loaded_models.pop(i) + model_obj = getattr(leaked, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() def archive_model_dtypes(model): @@ -868,11 +936,20 @@ def archive_model_dtypes(model): def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - if current_loaded_models[i].real_model() is None: + real_model_ref = current_loaded_models[i].real_model + if real_model_ref is None: + to_delete = [i] + to_delete + continue + if callable(real_model_ref) and real_model_ref() is None: to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) + model_obj = getattr(x, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() del x def dtype_size(dtype): diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..35bea21d8 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -11,12 +11,14 @@ from functools import partial import collections import math import logging +import os import comfy.sampler_helpers import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows import comfy.utils +from comfy.cli_args import args import scipy.stats import numpy @@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc _calc_cond_batch, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) ) - return executor.execute(model, conds, x_in, timestep, model_options) + result = executor.execute(model, conds, x_in, timestep, model_options) + return result def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" out_conds = [] out_counts = [] # separate conds by matching hooks @@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for k, v in to_run[tt][0].conditioning.items(): cond_shapes[k].append(v.size()) - if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: + memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes) + if memory_required * 1.5 < free_memory: to_batch = batch_amount break @@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens patches = p.patches batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) + if isolation_active: + target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device + input_x = torch.cat(input_x).to(target_device) + else: + input_x = torch.cat(input_x) c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) + if isolation_active: + timestep_ = torch.cat([timestep] * batch_chunks).to(target_device) + mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult] + else: + timestep_ = torch.cat([timestep] * batch_chunks) transformer_options = model.current_patcher.apply_hooks(hooks=hooks) if 'transformer_options' in model_options: @@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] + out_t = output[o] + mult_t = mult[o] + if isolation_active: + target_dev = out_conds[cond_index].device + if hasattr(out_t, "device") and out_t.device != target_dev: + out_t = out_t.to(target_dev) + if hasattr(mult_t, "device") and mult_t.device != target_dev: + mult_t = mult_t.to(target_dev) if a is None: - out_conds[cond_index] += output[o] * mult[o] - out_counts[cond_index] += mult[o] + out_conds[cond_index] += out_t * mult_t + out_counts[cond_index] += mult_t else: out_c = out_conds[cond_index] out_cts = out_counts[cond_index] @@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for i in range(dims): out_c = out_c.narrow(i + 2, a[i + dims], a[i]) out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) - out_c += output[o] * mult[o] - out_cts += mult[o] + out_c += out_t * mult_t + out_cts += mult_t for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -392,14 +413,31 @@ class KSamplerX0Inpaint: self.inner_model = model self.sigmas = sigmas def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" if denoise_mask is not None: + if isolation_active and denoise_mask.device != x.device: + denoise_mask = denoise_mask.to(x.device) if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask - x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask + if isolation_active: + latent_image = self.latent_image + if hasattr(latent_image, "device") and latent_image.device != x.device: + latent_image = latent_image.to(x.device) + scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image) + if hasattr(scaled, "device") and scaled.device != x.device: + scaled = scaled.to(x.device) + else: + scaled = self.inner_model.inner_model.scale_latent_inpaint( + x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image + ) + x = x * denoise_mask + scaled * latent_mask out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: - out = out * denoise_mask + self.latent_image * latent_mask + latent_image = self.latent_image + if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device: + latent_image = latent_image.to(out.device) + out = out * denoise_mask + latent_image * latent_mask return out def simple_scheduler(model_sampling, steps): @@ -741,7 +779,11 @@ class KSAMPLER(Sampler): else: model_k.noise = noise - noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) + max_denoise = self.max_denoise(model_wrap, sigmas) + model_sampling = model_wrap.inner_model.model_sampling + noise = model_sampling.noise_scaling( + sigmas[0], noise, latent_image, max_denoise + ) k_callback = None total_steps = len(sigmas) - 1 diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index e238cdf3c..5d2b931df 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -65,6 +65,22 @@ class SavedAudios(_UIOutput): return {"audio": self.results} +def _is_isolated_child() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def _get_preview_folder_type() -> FolderType: + if _is_isolated_child(): + return FolderType.output + return FolderType.temp + + +def _get_preview_route_prefix(folder_type: FolderType) -> str: + if folder_type == FolderType.output: + return "output" + return "temp" + + def _get_directory_by_folder_type(folder_type: FolderType) -> str: if folder_type == FolderType.input: return folder_paths.get_input_directory() @@ -388,10 +404,11 @@ class AudioSaveHelper: class PreviewImage(_UIOutput): def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs): + folder_type = _get_preview_folder_type() self.values = ImageSaveHelper.save_images( image, filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)), - folder_type=FolderType.temp, + folder_type=folder_type, cls=cls, compress_level=1, ) @@ -412,10 +429,11 @@ class PreviewMask(PreviewImage): class PreviewAudio(_UIOutput): def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs): + folder_type = _get_preview_folder_type() self.values = AudioSaveHelper.save_audio( audio, filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)), - folder_type=FolderType.temp, + folder_type=folder_type, cls=cls, format="flac", quality="128k", @@ -438,15 +456,16 @@ class PreviewUI3D(_UIOutput): self.model_file = model_file self.camera_info = camera_info self.bg_image_path = None + folder_type = _get_preview_folder_type() bg_image = kwargs.get("bg_image", None) if bg_image is not None: img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) img = PILImage.fromarray(img_array) - temp_dir = folder_paths.get_temp_directory() + preview_dir = _get_directory_by_folder_type(folder_type) filename = f"bg_{uuid.uuid4().hex}.png" - bg_image_path = os.path.join(temp_dir, filename) + bg_image_path = os.path.join(preview_dir, filename) img.save(bg_image_path, compress_level=1) - self.bg_image_path = f"temp/{filename}" + self.bg_image_path = f"{_get_preview_route_prefix(folder_type)}/{filename}" def as_dict(self): return {"result": [self.model_file, self.camera_info, self.bg_image_path]} diff --git a/comfy_api_sealed_worker/__init__.py b/comfy_api_sealed_worker/__init__.py new file mode 100644 index 000000000..269aa2644 --- /dev/null +++ b/comfy_api_sealed_worker/__init__.py @@ -0,0 +1,18 @@ +"""comfy_api_sealed_worker — torch-free type definitions for sealed worker children. + +Drop-in replacement for comfy_api.latest._util type imports in sealed workers +that do not have torch installed. Contains only data type definitions (TrimeshData, +PLY, NPZ, etc.) with numpy-only dependencies. + +Usage in serializers: + if _IMPORT_TORCH: + from comfy_api.latest._util.trimesh_types import TrimeshData + else: + from comfy_api_sealed_worker.trimesh_types import TrimeshData +""" + +from .trimesh_types import TrimeshData +from .ply_types import PLY +from .npz_types import NPZ + +__all__ = ["TrimeshData", "PLY", "NPZ"] diff --git a/comfy_api_sealed_worker/npz_types.py b/comfy_api_sealed_worker/npz_types.py new file mode 100644 index 000000000..a93eed68c --- /dev/null +++ b/comfy_api_sealed_worker/npz_types.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import os + + +class NPZ: + """Ordered collection of NPZ file payloads. + + Each entry in ``frames`` is a complete compressed ``.npz`` file stored + as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO). + ``save_to`` writes numbered files into a directory. + """ + + def __init__(self, frames: list[bytes]) -> None: + self.frames = frames + + @property + def num_frames(self) -> int: + return len(self.frames) + + def save_to(self, directory: str, prefix: str = "frame") -> str: + os.makedirs(directory, exist_ok=True) + for i, frame_bytes in enumerate(self.frames): + path = os.path.join(directory, f"{prefix}_{i:06d}.npz") + with open(path, "wb") as f: + f.write(frame_bytes) + return directory diff --git a/comfy_api_sealed_worker/ply_types.py b/comfy_api_sealed_worker/ply_types.py new file mode 100644 index 000000000..8beb566bc --- /dev/null +++ b/comfy_api_sealed_worker/ply_types.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import numpy as np + + +class PLY: + """Point cloud payload for PLY file output. + + Supports two schemas: + - Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format) + - Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format) + + When ``raw_data`` is provided, the object acts as an opaque binary PLY + carrier and ``save_to`` writes the bytes directly. + """ + + def __init__( + self, + points: np.ndarray | None = None, + colors: np.ndarray | None = None, + confidence: np.ndarray | None = None, + view_id: np.ndarray | None = None, + raw_data: bytes | None = None, + ) -> None: + self.raw_data = raw_data + if raw_data is not None: + self.points = None + self.colors = None + self.confidence = None + self.view_id = None + return + if points is None: + raise ValueError("Either points or raw_data must be provided") + if points.ndim != 2 or points.shape[1] != 3: + raise ValueError(f"points must be (N, 3), got {points.shape}") + self.points = np.ascontiguousarray(points, dtype=np.float32) + self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None + self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None + self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None + + @property + def is_gaussian(self) -> bool: + return self.raw_data is not None + + @property + def num_points(self) -> int: + if self.points is not None: + return self.points.shape[0] + return 0 + + @staticmethod + def _to_numpy(arr, dtype): + if arr is None: + return None + if hasattr(arr, "numpy"): + arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy() + return np.ascontiguousarray(arr, dtype=dtype) + + def save_to(self, path: str) -> str: + if self.raw_data is not None: + with open(path, "wb") as f: + f.write(self.raw_data) + return path + self.points = self._to_numpy(self.points, np.float32) + self.colors = self._to_numpy(self.colors, np.float32) + self.confidence = self._to_numpy(self.confidence, np.float32) + self.view_id = self._to_numpy(self.view_id, np.int32) + N = self.num_points + header_lines = [ + "ply", + "format ascii 1.0", + f"element vertex {N}", + "property float x", + "property float y", + "property float z", + ] + if self.colors is not None: + header_lines += ["property uchar red", "property uchar green", "property uchar blue"] + if self.confidence is not None: + header_lines.append("property float confidence") + if self.view_id is not None: + header_lines.append("property int view_id") + header_lines.append("end_header") + + with open(path, "w") as f: + f.write("\n".join(header_lines) + "\n") + for i in range(N): + parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"] + if self.colors is not None: + r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8) + parts.append(f"{r} {g} {b}") + if self.confidence is not None: + parts.append(f"{self.confidence[i]}") + if self.view_id is not None: + parts.append(f"{int(self.view_id[i])}") + f.write(" ".join(parts) + "\n") + return path diff --git a/comfy_api_sealed_worker/trimesh_types.py b/comfy_api_sealed_worker/trimesh_types.py new file mode 100644 index 000000000..ff8e969e6 --- /dev/null +++ b/comfy_api_sealed_worker/trimesh_types.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import numpy as np + + +class TrimeshData: + """Triangular mesh payload for cross-process transfer. + + Lightweight carrier for mesh geometry that does not depend on the + ``trimesh`` library. Serializers create this on the host side; + isolated child processes convert to/from ``trimesh.Trimesh`` as needed. + + Supports both ColorVisuals (vertex_colors) and TextureVisuals + (uv + material with textures). + """ + + def __init__( + self, + vertices: np.ndarray, + faces: np.ndarray, + vertex_normals: np.ndarray | None = None, + face_normals: np.ndarray | None = None, + vertex_colors: np.ndarray | None = None, + uv: np.ndarray | None = None, + material: dict | None = None, + vertex_attributes: dict | None = None, + face_attributes: dict | None = None, + metadata: dict | None = None, + ) -> None: + self.vertices = np.ascontiguousarray(vertices, dtype=np.float64) + self.faces = np.ascontiguousarray(faces, dtype=np.int64) + self.vertex_normals = ( + np.ascontiguousarray(vertex_normals, dtype=np.float64) + if vertex_normals is not None + else None + ) + self.face_normals = ( + np.ascontiguousarray(face_normals, dtype=np.float64) + if face_normals is not None + else None + ) + self.vertex_colors = ( + np.ascontiguousarray(vertex_colors, dtype=np.uint8) + if vertex_colors is not None + else None + ) + self.uv = ( + np.ascontiguousarray(uv, dtype=np.float64) + if uv is not None + else None + ) + self.material = material + self.vertex_attributes = vertex_attributes or {} + self.face_attributes = face_attributes or {} + self.metadata = self._detensorize_dict(metadata) if metadata else {} + + @staticmethod + def _detensorize_dict(d): + """Recursively convert any tensors in a dict back to numpy arrays.""" + if not isinstance(d, dict): + return d + result = {} + for k, v in d.items(): + if hasattr(v, "numpy"): + result[k] = v.cpu().numpy() if hasattr(v, "cpu") else v.numpy() + elif isinstance(v, dict): + result[k] = TrimeshData._detensorize_dict(v) + elif isinstance(v, list): + result[k] = [ + item.cpu().numpy() if hasattr(item, "numpy") and hasattr(item, "cpu") + else item.numpy() if hasattr(item, "numpy") + else item + for item in v + ] + else: + result[k] = v + return result + + @staticmethod + def _to_numpy(arr, dtype): + if arr is None: + return None + if hasattr(arr, "numpy"): + arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy() + return np.ascontiguousarray(arr, dtype=dtype) + + @property + def num_vertices(self) -> int: + return self.vertices.shape[0] + + @property + def num_faces(self) -> int: + return self.faces.shape[0] + + @property + def has_texture(self) -> bool: + return self.uv is not None and self.material is not None + + def to_trimesh(self): + """Convert to trimesh.Trimesh (requires trimesh in the environment).""" + import trimesh + from trimesh.visual import TextureVisuals + + kwargs = {} + if self.vertex_normals is not None: + kwargs["vertex_normals"] = self.vertex_normals + if self.face_normals is not None: + kwargs["face_normals"] = self.face_normals + if self.metadata: + kwargs["metadata"] = self.metadata + + mesh = trimesh.Trimesh( + vertices=self.vertices, faces=self.faces, process=False, **kwargs + ) + + # Reconstruct visual + if self.has_texture: + material = self._dict_to_material(self.material) + mesh.visual = TextureVisuals(uv=self.uv, material=material) + elif self.vertex_colors is not None: + mesh.visual.vertex_colors = self.vertex_colors + + for k, v in self.vertex_attributes.items(): + mesh.vertex_attributes[k] = v + + for k, v in self.face_attributes.items(): + mesh.face_attributes[k] = v + + return mesh + + @staticmethod + def _material_to_dict(material) -> dict: + """Serialize a trimesh material to a plain dict.""" + import base64 + from io import BytesIO + from trimesh.visual.material import PBRMaterial, SimpleMaterial + + result = {"type": type(material).__name__, "name": getattr(material, "name", None)} + + if isinstance(material, PBRMaterial): + result["baseColorFactor"] = material.baseColorFactor + result["metallicFactor"] = material.metallicFactor + result["roughnessFactor"] = material.roughnessFactor + result["emissiveFactor"] = material.emissiveFactor + result["alphaMode"] = material.alphaMode + result["alphaCutoff"] = material.alphaCutoff + result["doubleSided"] = material.doubleSided + + for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture", + "metallicRoughnessTexture", "occlusionTexture"): + tex = getattr(material, tex_name, None) + if tex is not None: + buf = BytesIO() + tex.save(buf, format="PNG") + result[tex_name] = base64.b64encode(buf.getvalue()).decode("ascii") + + elif isinstance(material, SimpleMaterial): + result["main_color"] = list(material.main_color) if material.main_color is not None else None + result["glossiness"] = material.glossiness + if hasattr(material, "image") and material.image is not None: + buf = BytesIO() + material.image.save(buf, format="PNG") + result["image"] = base64.b64encode(buf.getvalue()).decode("ascii") + + return result + + @staticmethod + def _dict_to_material(d: dict): + """Reconstruct a trimesh material from a plain dict.""" + import base64 + from io import BytesIO + from PIL import Image + from trimesh.visual.material import PBRMaterial, SimpleMaterial + + mat_type = d.get("type", "PBRMaterial") + + if mat_type == "PBRMaterial": + kwargs = { + "name": d.get("name"), + "baseColorFactor": d.get("baseColorFactor"), + "metallicFactor": d.get("metallicFactor"), + "roughnessFactor": d.get("roughnessFactor"), + "emissiveFactor": d.get("emissiveFactor"), + "alphaMode": d.get("alphaMode"), + "alphaCutoff": d.get("alphaCutoff"), + "doubleSided": d.get("doubleSided"), + } + for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture", + "metallicRoughnessTexture", "occlusionTexture"): + if tex_name in d and d[tex_name] is not None: + img = Image.open(BytesIO(base64.b64decode(d[tex_name]))) + kwargs[tex_name] = img + return PBRMaterial(**{k: v for k, v in kwargs.items() if v is not None}) + + elif mat_type == "SimpleMaterial": + kwargs = { + "name": d.get("name"), + "glossiness": d.get("glossiness"), + } + if d.get("main_color") is not None: + kwargs["diffuse"] = d["main_color"] + if d.get("image") is not None: + kwargs["image"] = Image.open(BytesIO(base64.b64decode(d["image"]))) + return SimpleMaterial(**kwargs) + + raise ValueError(f"Unknown material type: {mat_type}") + + @classmethod + def from_trimesh(cls, mesh) -> TrimeshData: + """Create from a trimesh.Trimesh object.""" + from trimesh.visual.texture import TextureVisuals + + vertex_normals = None + if mesh._cache.cache.get("vertex_normals") is not None: + vertex_normals = np.asarray(mesh.vertex_normals) + + face_normals = None + if mesh._cache.cache.get("face_normals") is not None: + face_normals = np.asarray(mesh.face_normals) + + vertex_colors = None + uv = None + material = None + + if isinstance(mesh.visual, TextureVisuals): + if mesh.visual.uv is not None: + uv = np.asarray(mesh.visual.uv, dtype=np.float64) + if mesh.visual.material is not None: + material = cls._material_to_dict(mesh.visual.material) + else: + try: + vc = mesh.visual.vertex_colors + if vc is not None and len(vc) > 0: + vertex_colors = np.asarray(vc, dtype=np.uint8) + except Exception: + pass + + va = {} + if hasattr(mesh, "vertex_attributes") and mesh.vertex_attributes: + for k, v in mesh.vertex_attributes.items(): + va[k] = np.asarray(v) if hasattr(v, "__array__") else v + + fa = {} + if hasattr(mesh, "face_attributes") and mesh.face_attributes: + for k, v in mesh.face_attributes.items(): + fa[k] = np.asarray(v) if hasattr(v, "__array__") else v + + return cls( + vertices=np.asarray(mesh.vertices), + faces=np.asarray(mesh.faces), + vertex_normals=vertex_normals, + face_normals=face_normals, + vertex_colors=vertex_colors, + uv=uv, + material=material, + vertex_attributes=va if va else None, + face_attributes=fa if fa else None, + metadata=mesh.metadata if mesh.metadata else None, + ) diff --git a/comfy_extras/nodes_save_npz.py b/comfy_extras/nodes_save_npz.py new file mode 100644 index 000000000..fb9e4c877 --- /dev/null +++ b/comfy_extras/nodes_save_npz.py @@ -0,0 +1,40 @@ +import os + +import folder_paths +from comfy_api.latest import io +from comfy_api_sealed_worker.npz_types import NPZ + + +class SaveNPZ(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveNPZ", + display_name="Save NPZ", + category="3d", + is_output_node=True, + inputs=[ + io.Npz.Input("npz"), + io.String.Input("filename_prefix", default="da3_streaming/ComfyUI"), + ], + ) + + @classmethod + def execute(cls, npz: NPZ, filename_prefix: str) -> io.NodeOutput: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory() + ) + batch_dir = os.path.join(full_output_folder, f"{filename}_{counter:05}") + os.makedirs(batch_dir, exist_ok=True) + filenames = [] + for i, frame_bytes in enumerate(npz.frames): + f = f"frame_{i:06d}.npz" + with open(os.path.join(batch_dir, f), "wb") as fh: + fh.write(frame_bytes) + filenames.append(f) + return io.NodeOutput(ui={"npz_files": [{"folder": os.path.join(subfolder, f"{filename}_{counter:05}"), "count": len(filenames), "type": "output"}]}) + + +NODE_CLASS_MAPPINGS = { + "SaveNPZ": SaveNPZ, +} diff --git a/comfy_extras/nodes_save_ply.py b/comfy_extras/nodes_save_ply.py new file mode 100644 index 000000000..64c32e4de --- /dev/null +++ b/comfy_extras/nodes_save_ply.py @@ -0,0 +1,34 @@ +import os + +import folder_paths +from comfy_api.latest import io +from comfy_api_sealed_worker.ply_types import PLY + + +class SavePLY(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SavePLY", + display_name="Save PLY", + category="3d", + is_output_node=True, + inputs=[ + io.Ply.Input("ply"), + io.String.Input("filename_prefix", default="pointcloud/ComfyUI"), + ], + ) + + @classmethod + def execute(cls, ply: PLY, filename_prefix: str) -> io.NodeOutput: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory() + ) + f = f"{filename}_{counter:05}_.ply" + ply.save_to(os.path.join(full_output_folder, f)) + return io.NodeOutput(ui={"pointclouds": [{"filename": f, "subfolder": subfolder, "type": "output"}]}) + + +NODE_CLASS_MAPPINGS = { + "SavePLY": SavePLY, +} diff --git a/cuda_malloc.py b/cuda_malloc.py index f7651981c..f6d2063e9 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -92,7 +92,7 @@ if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" - else: + elif not args.use_process_isolation: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var diff --git a/execution.py b/execution.py index 5e02dffb2..fc54c3e66 100644 --- a/execution.py +++ b/execution.py @@ -1,7 +1,9 @@ import copy +import gc import heapq import inspect import logging +import os import sys import threading import time @@ -42,6 +44,8 @@ from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_re from comfy_api.latest import io, _io from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger +_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False + class ExecutionResult(Enum): SUCCESS = 0 @@ -262,20 +266,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no state, just create clone - if is_class(obj): - type_obj = obj - obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(v3_data) - # otherwise, use class instance to populate/reuse some fields + # Check for isolated node - skip validation and class cloning + if hasattr(obj, "_pyisolate_extension"): + # Isolated Node: The stub is just a proxy; real validation happens in child process + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) + # Inject hidden inputs so they're available in the isolated child process + inputs.update(v3_data.get("hidden_inputs", {})) + f = getattr(obj, func) + # Standard V3 Node (Existing Logic) + else: - type_obj = type(obj) - type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) - f = make_locked_method_func(type_obj, func, class_clone) - # in case of dynamic inputs, restructure inputs to expected nested dict - if v3_data is not None: - inputs = _io.build_nested_inputs(inputs, v3_data) + # if is just a class, then assign no resources or state, just create clone + if is_class(obj): + type_obj = obj + obj.VALIDATE_CLASS() + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) + # otherwise, use class instance to populate/reuse some fields + else: + type_obj = type(obj) + type_obj.VALIDATE_CLASS() + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) + f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -537,7 +552,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() - comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None) + if vbar_lib is not None: + comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + else: + global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED + if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED: + logging.warning( + "DynamicVRAM backend unavailable for watermark reset; " + "skipping vbar reset for this process." + ) + _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True if has_pending_tasks: pending_async_nodes[unique_id] = output_data @@ -546,6 +571,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, tasks = [x for x in output_data if isinstance(x, asyncio.Task)] await asyncio.gather(*tasks, return_exceptions=True) unblock() + + # Keep isolation node execution deterministic by default, but allow + # opt-out for diagnostics. + isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes") + if args.use_process_isolation and isolation_sequential: + await await_completion() + return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs) + asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: @@ -657,6 +690,46 @@ class PromptExecutor: self.status_messages = [] self.success = True + async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import notify_execution_graph + await notify_execution_graph(class_types, caches=self.caches.all) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:notify_execution_graph failed", exc_info=True) + + async def _flush_running_extensions_transport_state_safe(self) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import flush_running_extensions_transport_state + await flush_running_extensions_transport_state() + except Exception: + logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True) + + async def _wait_model_patcher_quiescence_safe( + self, + *, + fail_loud: bool = False, + timeout_ms: int = 120000, + marker: str = "EX:wait_model_patcher_idle", + ) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import wait_for_model_patcher_quiescence + + await wait_for_model_patcher_quiescence( + timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker + ) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True) + def add_message(self, event, data: dict, broadcast: bool): data = { **data, @@ -711,6 +784,18 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + if args.use_process_isolation: + # Update RPC event loops for all isolated extensions. + # This is critical for serial workflow execution - each asyncio.run() creates + # a new event loop, and RPC instances must be updated to use it. + try: + from comfy.isolation import update_rpc_event_loops + update_rpc_event_loops() + except ImportError: + pass # Isolation not available + except Exception as e: + logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}") + set_preview_method(extra_data.get("preview_method")) nodes.interrupt_processing(False) @@ -723,6 +808,25 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) + if args.use_process_isolation: + try: + # Boundary cleanup runs at the start of the next workflow in + # isolation mode, matching non-isolated "next prompt" timing. + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + await self._wait_model_patcher_quiescence_safe( + fail_loud=False, + timeout_ms=120000, + marker="EX:boundary_cleanup_wait_idle", + ) + await self._flush_running_extensions_transport_state_safe() + comfy.model_management.unload_all_models() + comfy.model_management.cleanup_models_gc() + comfy.model_management.cleanup_models() + gc.collect() + comfy.model_management.soft_empty_cache() + except Exception: + logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True) + self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None @@ -760,6 +864,18 @@ class PromptExecutor: for node_id in list(execute_outputs): execution_list.add_node(node_id) + if args.use_process_isolation: + pending_class_types = set() + for node_id in execution_list.pendingNodes.keys(): + class_type = dynamic_prompt.get_node(node_id)["class_type"] + pending_class_types.add(class_type) + await self._wait_model_patcher_quiescence_safe( + fail_loud=True, + timeout_ms=120000, + marker="EX:notify_graph_wait_idle", + ) + await self._notify_execution_graph_safe(pending_class_types, fail_loud=True) + while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: diff --git a/main.py b/main.py index 12b04719d..2f97afb74 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,21 @@ +import os +import sys + +IS_PYISOLATE_CHILD = os.environ.get("PYISOLATE_CHILD") == "1" + +if __name__ == "__main__" and IS_PYISOLATE_CHILD: + del os.environ["PYISOLATE_CHILD"] + IS_PYISOLATE_CHILD = False + +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +if CURRENT_DIR not in sys.path: + sys.path.insert(0, CURRENT_DIR) + +IS_PRIMARY_PROCESS = (not IS_PYISOLATE_CHILD) and __name__ == "__main__" + import comfy.options comfy.options.enable_args_parsing() -import os import importlib.util import shutil import importlib.metadata @@ -12,7 +26,7 @@ from app.logger import setup_logger from app.assets.seeder import asset_seeder from app.assets.services import register_output_files import itertools -import utils.extra_config +import utils.extra_config # noqa: F401 from utils.mime_types import init_mime_types import faulthandler import logging @@ -22,12 +36,45 @@ from comfy_execution.utils import get_executing_context from comfy_api import feature_flags from app.database.db import init_db, dependencies_available -if __name__ == "__main__": - #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. +import comfy_aimdo.control + +if enables_dynamic_vram(): + if not comfy_aimdo.control.init(): + logging.warning( + "DynamicVRAM requested, but comfy-aimdo failed to initialize early. " + "Will fall back to legacy model loading if device init fails." + ) + +if '--use-process-isolation' in sys.argv: + from comfy.isolation import initialize_proxies + initialize_proxies() + + # Explicitly register the ComfyUI adapter for pyisolate (v1.0 architecture) + try: + import pyisolate + from comfy.isolation.adapter import ComfyUIAdapter + pyisolate.register_adapter(ComfyUIAdapter()) + logging.info("PyIsolate adapter registered: comfyui") + except ImportError: + logging.warning("PyIsolate not installed or version too old for explicit registration") + except Exception as e: + logging.error(f"Failed to register PyIsolate adapter: {e}") + + if not IS_PYISOLATE_CHILD: + if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:native' + +if not IS_PYISOLATE_CHILD: + from comfy_execution.progress import get_progress_state + from comfy_execution.utils import get_executing_context + from comfy_api import feature_flags + +if IS_PRIMARY_PROCESS: os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +if not IS_PYISOLATE_CHILD: + setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) faulthandler.enable(file=sys.stderr, all_threads=False) @@ -93,14 +140,15 @@ if args.enable_manager: def apply_custom_paths(): + from utils import extra_config # Deferred import - spawn re-runs main.py # extra model paths extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): - utils.extra_config.load_extra_path_config(extra_model_paths_config_path) + extra_config.load_extra_path_config(extra_model_paths_config_path) if args.extra_model_paths_config: for config_path in itertools.chain(*args.extra_model_paths_config): - utils.extra_config.load_extra_path_config(config_path) + extra_config.load_extra_path_config(config_path) # --output-directory, --input-directory, --user-directory if args.output_directory: @@ -173,15 +221,17 @@ def execute_prestartup_script(): else: import_message = " (PRESTARTUP FAILED)" logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) - logging.info("") + logging.info("") -apply_custom_paths() -init_mime_types() +if not IS_PYISOLATE_CHILD: + apply_custom_paths() + init_mime_types() -if args.enable_manager: +if args.enable_manager and not IS_PYISOLATE_CHILD: comfyui_manager.prestartup() -execute_prestartup_script() +if not IS_PYISOLATE_CHILD: + execute_prestartup_script() # Main code @@ -192,17 +242,17 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") - import comfy.utils -import execution -import server -from protocol import BinaryEventTypes -import nodes -import comfy.model_management -import comfyui_version -import app.logger -import hook_breaker_ac10a0 +if not IS_PYISOLATE_CHILD: + import execution + import server + from protocol import BinaryEventTypes + import nodes + import comfy.model_management + import comfyui_version + import app.logger + import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher @@ -462,6 +512,10 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + if args.use_process_isolation: + from comfy.isolation import start_isolation_loading_early + start_isolation_loading_early(asyncio_loop) + if args.enable_manager and not args.disable_manager_ui: comfyui_manager.start() @@ -506,12 +560,13 @@ def start_comfyui(asyncio_loop=None): if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("Python version: {}".format(sys.version)) - logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) - for package in ("comfy-aimdo", "comfy-kitchen"): - try: - logging.info("{} version: {}".format(package, importlib.metadata.version(package))) - except: - pass + if not IS_PYISOLATE_CHILD: + logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + for package in ("comfy-aimdo", "comfy-kitchen"): + try: + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + except: + pass if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") diff --git a/nodes.py b/nodes.py index 299b3d758..f05d15793 100644 --- a/nodes.py +++ b/nodes.py @@ -1927,6 +1927,7 @@ class ImageInvert: class ImageBatch: SEARCH_ALIASES = ["combine images", "merge images", "stack images"] + ESSENTIALS_CATEGORY = "Image Tools" @classmethod def INPUT_TYPES(s): @@ -2310,6 +2311,27 @@ async def init_external_custom_nodes(): Returns: None """ + whitelist = set() + isolated_module_paths = set() + if args.use_process_isolation: + from pathlib import Path + from comfy.isolation import await_isolation_loading, get_claimed_paths + from comfy.isolation.host_policy import load_host_policy + + # Load Global Host Policy + host_policy = load_host_policy(Path(folder_paths.base_path)) + whitelist_dict = host_policy.get("whitelist", {}) + # Normalize whitelist keys to lowercase for case-insensitive matching + # (matches ComfyUI-Manager's normalization: project.name.strip().lower()) + whitelist = set(k.strip().lower() for k in whitelist_dict.keys()) + logging.info(f"][ Loaded Whitelist: {len(whitelist)} nodes allowed.") + + isolated_specs = await await_isolation_loading() + for spec in isolated_specs: + NODE_CLASS_MAPPINGS.setdefault(spec.node_name, spec.stub_class) + NODE_DISPLAY_NAME_MAPPINGS.setdefault(spec.node_name, spec.display_name) + isolated_module_paths = get_claimed_paths() + base_node_names = set(NODE_CLASS_MAPPINGS.keys()) node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] @@ -2333,6 +2355,16 @@ async def init_external_custom_nodes(): logging.info(f"Blocked by policy: {module_path}") continue + if args.use_process_isolation: + if Path(module_path).resolve() in isolated_module_paths: + continue + + # Tri-State Enforcement: If not Isolated (checked above), MUST be Whitelisted. + # Normalize to lowercase for case-insensitive matching (matches ComfyUI-Manager) + if possible_module.strip().lower() not in whitelist: + logging.warning(f"][ REJECTED: Node '{possible_module}' is blocked by security policy (not whitelisted/isolated).") + continue + time_before = time.perf_counter() success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes") node_import_times.append((time.perf_counter() - time_before, module_path, success)) @@ -2347,6 +2379,14 @@ async def init_external_custom_nodes(): logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) logging.info("") + if args.use_process_isolation: + from comfy.isolation import isolated_node_timings + if isolated_node_timings: + logging.info("\nImport times for isolated custom nodes:") + for timing, path, count in sorted(isolated_node_timings): + logging.info("{:6.1f} seconds: {} ({})".format(timing, path, count)) + logging.info("") + async def init_builtin_extra_nodes(): """ Initializes the built-in extra nodes in ComfyUI. @@ -2419,6 +2459,8 @@ async def init_builtin_extra_nodes(): "nodes_wan.py", "nodes_lotus.py", "nodes_hunyuan3d.py", + "nodes_save_ply.py", + "nodes_save_npz.py", "nodes_primitive.py", "nodes_cfg.py", "nodes_optimalsteps.py", @@ -2439,7 +2481,6 @@ async def init_builtin_extra_nodes(): "nodes_audio_encoder.py", "nodes_rope.py", "nodes_logic.py", - "nodes_resolution.py", "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", @@ -2447,7 +2488,6 @@ async def init_builtin_extra_nodes(): "nodes_zimage.py", "nodes_glsl.py", "nodes_lora_debug.py", - "nodes_textgen.py", "nodes_color.py", "nodes_toolkit.py", "nodes_replacements.py", diff --git a/pyproject.toml b/pyproject.toml index 1fc9402a1..3f85e5a3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,17 @@ homepage = "https://www.comfy.org/" repository = "https://github.com/comfyanonymous/ComfyUI" documentation = "https://docs.comfy.org/" +[tool.comfy.host] +sandbox_mode = "disabled" +allow_network = false +writable_paths = ["/dev/shm", "/tmp"] + +[tool.comfy.host.whitelist] +"ComfyUI-GGUF" = "*" +"ComfyUI-KJNodes" = "*" +"ComfyUI-Manager" = "*" +"websocket_image_save.py" = "*" + [tool.ruff] lint.select = [ "N805", # invalid-first-argument-name-for-method diff --git a/requirements.txt b/requirements.txt index 1a8e1ea1c..2c5a520c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,5 @@ pydantic~=2.0 pydantic-settings~=2.0 PyOpenGL glfw + +pyisolate==0.10.0 diff --git a/server.py b/server.py index 881da8e66..30fcff3bb 100644 --- a/server.py +++ b/server.py @@ -3,7 +3,6 @@ import sys import asyncio import traceback import time - import nodes import folder_paths import execution @@ -202,6 +201,8 @@ def create_block_external_middleware(): class PromptServer(): def __init__(self, loop): PromptServer.instance = self + if loop is None: + loop = asyncio.get_event_loop() self.user_manager = UserManager() self.model_file_manager = ModelFileManager() @@ -352,6 +353,17 @@ class PromptServer(): extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote( name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) + # Include JS files from proxied web directories (isolated nodes) + if args.use_process_isolation: + from comfy.isolation.proxies.web_directory_proxy import get_web_directory_cache + cache = get_web_directory_cache() + for ext_name in cache.extension_names: + for entry in cache.list_files(ext_name): + if entry["relative_path"].endswith(".js"): + extensions.append( + "/extensions/" + urllib.parse.quote(ext_name) + "/" + entry["relative_path"] + ) + return web.json_response(extensions) def get_dir_by_type(dir_type): @@ -1067,6 +1079,40 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) + # Add dynamic handler for proxied web directories (isolated nodes) + if args.use_process_isolation: + from comfy.isolation.proxies.web_directory_proxy import ( + get_web_directory_cache, + ALLOWED_EXTENSIONS, + ) + + async def proxied_web_handler(request): + ext_name = request.match_info["ext_name"] + file_path = request.match_info["file_path"] + + suffix = os.path.splitext(file_path)[1].lower() + if suffix not in ALLOWED_EXTENSIONS: + return web.Response(status=403, text="Forbidden file type") + + cache = get_web_directory_cache() + result = cache.get_file(ext_name, file_path) + if result is None: + return web.Response(status=404, text="Not found") + + content_type = { + ".js": "application/javascript", + ".css": "text/css", + ".html": "text/html", + ".json": "application/json", + }.get(suffix, "application/octet-stream") + + return web.Response(body=result, content_type=content_type) + + self.app.router.add_get( + "/extensions/{ext_name}/{file_path:.*}", + proxied_web_handler, + ) + installed_templates_version = FrontendManager.get_installed_templates_version() use_legacy_templates = True if installed_templates_version: diff --git a/tests/isolation/conda_sealed_worker/__init__.py b/tests/isolation/conda_sealed_worker/__init__.py new file mode 100644 index 000000000..0208a4bde --- /dev/null +++ b/tests/isolation/conda_sealed_worker/__init__.py @@ -0,0 +1,209 @@ +# pylint: disable=import-outside-toplevel,import-error +from __future__ import annotations + +import logging +import os +import sys +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def _artifact_dir() -> Path | None: + raw = os.environ.get("PYISOLATE_ARTIFACT_DIR") + if not raw: + return None + path = Path(raw) + path.mkdir(parents=True, exist_ok=True) + return path + + +def _write_artifact(name: str, content: str) -> None: + artifact_dir = _artifact_dir() + if artifact_dir is None: + return + (artifact_dir / name).write_text(content, encoding="utf-8") + + +def _contains_tensor_marker(value: Any) -> bool: + if isinstance(value, dict): + if value.get("__type__") == "TensorValue": + return True + return any(_contains_tensor_marker(v) for v in value.values()) + if isinstance(value, (list, tuple)): + return any(_contains_tensor_marker(v) for v in value) + return False + + +class InspectRuntimeNode: + RETURN_TYPES = ( + "STRING", + "STRING", + "BOOLEAN", + "BOOLEAN", + "STRING", + "STRING", + "BOOLEAN", + ) + RETURN_NAMES = ( + "path_dump", + "runtime_report", + "saw_comfy_root", + "imported_comfy_wrapper", + "comfy_module_dump", + "python_exe", + "saw_user_site", + ) + FUNCTION = "inspect" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def inspect(self) -> tuple[str, str, bool, bool, str, str, bool]: + import cfgrib + import eccodes + import xarray as xr + + path_dump = "\n".join(sys.path) + comfy_root = "/home/johnj/ComfyUI" + saw_comfy_root = any( + entry == comfy_root + or entry.startswith(f"{comfy_root}/comfy") + or entry.startswith(f"{comfy_root}/.venv") + for entry in sys.path + ) + imported_comfy_wrapper = "comfy.isolation.extension_wrapper" in sys.modules + comfy_module_dump = "\n".join( + sorted(name for name in sys.modules if name.startswith("comfy")) + ) + saw_user_site = any("/.local/lib/" in entry for entry in sys.path) + python_exe = sys.executable + + runtime_lines = [ + "Conda sealed worker runtime probe", + f"python_exe={python_exe}", + f"xarray_origin={getattr(xr, '__file__', '')}", + f"cfgrib_origin={getattr(cfgrib, '__file__', '')}", + f"eccodes_origin={getattr(eccodes, '__file__', '')}", + f"saw_comfy_root={saw_comfy_root}", + f"imported_comfy_wrapper={imported_comfy_wrapper}", + f"saw_user_site={saw_user_site}", + ] + runtime_report = "\n".join(runtime_lines) + + _write_artifact("child_bootstrap_paths.txt", path_dump) + _write_artifact("child_import_trace.txt", comfy_module_dump) + _write_artifact("child_dependency_dump.txt", runtime_report) + logger.warning("][ Conda sealed runtime probe executed") + logger.warning("][ conda python executable: %s", python_exe) + logger.warning( + "][ conda dependency origins: xarray=%s cfgrib=%s eccodes=%s", + getattr(xr, "__file__", ""), + getattr(cfgrib, "__file__", ""), + getattr(eccodes, "__file__", ""), + ) + + return ( + path_dump, + runtime_report, + saw_comfy_root, + imported_comfy_wrapper, + comfy_module_dump, + python_exe, + saw_user_site, + ) + + +class OpenWeatherDatasetNode: + RETURN_TYPES = ("FLOAT", "STRING", "STRING") + RETURN_NAMES = ("sum_value", "grib_path", "dependency_report") + FUNCTION = "open_dataset" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def open_dataset(self) -> tuple[float, str, str]: + import eccodes + import xarray as xr + + artifact_dir = _artifact_dir() + if artifact_dir is None: + artifact_dir = Path(os.environ.get("HOME", ".")) / "pyisolate_artifacts" + artifact_dir.mkdir(parents=True, exist_ok=True) + + grib_path = artifact_dir / "toolkit_weather_fixture.grib2" + + gid = eccodes.codes_grib_new_from_samples("GRIB2") + for key, value in [ + ("gridType", "regular_ll"), + ("Nx", 2), + ("Ny", 2), + ("latitudeOfFirstGridPointInDegrees", 1.0), + ("longitudeOfFirstGridPointInDegrees", 0.0), + ("latitudeOfLastGridPointInDegrees", 0.0), + ("longitudeOfLastGridPointInDegrees", 1.0), + ("iDirectionIncrementInDegrees", 1.0), + ("jDirectionIncrementInDegrees", 1.0), + ("jScansPositively", 0), + ("shortName", "t"), + ("typeOfLevel", "surface"), + ("level", 0), + ("date", 20260315), + ("time", 0), + ("step", 0), + ]: + eccodes.codes_set(gid, key, value) + + eccodes.codes_set_values(gid, [1.0, 2.0, 3.0, 4.0]) + with grib_path.open("wb") as handle: + eccodes.codes_write(gid, handle) + eccodes.codes_release(gid) + + dataset = xr.open_dataset(grib_path, engine="cfgrib") + sum_value = float(dataset["t"].sum().item()) + dependency_report = "\n".join( + [ + f"dataset_sum={sum_value}", + f"grib_path={grib_path}", + "xarray_engine=cfgrib", + ] + ) + _write_artifact("weather_dependency_report.txt", dependency_report) + logger.warning("][ cfgrib import ok") + logger.warning("][ xarray open_dataset engine=cfgrib path=%s", grib_path) + logger.warning("][ conda weather dataset sum=%s", sum_value) + return sum_value, str(grib_path), dependency_report + + +class EchoLatentNode: + RETURN_TYPES = ("LATENT", "BOOLEAN") + RETURN_NAMES = ("latent", "saw_json_tensor") + FUNCTION = "echo_latent" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"latent": ("LATENT",)}} + + def echo_latent(self, latent: Any) -> tuple[Any, bool]: + saw_json_tensor = _contains_tensor_marker(latent) + logger.warning("][ conda latent echo json_marker=%s", saw_json_tensor) + return latent, saw_json_tensor + + +NODE_CLASS_MAPPINGS = { + "CondaSealedRuntimeProbe": InspectRuntimeNode, + "CondaSealedOpenWeatherDataset": OpenWeatherDatasetNode, + "CondaSealedLatentEcho": EchoLatentNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CondaSealedRuntimeProbe": "Conda Sealed Runtime Probe", + "CondaSealedOpenWeatherDataset": "Conda Sealed Open Weather Dataset", + "CondaSealedLatentEcho": "Conda Sealed Latent Echo", +} diff --git a/tests/isolation/conda_sealed_worker/pyproject.toml b/tests/isolation/conda_sealed_worker/pyproject.toml new file mode 100644 index 000000000..6d6d7d804 --- /dev/null +++ b/tests/isolation/conda_sealed_worker/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "comfyui-toolkit-conda-sealed-worker" +version = "0.1.0" +dependencies = ["xarray", "cfgrib"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = false +package_manager = "conda" +execution_model = "sealed_worker" +standalone = true +conda_channels = ["conda-forge"] +conda_dependencies = ["eccodes", "cfgrib"] diff --git a/tests/isolation/internal_probe_host_policy.toml b/tests/isolation/internal_probe_host_policy.toml new file mode 100644 index 000000000..57bde615d --- /dev/null +++ b/tests/isolation/internal_probe_host_policy.toml @@ -0,0 +1,7 @@ +[tool.comfy.host] +sandbox_mode = "required" +allow_network = false +writable_paths = [ + "/dev/shm", + "/home/johnj/ComfyUI/output", +] diff --git a/tests/isolation/internal_probe_node/__init__.py b/tests/isolation/internal_probe_node/__init__.py new file mode 100644 index 000000000..f4155bf99 --- /dev/null +++ b/tests/isolation/internal_probe_node/__init__.py @@ -0,0 +1,6 @@ +from .probe_nodes import ( + NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS, +) + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/tests/isolation/internal_probe_node/probe_nodes.py b/tests/isolation/internal_probe_node/probe_nodes.py new file mode 100644 index 000000000..1c29996e7 --- /dev/null +++ b/tests/isolation/internal_probe_node/probe_nodes.py @@ -0,0 +1,75 @@ +from __future__ import annotations + + +class InternalIsolationProbeImage: + CATEGORY = "tests/isolation" + RETURN_TYPES = () + FUNCTION = "run" + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + def run(self): + from comfy_api.latest import UI + import torch + + image = torch.zeros((1, 2, 2, 3), dtype=torch.float32) + image[:, :, :, 0] = 1.0 + ui = UI.PreviewImage(image) + return {"ui": ui.as_dict(), "result": ()} + + +class InternalIsolationProbeAudio: + CATEGORY = "tests/isolation" + RETURN_TYPES = () + FUNCTION = "run" + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + def run(self): + from comfy_api.latest import UI + import torch + + waveform = torch.zeros((1, 1, 32), dtype=torch.float32) + audio = {"waveform": waveform, "sample_rate": 44100} + ui = UI.PreviewAudio(audio) + return {"ui": ui.as_dict(), "result": ()} + + +class InternalIsolationProbeUI3D: + CATEGORY = "tests/isolation" + RETURN_TYPES = () + FUNCTION = "run" + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + def run(self): + from comfy_api.latest import UI + import torch + + bg_image = torch.zeros((1, 2, 2, 3), dtype=torch.float32) + bg_image[:, :, :, 1] = 1.0 + camera_info = {"distance": 1.0} + ui = UI.PreviewUI3D("internal_probe_preview.obj", camera_info, bg_image=bg_image) + return {"ui": ui.as_dict(), "result": ()} + + +NODE_CLASS_MAPPINGS = { + "InternalIsolationProbeImage": InternalIsolationProbeImage, + "InternalIsolationProbeAudio": InternalIsolationProbeAudio, + "InternalIsolationProbeUI3D": InternalIsolationProbeUI3D, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "InternalIsolationProbeImage": "Internal Isolation Probe Image", + "InternalIsolationProbeAudio": "Internal Isolation Probe Audio", + "InternalIsolationProbeUI3D": "Internal Isolation Probe UI3D", +} diff --git a/tests/isolation/singleton_boundary_helpers.py b/tests/isolation/singleton_boundary_helpers.py new file mode 100644 index 000000000..f113f6a81 --- /dev/null +++ b/tests/isolation/singleton_boundary_helpers.py @@ -0,0 +1,955 @@ +from __future__ import annotations + +import asyncio +import importlib.util +import os +import sys +from pathlib import Path +from typing import Any + + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +UV_SEALED_WORKER_MODULE = COMFYUI_ROOT / "tests" / "isolation" / "uv_sealed_worker" / "__init__.py" +FORBIDDEN_MINIMAL_SEALED_MODULES = ( + "torch", + "folder_paths", + "comfy.utils", + "comfy.model_management", + "main", + "comfy.isolation.extension_wrapper", +) +FORBIDDEN_SEALED_SINGLETON_MODULES = ( + "torch", + "folder_paths", + "comfy.utils", + "comfy_execution.progress", +) +FORBIDDEN_EXACT_SMALL_PROXY_MODULES = FORBIDDEN_SEALED_SINGLETON_MODULES +FORBIDDEN_MODEL_MANAGEMENT_MODULES = ( + "comfy.model_management", +) + + +def _load_module_from_path(module_name: str, module_path: Path): + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"unable to build import spec for {module_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + except Exception: + sys.modules.pop(module_name, None) + raise + return module + + +def matching_modules(prefixes: tuple[str, ...], modules: set[str]) -> list[str]: + return sorted( + module_name + for module_name in modules + if any( + module_name == prefix or module_name.startswith(f"{prefix}.") + for prefix in prefixes + ) + ) + + +def _load_helper_proxy_service() -> Any | None: + try: + from comfy.isolation.proxies.helper_proxies import HelperProxiesService + except (ImportError, AttributeError): + return None + return HelperProxiesService + + +def _load_model_management_proxy() -> Any | None: + try: + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + except (ImportError, AttributeError): + return None + return ModelManagementProxy + + +async def _capture_minimal_sealed_worker_imports() -> dict[str, object]: + from pyisolate.sealed import SealedNodeExtension + + module_name = "tests.isolation.uv_sealed_worker_boundary_probe" + before = set(sys.modules) + extension = SealedNodeExtension() + module = _load_module_from_path(module_name, UV_SEALED_WORKER_MODULE) + try: + await extension.on_module_loaded(module) + node_list = await extension.list_nodes() + node_details = await extension.get_node_details("UVSealedRuntimeProbe") + imported = set(sys.modules) - before + return { + "mode": "minimal_sealed_worker", + "node_names": sorted(node_list), + "runtime_probe_function": node_details["function"], + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_MINIMAL_SEALED_MODULES, imported), + } + finally: + sys.modules.pop(module_name, None) + + +def capture_minimal_sealed_worker_imports() -> dict[str, object]: + return asyncio.run(_capture_minimal_sealed_worker_imports()) + + +class FakeSingletonCaller: + def __init__(self, methods: dict[str, Any], calls: list[dict[str, Any]], object_id: str): + self._methods = methods + self._calls = calls + self._object_id = object_id + + def __getattr__(self, name: str): + if name not in self._methods: + raise AttributeError(name) + + async def method(*args: Any, **kwargs: Any) -> Any: + self._calls.append( + { + "object_id": self._object_id, + "method": name, + "args": list(args), + "kwargs": dict(kwargs), + } + ) + result = self._methods[name] + return result(*args, **kwargs) if callable(result) else result + + return method + + +class FakeSingletonRPC: + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + self._device = {"__pyisolate_torch_device__": "cpu"} + self._services: dict[str, dict[str, Any]] = { + "FolderPathsProxy": { + "rpc_get_models_dir": lambda: "/sandbox/models", + "rpc_get_folder_names_and_paths": lambda: { + "checkpoints": { + "paths": ["/sandbox/models/checkpoints"], + "extensions": [".ckpt", ".safetensors"], + } + }, + "rpc_get_extension_mimetypes_cache": lambda: {"webp": "image"}, + "rpc_get_filename_list_cache": lambda: {}, + "rpc_get_temp_directory": lambda: "/sandbox/temp", + "rpc_get_input_directory": lambda: "/sandbox/input", + "rpc_get_output_directory": lambda: "/sandbox/output", + "rpc_get_user_directory": lambda: "/sandbox/user", + "rpc_get_annotated_filepath": self._get_annotated_filepath, + "rpc_exists_annotated_filepath": lambda _name: False, + "rpc_add_model_folder_path": lambda *_args, **_kwargs: None, + "rpc_get_folder_paths": lambda folder_name: [f"/sandbox/models/{folder_name}"], + "rpc_get_filename_list": lambda folder_name: [f"{folder_name}_fixture.safetensors"], + "rpc_get_full_path": lambda folder_name, filename: f"/sandbox/models/{folder_name}/{filename}", + }, + "UtilsProxy": { + "progress_bar_hook": lambda value, total, preview=None, node_id=None: { + "value": value, + "total": total, + "preview": preview, + "node_id": node_id, + } + }, + "ProgressProxy": { + "rpc_set_progress": lambda value, max_value, node_id=None, image=None: { + "value": value, + "max_value": max_value, + "node_id": node_id, + "image": image, + } + }, + "HelperProxiesService": { + "rpc_restore_input_types": lambda raw: raw, + }, + "ModelManagementProxy": { + "rpc_call": self._model_management_rpc_call, + }, + } + + def _model_management_rpc_call(self, method_name: str, args: Any = None, kwargs: Any = None) -> Any: + if method_name == "get_torch_device": + return self._device + elif method_name == "get_torch_device_name": + return "cpu" + elif method_name == "get_free_memory": + return 34359738368 + raise AssertionError(f"unexpected model_management method {method_name}") + + @staticmethod + def _get_annotated_filepath(name: str, default_dir: str | None = None) -> str: + if name.endswith("[output]"): + return f"/sandbox/output/{name[:-8]}" + if name.endswith("[input]"): + return f"/sandbox/input/{name[:-7]}" + if name.endswith("[temp]"): + return f"/sandbox/temp/{name[:-6]}" + base_dir = default_dir or "/sandbox/input" + return f"{base_dir}/{name}" + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + return FakeSingletonCaller(methods, self.calls, object_id) + + +def _clear_proxy_rpcs() -> None: + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + FolderPathsProxy.clear_rpc() + ProgressProxy.clear_rpc() + UtilsProxy.clear_rpc() + helper_proxy_service = _load_helper_proxy_service() + if helper_proxy_service is not None: + helper_proxy_service.clear_rpc() + model_management_proxy = _load_model_management_proxy() + if model_management_proxy is not None and hasattr(model_management_proxy, "clear_rpc"): + model_management_proxy.clear_rpc() + + +def prepare_sealed_singleton_proxies(fake_rpc: FakeSingletonRPC) -> None: + os.environ["PYISOLATE_CHILD"] = "1" + os.environ["PYISOLATE_IMPORT_TORCH"] = "0" + _clear_proxy_rpcs() + + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + FolderPathsProxy.set_rpc(fake_rpc) + ProgressProxy.set_rpc(fake_rpc) + UtilsProxy.set_rpc(fake_rpc) + helper_proxy_service = _load_helper_proxy_service() + if helper_proxy_service is not None: + helper_proxy_service.set_rpc(fake_rpc) + model_management_proxy = _load_model_management_proxy() + if model_management_proxy is not None and hasattr(model_management_proxy, "set_rpc"): + model_management_proxy.set_rpc(fake_rpc) + + +def reset_forbidden_singleton_modules() -> None: + for module_name in ( + "folder_paths", + "comfy.utils", + "comfy_execution.progress", + ): + sys.modules.pop(module_name, None) + + +class FakeExactRelayCaller: + def __init__(self, methods: dict[str, Any], transcripts: list[dict[str, Any]], object_id: str): + self._methods = methods + self._transcripts = transcripts + self._object_id = object_id + + def __getattr__(self, name: str): + if name not in self._methods: + raise AttributeError(name) + + async def method(*args: Any, **kwargs: Any) -> Any: + self._transcripts.append( + { + "phase": "child_call", + "object_id": self._object_id, + "method": name, + "args": list(args), + "kwargs": dict(kwargs), + } + ) + impl = self._methods[name] + self._transcripts.append( + { + "phase": "host_invocation", + "object_id": self._object_id, + "method": name, + "target": impl["target"], + "args": list(args), + "kwargs": dict(kwargs), + } + ) + result = impl["result"](*args, **kwargs) if callable(impl["result"]) else impl["result"] + self._transcripts.append( + { + "phase": "result", + "object_id": self._object_id, + "method": name, + "result": result, + } + ) + return result + + return method + + +class FakeExactRelayRPC: + def __init__(self) -> None: + self.transcripts: list[dict[str, Any]] = [] + self._device = {"__pyisolate_torch_device__": "cpu"} + self._services: dict[str, dict[str, Any]] = { + "FolderPathsProxy": { + "rpc_get_models_dir": { + "target": "folder_paths.models_dir", + "result": "/sandbox/models", + }, + "rpc_get_temp_directory": { + "target": "folder_paths.get_temp_directory", + "result": "/sandbox/temp", + }, + "rpc_get_input_directory": { + "target": "folder_paths.get_input_directory", + "result": "/sandbox/input", + }, + "rpc_get_output_directory": { + "target": "folder_paths.get_output_directory", + "result": "/sandbox/output", + }, + "rpc_get_user_directory": { + "target": "folder_paths.get_user_directory", + "result": "/sandbox/user", + }, + "rpc_get_folder_names_and_paths": { + "target": "folder_paths.folder_names_and_paths", + "result": { + "checkpoints": { + "paths": ["/sandbox/models/checkpoints"], + "extensions": [".ckpt", ".safetensors"], + } + }, + }, + "rpc_get_extension_mimetypes_cache": { + "target": "folder_paths.extension_mimetypes_cache", + "result": {"webp": "image"}, + }, + "rpc_get_filename_list_cache": { + "target": "folder_paths.filename_list_cache", + "result": {}, + }, + "rpc_get_annotated_filepath": { + "target": "folder_paths.get_annotated_filepath", + "result": lambda name, default_dir=None: FakeSingletonRPC._get_annotated_filepath(name, default_dir), + }, + "rpc_exists_annotated_filepath": { + "target": "folder_paths.exists_annotated_filepath", + "result": False, + }, + "rpc_add_model_folder_path": { + "target": "folder_paths.add_model_folder_path", + "result": None, + }, + "rpc_get_folder_paths": { + "target": "folder_paths.get_folder_paths", + "result": lambda folder_name: [f"/sandbox/models/{folder_name}"], + }, + "rpc_get_filename_list": { + "target": "folder_paths.get_filename_list", + "result": lambda folder_name: [f"{folder_name}_fixture.safetensors"], + }, + "rpc_get_full_path": { + "target": "folder_paths.get_full_path", + "result": lambda folder_name, filename: f"/sandbox/models/{folder_name}/{filename}", + }, + }, + "UtilsProxy": { + "progress_bar_hook": { + "target": "comfy.utils.PROGRESS_BAR_HOOK", + "result": lambda value, total, preview=None, node_id=None: { + "value": value, + "total": total, + "preview": preview, + "node_id": node_id, + }, + }, + }, + "ProgressProxy": { + "rpc_set_progress": { + "target": "comfy_execution.progress.get_progress_state().update_progress", + "result": None, + }, + }, + "HelperProxiesService": { + "rpc_restore_input_types": { + "target": "comfy.isolation.proxies.helper_proxies.restore_input_types", + "result": lambda raw: raw, + } + }, + "ModelManagementProxy": { + "rpc_call": { + "target": "comfy.model_management.*", + "result": self._model_management_rpc_call, + }, + }, + } + + def _model_management_rpc_call(self, method_name: str, args: Any = None, kwargs: Any = None) -> Any: + device = {"__pyisolate_torch_device__": "cpu"} + if method_name == "get_torch_device": + return device + elif method_name == "get_torch_device_name": + return "cpu" + elif method_name == "get_free_memory": + return 34359738368 + raise AssertionError(f"unexpected exact-relay method {method_name}") + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + return FakeExactRelayCaller(methods, self.transcripts, object_id) + + +def capture_exact_small_proxy_relay() -> dict[str, object]: + reset_forbidden_singleton_modules() + fake_rpc = FakeExactRelayRPC() + previous_child = os.environ.get("PYISOLATE_CHILD") + previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH") + try: + prepare_sealed_singleton_proxies(fake_rpc) + + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.helper_proxies import restore_input_types + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + folder_proxy = FolderPathsProxy() + utils_proxy = UtilsProxy() + progress_proxy = ProgressProxy() + before = set(sys.modules) + + restored = restore_input_types( + { + "required": { + "image": {"__pyisolate_any_type__": True, "value": "*"}, + } + } + ) + folder_path = folder_proxy.get_annotated_filepath("demo.png[input]") + models_dir = folder_proxy.models_dir + folder_names_and_paths = folder_proxy.folder_names_and_paths + asyncio.run(utils_proxy.progress_bar_hook(2, 5, node_id="node-17")) + progress_proxy.set_progress(1.5, 5.0, node_id="node-17") + + imported = set(sys.modules) - before + return { + "mode": "exact_small_proxy_relay", + "folder_path": folder_path, + "models_dir": models_dir, + "folder_names_and_paths": folder_names_and_paths, + "restored_any_type": str(restored["required"]["image"]), + "transcripts": fake_rpc.transcripts, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_EXACT_SMALL_PROXY_MODULES, imported), + } + finally: + _clear_proxy_rpcs() + if previous_child is None: + os.environ.pop("PYISOLATE_CHILD", None) + else: + os.environ["PYISOLATE_CHILD"] = previous_child + if previous_import_torch is None: + os.environ.pop("PYISOLATE_IMPORT_TORCH", None) + else: + os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch + + +class FakeModelManagementExactRelayRPC: + def __init__(self) -> None: + self.transcripts: list[dict[str, object]] = [] + self._device = {"__pyisolate_torch_device__": "cpu"} + self._services: dict[str, dict[str, Any]] = { + "ModelManagementProxy": { + "rpc_call": self._rpc_call, + } + } + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + return _ModelManagementExactRelayCaller(methods) + + def _rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any: + self.transcripts.append( + { + "phase": "child_call", + "object_id": "ModelManagementProxy", + "method": method_name, + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + target = f"comfy.model_management.{method_name}" + self.transcripts.append( + { + "phase": "host_invocation", + "object_id": "ModelManagementProxy", + "method": method_name, + "target": target, + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + if method_name == "get_torch_device": + result = self._device + elif method_name == "get_torch_device_name": + result = "cpu" + elif method_name == "get_free_memory": + result = 34359738368 + else: + raise AssertionError(f"unexpected exact-relay method {method_name}") + self.transcripts.append( + { + "phase": "result", + "object_id": "ModelManagementProxy", + "method": method_name, + "result": _json_safe(result), + } + ) + return result + + +class _ModelManagementExactRelayCaller: + def __init__(self, methods: dict[str, Any]): + self._methods = methods + + def __getattr__(self, name: str): + if name not in self._methods: + raise AttributeError(name) + + async def method(*args: Any, **kwargs: Any) -> Any: + impl = self._methods[name] + return impl(*args, **kwargs) if callable(impl) else impl + + return method + + +def _json_safe(value: Any) -> Any: + if callable(value): + return f"" + if isinstance(value, tuple): + return [_json_safe(item) for item in value] + if isinstance(value, list): + return [_json_safe(item) for item in value] + if isinstance(value, dict): + return {key: _json_safe(inner) for key, inner in value.items()} + return value + + +def capture_model_management_exact_relay() -> dict[str, object]: + for module_name in FORBIDDEN_MODEL_MANAGEMENT_MODULES: + sys.modules.pop(module_name, None) + + fake_rpc = FakeModelManagementExactRelayRPC() + previous_child = os.environ.get("PYISOLATE_CHILD") + previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH") + try: + os.environ["PYISOLATE_CHILD"] = "1" + os.environ["PYISOLATE_IMPORT_TORCH"] = "0" + + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + + if hasattr(ModelManagementProxy, "clear_rpc"): + ModelManagementProxy.clear_rpc() + if hasattr(ModelManagementProxy, "set_rpc"): + ModelManagementProxy.set_rpc(fake_rpc) + + proxy = ModelManagementProxy() + before = set(sys.modules) + device = proxy.get_torch_device() + device_name = proxy.get_torch_device_name(device) + free_memory = proxy.get_free_memory(device) + imported = set(sys.modules) - before + return { + "mode": "model_management_exact_relay", + "device": str(device), + "device_type": getattr(device, "type", None), + "device_name": device_name, + "free_memory": free_memory, + "transcripts": fake_rpc.transcripts, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_MODEL_MANAGEMENT_MODULES, imported), + } + finally: + model_management_proxy = _load_model_management_proxy() + if model_management_proxy is not None and hasattr(model_management_proxy, "clear_rpc"): + model_management_proxy.clear_rpc() + if previous_child is None: + os.environ.pop("PYISOLATE_CHILD", None) + else: + os.environ["PYISOLATE_CHILD"] = previous_child + if previous_import_torch is None: + os.environ.pop("PYISOLATE_IMPORT_TORCH", None) + else: + os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch + + +FORBIDDEN_PROMPT_WEB_MODULES = ( + "server", + "aiohttp", + "comfy.isolation.extension_wrapper", +) +FORBIDDEN_EXACT_BOOTSTRAP_MODULES = ( + "comfy.isolation.adapter", + "folder_paths", + "comfy.utils", + "comfy.model_management", + "server", + "main", + "comfy.isolation.extension_wrapper", +) + + +class _PromptServiceExactRelayCaller: + def __init__(self, methods: dict[str, Any], transcripts: list[dict[str, Any]], object_id: str): + self._methods = methods + self._transcripts = transcripts + self._object_id = object_id + + def __getattr__(self, name: str): + if name not in self._methods: + raise AttributeError(name) + + async def method(*args: Any, **kwargs: Any) -> Any: + self._transcripts.append( + { + "phase": "child_call", + "object_id": self._object_id, + "method": name, + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + impl = self._methods[name] + self._transcripts.append( + { + "phase": "host_invocation", + "object_id": self._object_id, + "method": name, + "target": impl["target"], + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + result = impl["result"](*args, **kwargs) if callable(impl["result"]) else impl["result"] + self._transcripts.append( + { + "phase": "result", + "object_id": self._object_id, + "method": name, + "result": _json_safe(result), + } + ) + return result + + return method + + +class FakePromptWebRPC: + def __init__(self) -> None: + self.transcripts: list[dict[str, Any]] = [] + self._services = { + "PromptServerService": { + "ui_send_progress_text": { + "target": "server.PromptServer.instance.send_progress_text", + "result": None, + }, + "register_route_rpc": { + "target": "server.PromptServer.instance.routes.add_route", + "result": None, + }, + } + } + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + return _PromptServiceExactRelayCaller(methods, self.transcripts, object_id) + + +class FakeWebDirectoryProxy: + def __init__(self, transcripts: list[dict[str, Any]]): + self._transcripts = transcripts + + def get_web_file(self, extension_name: str, relative_path: str) -> dict[str, Any]: + self._transcripts.append( + { + "phase": "child_call", + "object_id": "WebDirectoryProxy", + "method": "get_web_file", + "args": [extension_name, relative_path], + "kwargs": {}, + } + ) + self._transcripts.append( + { + "phase": "host_invocation", + "object_id": "WebDirectoryProxy", + "method": "get_web_file", + "target": "comfy.isolation.proxies.web_directory_proxy.WebDirectoryProxy.get_web_file", + "args": [extension_name, relative_path], + "kwargs": {}, + } + ) + result = { + "content": "Y29uc29sZS5sb2coJ2RlbycpOw==", + "content_type": "application/javascript", + } + self._transcripts.append( + { + "phase": "result", + "object_id": "WebDirectoryProxy", + "method": "get_web_file", + "result": result, + } + ) + return result + + +def capture_prompt_web_exact_relay() -> dict[str, object]: + for module_name in FORBIDDEN_PROMPT_WEB_MODULES: + sys.modules.pop(module_name, None) + + fake_rpc = FakePromptWebRPC() + + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache + + PromptServerStub.set_rpc(fake_rpc) + stub = PromptServerStub() + cache = WebDirectoryCache() + cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts)) + + before = set(sys.modules) + + def demo_handler(_request): + return {"ok": True} + + stub.send_progress_text("hello", "node-17") + stub.routes.get("/demo")(demo_handler) + web_file = cache.get_file("demo_ext", "js/app.js") + imported = set(sys.modules) - before + return { + "mode": "prompt_web_exact_relay", + "web_file": { + "content_type": web_file["content_type"] if web_file else None, + "content": web_file["content"].decode("utf-8") if web_file else None, + }, + "transcripts": fake_rpc.transcripts, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_PROMPT_WEB_MODULES, imported), + } + + +class FakeExactBootstrapRPC: + def __init__(self) -> None: + self.transcripts: list[dict[str, Any]] = [] + self._device = {"__pyisolate_torch_device__": "cpu"} + self._services: dict[str, dict[str, Any]] = { + "FolderPathsProxy": FakeExactRelayRPC()._services["FolderPathsProxy"], + "HelperProxiesService": FakeExactRelayRPC()._services["HelperProxiesService"], + "ProgressProxy": FakeExactRelayRPC()._services["ProgressProxy"], + "UtilsProxy": FakeExactRelayRPC()._services["UtilsProxy"], + "PromptServerService": { + "ui_send_sync": { + "target": "server.PromptServer.instance.send_sync", + "result": None, + }, + "ui_send": { + "target": "server.PromptServer.instance.send", + "result": None, + }, + "ui_send_progress_text": { + "target": "server.PromptServer.instance.send_progress_text", + "result": None, + }, + "register_route_rpc": { + "target": "server.PromptServer.instance.routes.add_route", + "result": None, + }, + }, + "ModelManagementProxy": { + "rpc_call": self._rpc_call, + }, + } + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + if object_id == "ModelManagementProxy": + return _ModelManagementExactRelayCaller(methods) + return _PromptServiceExactRelayCaller(methods, self.transcripts, object_id) + + def _rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any: + self.transcripts.append( + { + "phase": "child_call", + "object_id": "ModelManagementProxy", + "method": method_name, + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + self.transcripts.append( + { + "phase": "host_invocation", + "object_id": "ModelManagementProxy", + "method": method_name, + "target": f"comfy.model_management.{method_name}", + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + result = self._device if method_name == "get_torch_device" else None + self.transcripts.append( + { + "phase": "result", + "object_id": "ModelManagementProxy", + "method": method_name, + "result": _json_safe(result), + } + ) + return result + + +def capture_exact_proxy_bootstrap_contract() -> dict[str, object]: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance, set_child_rpc_instance + + from comfy.isolation.adapter import ComfyUIAdapter + from comfy.isolation.child_hooks import initialize_child_process + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.helper_proxies import HelperProxiesService + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + host_services = sorted(cls.__name__ for cls in ComfyUIAdapter().provide_rpc_services()) + + for module_name in FORBIDDEN_EXACT_BOOTSTRAP_MODULES: + sys.modules.pop(module_name, None) + + previous_child = os.environ.get("PYISOLATE_CHILD") + previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH") + os.environ["PYISOLATE_CHILD"] = "1" + os.environ["PYISOLATE_IMPORT_TORCH"] = "0" + + _clear_proxy_rpcs() + if hasattr(PromptServerStub, "clear_rpc"): + PromptServerStub.clear_rpc() + else: + PromptServerStub._rpc = None # type: ignore[attr-defined] + fake_rpc = FakeExactBootstrapRPC() + set_child_rpc_instance(fake_rpc) + + before = set(sys.modules) + try: + initialize_child_process() + imported = set(sys.modules) - before + matrix = { + "base.py": { + "bound": get_child_rpc_instance() is fake_rpc, + "details": {"child_rpc_instance": get_child_rpc_instance() is fake_rpc}, + }, + "folder_paths_proxy.py": { + "bound": "FolderPathsProxy" in host_services and FolderPathsProxy._rpc is not None, + "details": {"host_service": "FolderPathsProxy" in host_services, "child_rpc": FolderPathsProxy._rpc is not None}, + }, + "helper_proxies.py": { + "bound": "HelperProxiesService" in host_services and HelperProxiesService._rpc is not None, + "details": {"host_service": "HelperProxiesService" in host_services, "child_rpc": HelperProxiesService._rpc is not None}, + }, + "model_management_proxy.py": { + "bound": "ModelManagementProxy" in host_services and ModelManagementProxy._rpc is not None, + "details": {"host_service": "ModelManagementProxy" in host_services, "child_rpc": ModelManagementProxy._rpc is not None}, + }, + "progress_proxy.py": { + "bound": "ProgressProxy" in host_services and ProgressProxy._rpc is not None, + "details": {"host_service": "ProgressProxy" in host_services, "child_rpc": ProgressProxy._rpc is not None}, + }, + "prompt_server_impl.py": { + "bound": "PromptServerService" in host_services and PromptServerStub._rpc is not None, + "details": {"host_service": "PromptServerService" in host_services, "child_rpc": PromptServerStub._rpc is not None}, + }, + "utils_proxy.py": { + "bound": "UtilsProxy" in host_services and UtilsProxy._rpc is not None, + "details": {"host_service": "UtilsProxy" in host_services, "child_rpc": UtilsProxy._rpc is not None}, + }, + "web_directory_proxy.py": { + "bound": "WebDirectoryProxy" in host_services, + "details": {"host_service": "WebDirectoryProxy" in host_services}, + }, + } + finally: + set_child_rpc_instance(None) + if previous_child is None: + os.environ.pop("PYISOLATE_CHILD", None) + else: + os.environ["PYISOLATE_CHILD"] = previous_child + if previous_import_torch is None: + os.environ.pop("PYISOLATE_IMPORT_TORCH", None) + else: + os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch + + omitted = sorted(name for name, status in matrix.items() if not status["bound"]) + return { + "mode": "exact_proxy_bootstrap_contract", + "host_services": host_services, + "matrix": matrix, + "omitted_proxies": omitted, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_EXACT_BOOTSTRAP_MODULES, imported), + } + +def capture_sealed_singleton_imports() -> dict[str, object]: + reset_forbidden_singleton_modules() + fake_rpc = FakeSingletonRPC() + previous_child = os.environ.get("PYISOLATE_CHILD") + previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH") + before = set(sys.modules) + try: + prepare_sealed_singleton_proxies(fake_rpc) + + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + folder_proxy = FolderPathsProxy() + progress_proxy = ProgressProxy() + utils_proxy = UtilsProxy() + + folder_path = folder_proxy.get_annotated_filepath("demo.png[input]") + temp_dir = folder_proxy.get_temp_directory() + models_dir = folder_proxy.models_dir + asyncio.run(utils_proxy.progress_bar_hook(2, 5, node_id="node-17")) + progress_proxy.set_progress(1.5, 5.0, node_id="node-17") + + imported = set(sys.modules) - before + return { + "mode": "sealed_singletons", + "folder_path": folder_path, + "temp_dir": temp_dir, + "models_dir": models_dir, + "rpc_calls": fake_rpc.calls, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_SEALED_SINGLETON_MODULES, imported), + } + finally: + _clear_proxy_rpcs() + if previous_child is None: + os.environ.pop("PYISOLATE_CHILD", None) + else: + os.environ["PYISOLATE_CHILD"] = previous_child + if previous_import_torch is None: + os.environ.pop("PYISOLATE_IMPORT_TORCH", None) + else: + os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch diff --git a/tests/isolation/stage_internal_probe_node.py b/tests/isolation/stage_internal_probe_node.py new file mode 100644 index 000000000..b072ab43e --- /dev/null +++ b/tests/isolation/stage_internal_probe_node.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import argparse +import shutil +import sys +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator + + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +PROBE_SOURCE_ROOT = COMFYUI_ROOT / "tests" / "isolation" / "internal_probe_node" +PROBE_NODE_NAME = "InternalIsolationProbeNode" + +PYPROJECT_CONTENT = """[project] +name = "InternalIsolationProbeNode" +version = "0.0.1" + +[tool.comfy.isolation] +can_isolate = true +share_torch = true +""" + + +def _probe_target_root(comfy_root: Path) -> Path: + return Path(comfy_root) / "custom_nodes" / PROBE_NODE_NAME + + +def stage_probe_node(comfy_root: Path) -> Path: + if not PROBE_SOURCE_ROOT.is_dir(): + raise RuntimeError(f"Missing probe source directory: {PROBE_SOURCE_ROOT}") + + target_root = _probe_target_root(comfy_root) + target_root.mkdir(parents=True, exist_ok=True) + for source_path in PROBE_SOURCE_ROOT.iterdir(): + destination_path = target_root / source_path.name + if source_path.is_dir(): + shutil.copytree(source_path, destination_path, dirs_exist_ok=True) + else: + shutil.copy2(source_path, destination_path) + + (target_root / "pyproject.toml").write_text(PYPROJECT_CONTENT, encoding="utf-8") + return target_root + + +@contextmanager +def staged_probe_node() -> Iterator[Path]: + staging_root = Path(tempfile.mkdtemp(prefix="comfyui_internal_probe_")) + try: + yield stage_probe_node(staging_root) + finally: + shutil.rmtree(staging_root, ignore_errors=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Stage the internal isolation probe node under an explicit ComfyUI root." + ) + parser.add_argument( + "--target-root", + type=Path, + required=True, + help="Explicit ComfyUI root to stage under. Caller owns cleanup.", + ) + args = parser.parse_args() + + staged = stage_probe_node(args.target_root) + sys.stdout.write(f"{staged}\n") diff --git a/tests/isolation/test_client_snapshot.py b/tests/isolation/test_client_snapshot.py new file mode 100644 index 000000000..0eedf6b41 --- /dev/null +++ b/tests/isolation/test_client_snapshot.py @@ -0,0 +1,122 @@ +"""Tests for pyisolate._internal.client import-time snapshot handling.""" + +import json +import os +import subprocess +import sys +from pathlib import Path + +import pytest + +# Paths needed for subprocess +PYISOLATE_ROOT = str(Path(__file__).parent.parent) +COMFYUI_ROOT = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + +SCRIPT = """ +import json, sys +import pyisolate._internal.client # noqa: F401 # triggers snapshot logic +print(json.dumps(sys.path[:6])) +""" + + +def _run_client_process(env): + # Ensure subprocess can find pyisolate and ComfyUI + pythonpath_parts = [PYISOLATE_ROOT, COMFYUI_ROOT] + existing = env.get("PYTHONPATH", "") + if existing: + pythonpath_parts.append(existing) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_parts) + + result = subprocess.run( # noqa: S603 + [sys.executable, "-c", SCRIPT], + capture_output=True, + text=True, + env=env, + check=True, + ) + stdout = result.stdout.strip().splitlines()[-1] + return json.loads(stdout) + + +@pytest.fixture() +def comfy_module_path(tmp_path): + comfy_root = tmp_path / "ComfyUI" + module_path = comfy_root / "custom_nodes" / "TestNode" + module_path.mkdir(parents=True) + return comfy_root, module_path + + +def test_snapshot_applied_and_comfy_root_prepend(tmp_path, comfy_module_path): + comfy_root, module_path = comfy_module_path + # Must include real ComfyUI path for utils validation to pass + host_paths = [COMFYUI_ROOT, "/host/lib1", "/host/lib2"] + snapshot = { + "sys_path": host_paths, + "sys_executable": sys.executable, + "sys_prefix": sys.prefix, + "environment": {}, + } + snapshot_path = tmp_path / "snapshot.json" + snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8") + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(snapshot_path), + "PYISOLATE_MODULE_PATH": str(module_path), + } + ) + + path_prefix = _run_client_process(env) + + # Current client behavior preserves the runtime bootstrap path order and + # keeps the resolved ComfyUI root available for imports. + assert COMFYUI_ROOT in path_prefix + # Module path should not override runtime root selection. + assert str(comfy_root) not in path_prefix + + +def test_missing_snapshot_file_does_not_crash(tmp_path, comfy_module_path): + _, module_path = comfy_module_path + missing_snapshot = tmp_path / "missing.json" + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(missing_snapshot), + "PYISOLATE_MODULE_PATH": str(module_path), + } + ) + + # Should not raise even though snapshot path is missing + paths = _run_client_process(env) + assert len(paths) > 0 + + +def test_no_comfy_root_when_module_path_absent(tmp_path): + # Must include real ComfyUI path for utils validation to pass + host_paths = [COMFYUI_ROOT, "/alpha", "/beta"] + snapshot = { + "sys_path": host_paths, + "sys_executable": sys.executable, + "sys_prefix": sys.prefix, + "environment": {}, + } + snapshot_path = tmp_path / "snapshot.json" + snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8") + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(snapshot_path), + } + ) + + paths = _run_client_process(env) + # Runtime path bootstrap keeps ComfyUI importability regardless of host + # snapshot extras. + assert COMFYUI_ROOT in paths + assert "/alpha" not in paths and "/beta" not in paths diff --git a/tests/isolation/test_cuda_wheels_and_env_flags.py b/tests/isolation/test_cuda_wheels_and_env_flags.py new file mode 100644 index 000000000..f0361d5ef --- /dev/null +++ b/tests/isolation/test_cuda_wheels_and_env_flags.py @@ -0,0 +1,460 @@ +"""Synthetic integration coverage for manifest plumbing and env flags. + +These tests do not perform a real wheel install or a real ComfyUI E2E run. +""" + +import asyncio +import logging +import os +import sys +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +import comfy.isolation as isolation_pkg +from comfy.isolation import runtime_helpers +from comfy.isolation import extension_loader as extension_loader_module +from comfy.isolation import extension_wrapper as extension_wrapper_module +from comfy.isolation import model_patcher_proxy_utils +from comfy.isolation.extension_loader import ExtensionLoadError, load_isolated_node +from comfy.isolation.extension_wrapper import ComfyNodeExtension +from comfy.isolation.model_patcher_proxy_utils import maybe_wrap_model_for_isolation +from pyisolate._internal.environment_conda import _generate_pixi_toml + + +class _DummyExtension: + def __init__(self) -> None: + self.name = "demo-extension" + + async def stop(self) -> None: + return None + + +def _write_manifest(node_dir, manifest_text: str) -> None: + (node_dir / "pyproject.toml").write_text(manifest_text, encoding="utf-8") + + +def test_load_isolated_node_passes_normalized_cuda_wheels_config(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["flash-attn>=1.0", "sageattention==0.1"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = true + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["flash_attn", "sageattention"] + +[tool.comfy.isolation.cuda_wheels.package_map] +flash_attn = "flash-attn-special" +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + specs = asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert len(specs) == 1 + assert captured["sandbox_mode"] == "required" + assert captured["cuda_wheels"] == { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn", "sageattention"], + "package_map": {"flash-attn": "flash-attn-special"}, + } + + +def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency( + tmp_path, monkeypatch +): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["numpy>=1.0"] + +[tool.comfy.isolation] +can_isolate = true + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["flash-attn"] +""".strip(), + ) + + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + with pytest.raises(ExtensionLoadError, match="undeclared dependencies"): + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + +def test_conda_cuda_wheels_declared_packages_do_not_force_pixi_solve(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["numpy>=1.0", "spconv", "cumm", "flash-attn"] + +[tool.comfy.isolation] +can_isolate = true +package_manager = "conda" +conda_channels = ["conda-forge"] + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["spconv", "cumm", "flash-attn"] +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "disabled", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + generated = _generate_pixi_toml(captured) + assert 'numpy = ">=1.0"' in generated + assert "spconv =" not in generated + assert "cumm =" not in generated + assert "flash-attn =" not in generated + + +def test_conda_cuda_wheels_loader_accepts_sam3d_contract(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = [ + "torch", + "torchvision", + "pytorch3d", + "gsplat", + "nvdiffrast", + "flash-attn", + "sageattention", + "spconv", + "cumm", +] + +[tool.comfy.isolation] +can_isolate = true +package_manager = "conda" +conda_channels = ["conda-forge"] + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["pytorch3d", "gsplat", "nvdiffrast", "flash-attn", "sageattention", "spconv", "cumm"] +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "disabled", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert captured["package_manager"] == "conda" + assert captured["cuda_wheels"] == { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": [ + "pytorch3d", + "gsplat", + "nvdiffrast", + "flash-attn", + "sageattention", + "spconv", + "cumm", + ], + "package_map": {}, + } + + +def test_load_isolated_node_omits_cuda_wheels_when_not_configured(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["numpy>=1.0"] + +[tool.comfy.isolation] +can_isolate = true +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "disabled", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert captured["sandbox_mode"] == "disabled" + assert "cuda_wheels" not in captured + + +def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch): + class DummyRegistry: + def register(self, model): + return "model-123" + + class DummyProxy: + def __init__(self, model_id, registry, manage_lifecycle): + self.model_id = model_id + self.registry = registry + self.manage_lifecycle = manage_lifecycle + + monkeypatch.setattr(model_patcher_proxy_utils.args, "use_process_isolation", True) + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.delenv("PYISOLATE_CHILD", raising=False) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy_registry", + SimpleNamespace(ModelPatcherRegistry=DummyRegistry), + ) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy", + SimpleNamespace(ModelPatcherProxy=DummyProxy), + ) + + wrapped = cast(Any, maybe_wrap_model_for_isolation(object())) + + assert isinstance(wrapped, DummyProxy) + assert getattr(wrapped, "model_id") == "model-123" + assert getattr(wrapped, "manage_lifecycle") is True + + +def test_flush_transport_state_uses_child_env_without_legacy_flag(monkeypatch): + monkeypatch.setenv("PYISOLATE_CHILD", "1") + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.setattr(extension_wrapper_module, "_flush_tensor_transport_state", lambda marker: 3) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy_registry", + SimpleNamespace( + ModelPatcherRegistry=lambda: SimpleNamespace( + sweep_pending_cleanup=lambda: 0 + ) + ), + ) + + flushed = asyncio.run( + ComfyNodeExtension.flush_transport_state(SimpleNamespace(name="demo")) + ) + + assert flushed == 3 + + +def test_build_stub_class_relieves_host_vram_without_legacy_flag(monkeypatch): + relieve_calls: list[str] = [] + + async def deserialize_from_isolation(result, extension): + return result + + monkeypatch.delenv("PYISOLATE_CHILD", raising=False) + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.setattr( + runtime_helpers, "_relieve_host_vram_pressure", lambda marker, logger: relieve_calls.append(marker) + ) + monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None) + monkeypatch.setattr(isolation_pkg, "_RUNNING_EXTENSIONS", {}, raising=False) + monkeypatch.setitem( + sys.modules, + "pyisolate._internal.model_serialization", + SimpleNamespace( + serialize_for_isolation=lambda payload: payload, + deserialize_from_isolation=deserialize_from_isolation, + ), + ) + + class DummyExtension: + name = "demo-extension" + module_path = os.getcwd() + + async def execute_node(self, node_name, **inputs): + return inputs + + stub_cls = runtime_helpers.build_stub_class( + "DemoNode", + {"input_types": {}}, + DummyExtension(), + {}, + logging.getLogger("test"), + ) + + result = asyncio.run( + getattr(stub_cls, "_pyisolate_execute")(SimpleNamespace(), value=1) + ) + + assert relieve_calls == ["RUNTIME:pre_execute"] + assert result == {"value": 1} diff --git a/tests/isolation/test_exact_proxy_bootstrap_contract.py b/tests/isolation/test_exact_proxy_bootstrap_contract.py new file mode 100644 index 000000000..c67fb5ac4 --- /dev/null +++ b/tests/isolation/test_exact_proxy_bootstrap_contract.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from tests.isolation.singleton_boundary_helpers import ( + capture_exact_proxy_bootstrap_contract, +) + + +def test_no_proxy_omission_allowed() -> None: + payload = capture_exact_proxy_bootstrap_contract() + + assert payload["omitted_proxies"] == [] + assert payload["forbidden_matches"] == [] + + matrix = payload["matrix"] + assert matrix["base.py"]["bound"] is True + assert matrix["folder_paths_proxy.py"]["bound"] is True + assert matrix["helper_proxies.py"]["bound"] is True + assert matrix["model_management_proxy.py"]["bound"] is True + assert matrix["progress_proxy.py"]["bound"] is True + assert matrix["prompt_server_impl.py"]["bound"] is True + assert matrix["utils_proxy.py"]["bound"] is True + assert matrix["web_directory_proxy.py"]["bound"] is True diff --git a/tests/isolation/test_exact_proxy_relay_matrix.py b/tests/isolation/test_exact_proxy_relay_matrix.py new file mode 100644 index 000000000..ca9dbf94d --- /dev/null +++ b/tests/isolation/test_exact_proxy_relay_matrix.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from tests.isolation.singleton_boundary_helpers import ( + capture_exact_small_proxy_relay, + capture_model_management_exact_relay, + capture_prompt_web_exact_relay, +) + + +def _transcripts_for(payload: dict[str, object], object_id: str, method: str) -> list[dict[str, object]]: + return [ + entry + for entry in payload["transcripts"] + if entry["object_id"] == object_id and entry["method"] == method + ] + + +def test_folder_paths_exact_relay() -> None: + payload = capture_exact_small_proxy_relay() + + assert payload["forbidden_matches"] == [] + assert payload["models_dir"] == "/sandbox/models" + assert payload["folder_path"] == "/sandbox/input/demo.png" + + models_dir_calls = _transcripts_for(payload, "FolderPathsProxy", "rpc_get_models_dir") + annotated_calls = _transcripts_for(payload, "FolderPathsProxy", "rpc_get_annotated_filepath") + + assert models_dir_calls + assert annotated_calls + assert all(entry["phase"] != "child_call" or entry["method"] != "rpc_snapshot" for entry in payload["transcripts"]) + + +def test_progress_exact_relay() -> None: + payload = capture_exact_small_proxy_relay() + + progress_calls = _transcripts_for(payload, "ProgressProxy", "rpc_set_progress") + + assert progress_calls + host_targets = [entry["target"] for entry in progress_calls if entry["phase"] == "host_invocation"] + assert host_targets == ["comfy_execution.progress.get_progress_state().update_progress"] + result_entries = [entry for entry in progress_calls if entry["phase"] == "result"] + assert result_entries == [{"phase": "result", "object_id": "ProgressProxy", "method": "rpc_set_progress", "result": None}] + + +def test_utils_exact_relay() -> None: + payload = capture_exact_small_proxy_relay() + + utils_calls = _transcripts_for(payload, "UtilsProxy", "progress_bar_hook") + + assert utils_calls + host_targets = [entry["target"] for entry in utils_calls if entry["phase"] == "host_invocation"] + assert host_targets == ["comfy.utils.PROGRESS_BAR_HOOK"] + result_entries = [entry for entry in utils_calls if entry["phase"] == "result"] + assert result_entries + assert result_entries[0]["result"]["value"] == 2 + assert result_entries[0]["result"]["total"] == 5 + + +def test_helper_proxy_exact_relay() -> None: + payload = capture_exact_small_proxy_relay() + + helper_calls = _transcripts_for(payload, "HelperProxiesService", "rpc_restore_input_types") + + assert helper_calls + host_targets = [entry["target"] for entry in helper_calls if entry["phase"] == "host_invocation"] + assert host_targets == ["comfy.isolation.proxies.helper_proxies.restore_input_types"] + assert payload["restored_any_type"] == "*" + + +def test_model_management_exact_relay() -> None: + payload = capture_model_management_exact_relay() + + model_calls = _transcripts_for(payload, "ModelManagementProxy", "get_torch_device") + model_calls += _transcripts_for(payload, "ModelManagementProxy", "get_torch_device_name") + model_calls += _transcripts_for(payload, "ModelManagementProxy", "get_free_memory") + + assert payload["forbidden_matches"] == [] + assert model_calls + host_targets = [ + entry["target"] + for entry in payload["transcripts"] + if entry["phase"] == "host_invocation" + ] + assert host_targets == [ + "comfy.model_management.get_torch_device", + "comfy.model_management.get_torch_device_name", + "comfy.model_management.get_free_memory", + ] + + +def test_model_management_capability_preserved() -> None: + payload = capture_model_management_exact_relay() + + assert payload["device"] == "cpu" + assert payload["device_type"] == "cpu" + assert payload["device_name"] == "cpu" + assert payload["free_memory"] == 34359738368 + + +def test_prompt_server_exact_relay() -> None: + payload = capture_prompt_web_exact_relay() + + prompt_calls = _transcripts_for(payload, "PromptServerService", "ui_send_progress_text") + prompt_calls += _transcripts_for(payload, "PromptServerService", "register_route_rpc") + + assert payload["forbidden_matches"] == [] + assert prompt_calls + host_targets = [ + entry["target"] + for entry in payload["transcripts"] + if entry["object_id"] == "PromptServerService" and entry["phase"] == "host_invocation" + ] + assert host_targets == [ + "server.PromptServer.instance.send_progress_text", + "server.PromptServer.instance.routes.add_route", + ] + + +def test_web_directory_exact_relay() -> None: + payload = capture_prompt_web_exact_relay() + + web_calls = _transcripts_for(payload, "WebDirectoryProxy", "get_web_file") + + assert web_calls + host_targets = [entry["target"] for entry in web_calls if entry["phase"] == "host_invocation"] + assert host_targets == ["comfy.isolation.proxies.web_directory_proxy.WebDirectoryProxy.get_web_file"] + assert payload["web_file"]["content_type"] == "application/javascript" + assert payload["web_file"]["content"] == "console.log('deo');" diff --git a/tests/isolation/test_extension_loader_conda.py b/tests/isolation/test_extension_loader_conda.py new file mode 100644 index 000000000..21154655f --- /dev/null +++ b/tests/isolation/test_extension_loader_conda.py @@ -0,0 +1,428 @@ +"""Tests for conda config parsing in extension_loader.py (Slice 5). + +These tests verify that extension_loader.py correctly parses conda-related +fields from pyproject.toml manifests and passes them into the extension config +dict given to pyisolate. The torch import chain is broken by pre-mocking +extension_wrapper before importing extension_loader. +""" + +from __future__ import annotations + +import importlib +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +def _make_manifest( + *, + package_manager: str = "uv", + conda_channels: list[str] | None = None, + conda_dependencies: list[str] | None = None, + conda_platforms: list[str] | None = None, + share_torch: bool = False, + can_isolate: bool = True, + dependencies: list[str] | None = None, + cuda_wheels: list[str] | None = None, +) -> dict: + """Build a manifest dict matching tomllib.load() output.""" + isolation: dict = {"can_isolate": can_isolate} + if package_manager != "uv": + isolation["package_manager"] = package_manager + if conda_channels is not None: + isolation["conda_channels"] = conda_channels + if conda_dependencies is not None: + isolation["conda_dependencies"] = conda_dependencies + if conda_platforms is not None: + isolation["conda_platforms"] = conda_platforms + if share_torch: + isolation["share_torch"] = True + if cuda_wheels is not None: + isolation["cuda_wheels"] = cuda_wheels + + return { + "project": { + "name": "test-extension", + "dependencies": dependencies or ["numpy"], + }, + "tool": {"comfy": {"isolation": isolation}}, + } + + +@pytest.fixture +def manifest_file(tmp_path): + """Create a dummy pyproject.toml so manifest_path.open('rb') succeeds.""" + path = tmp_path / "pyproject.toml" + path.write_bytes(b"") # content is overridden by tomllib mock + return path + + +@pytest.fixture +def loader_module(monkeypatch): + """Import extension_loader under a mocked isolation package for this test only.""" + mock_wrapper = MagicMock() + mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {}) + + iso_mod = types.ModuleType("comfy.isolation") + iso_mod.__path__ = [ # type: ignore[attr-defined] + str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation") + ] + iso_mod.__package__ = "comfy.isolation" + + manifest_loader = types.SimpleNamespace( + is_cache_valid=lambda *args, **kwargs: False, + load_from_cache=lambda *args, **kwargs: None, + save_to_cache=lambda *args, **kwargs: None, + ) + host_policy = types.SimpleNamespace( + load_host_policy=lambda base_path: { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + } + ) + folder_paths = types.SimpleNamespace(base_path="/fake/comfyui") + + monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod) + monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper) + monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock()) + monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader) + monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy) + monkeypatch.setitem(sys.modules, "folder_paths", folder_paths) + sys.modules.pop("comfy.isolation.extension_loader", None) + + module = importlib.import_module("comfy.isolation.extension_loader") + try: + yield module, mock_wrapper + finally: + sys.modules.pop("comfy.isolation.extension_loader", None) + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"): + delattr(comfy_pkg, "isolation") + + +@pytest.fixture +def mock_pyisolate(loader_module): + """Mock pyisolate to avoid real venv creation.""" + module, mock_wrapper = loader_module + mock_ext = AsyncMock() + mock_ext.list_nodes = AsyncMock(return_value={}) + + mock_manager = MagicMock() + mock_manager.load_extension = MagicMock(return_value=mock_ext) + sealed_type = type("SealedNodeExtension", (), {}) + + with patch.object(module, "pyisolate") as mock_pi: + mock_pi.ExtensionManager = MagicMock(return_value=mock_manager) + mock_pi.SealedNodeExtension = sealed_type + yield module, mock_pi, mock_manager, mock_ext, mock_wrapper + + +def load_isolated_node(*args, **kwargs): + return sys.modules["comfy.isolation.extension_loader"].load_isolated_node( + *args, **kwargs + ) + + +class TestCondaPackageManagerParsing: + """Verify extension_loader.py parses conda config from pyproject.toml.""" + + @pytest.mark.asyncio + async def test_conda_package_manager_in_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """package_manager='conda' must appear in extension_config.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["package_manager"] == "conda" + + @pytest.mark.asyncio + async def test_conda_channels_in_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """conda_channels must be passed through to extension_config.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge", "nvidia"], + conda_dependencies=["eccodes"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["conda_channels"] == ["conda-forge", "nvidia"] + + @pytest.mark.asyncio + async def test_conda_dependencies_in_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """conda_dependencies must be passed through to extension_config.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes", "cfgrib"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["conda_dependencies"] == ["eccodes", "cfgrib"] + + @pytest.mark.asyncio + async def test_conda_platforms_in_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """conda_platforms must be passed through to extension_config.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + conda_platforms=["linux-64"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["conda_platforms"] == ["linux-64"] + + +class TestCondaForcedOverrides: + """Verify conda forces share_torch=False, share_cuda_ipc=False.""" + + @pytest.mark.asyncio + async def test_conda_forces_share_torch_false( + self, mock_pyisolate, manifest_file, tmp_path + ): + """share_torch must be forced False for conda, even if manifest says True.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + share_torch=True, # manifest requests True — must be overridden + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["share_torch"] is False + + @pytest.mark.asyncio + async def test_conda_forces_share_cuda_ipc_false( + self, mock_pyisolate, manifest_file, tmp_path + ): + """share_cuda_ipc must be forced False for conda.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + share_torch=True, + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["share_cuda_ipc"] is False + + @pytest.mark.asyncio + async def test_conda_sealed_worker_uses_host_policy_sandbox_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """Conda sealed_worker must carry the host-policy sandbox config on Linux.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with ( + patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib, + patch( + "comfy.isolation.extension_loader.platform.system", + return_value="Linux", + ), + ): + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["sandbox"] == { + "network": False, + "writable_paths": [], + "readonly_paths": [], + } + + @pytest.mark.asyncio + async def test_conda_uses_sealed_extension_type( + self, mock_pyisolate, manifest_file, tmp_path + ): + """Conda must not launch through ComfyNodeExtension.""" + + _, mock_pi, _, _, mock_wrapper = mock_pyisolate + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + ) + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type.__name__ == "SealedNodeExtension" + assert extension_type is not mock_wrapper.ComfyNodeExtension + + +class TestUvUnchanged: + """Verify uv extensions are NOT affected by conda changes.""" + + @pytest.mark.asyncio + async def test_uv_default_no_conda_keys( + self, mock_pyisolate, manifest_file, tmp_path + ): + """Default uv extension must NOT have package_manager or conda keys.""" + + manifest = _make_manifest() # defaults: uv, no conda fields + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + # uv extensions should not have conda-specific keys + assert config.get("package_manager", "uv") == "uv" + assert "conda_channels" not in config + assert "conda_dependencies" not in config + + @pytest.mark.asyncio + async def test_uv_keeps_comfy_extension_type( + self, mock_pyisolate, manifest_file, tmp_path + ): + """uv keeps the existing ComfyNodeExtension path.""" + + _, mock_pi, _, _, _ = mock_pyisolate + manifest = _make_manifest() + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type.__name__ == "ComfyNodeExtension" + assert extension_type is not mock_pi.SealedNodeExtension diff --git a/tests/isolation/test_extension_loader_sealed_worker.py b/tests/isolation/test_extension_loader_sealed_worker.py new file mode 100644 index 000000000..d694b178f --- /dev/null +++ b/tests/isolation/test_extension_loader_sealed_worker.py @@ -0,0 +1,281 @@ +"""Tests for execution_model parsing and sealed-worker loader selection.""" + +from __future__ import annotations + +import importlib +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +def _make_manifest( + *, + package_manager: str = "uv", + execution_model: str | None = None, + can_isolate: bool = True, + dependencies: list[str] | None = None, + sealed_host_ro_paths: list[str] | None = None, +) -> dict: + isolation: dict = {"can_isolate": can_isolate} + if package_manager != "uv": + isolation["package_manager"] = package_manager + if execution_model is not None: + isolation["execution_model"] = execution_model + if sealed_host_ro_paths is not None: + isolation["sealed_host_ro_paths"] = sealed_host_ro_paths + + return { + "project": { + "name": "test-extension", + "dependencies": dependencies or ["numpy"], + }, + "tool": {"comfy": {"isolation": isolation}}, + } + + +@pytest.fixture +def manifest_file(tmp_path): + path = tmp_path / "pyproject.toml" + path.write_bytes(b"") + return path + + +@pytest.fixture +def loader_module(monkeypatch): + mock_wrapper = MagicMock() + mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {}) + + iso_mod = types.ModuleType("comfy.isolation") + iso_mod.__path__ = [ # type: ignore[attr-defined] + str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation") + ] + iso_mod.__package__ = "comfy.isolation" + + manifest_loader = types.SimpleNamespace( + is_cache_valid=lambda *args, **kwargs: False, + load_from_cache=lambda *args, **kwargs: None, + save_to_cache=lambda *args, **kwargs: None, + ) + host_policy = types.SimpleNamespace( + load_host_policy=lambda base_path: { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": [], + } + ) + folder_paths = types.SimpleNamespace(base_path="/fake/comfyui") + + monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod) + monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper) + monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock()) + monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader) + monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy) + monkeypatch.setitem(sys.modules, "folder_paths", folder_paths) + sys.modules.pop("comfy.isolation.extension_loader", None) + + module = importlib.import_module("comfy.isolation.extension_loader") + try: + yield module + finally: + sys.modules.pop("comfy.isolation.extension_loader", None) + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"): + delattr(comfy_pkg, "isolation") + + +@pytest.fixture +def mock_pyisolate(loader_module): + mock_ext = AsyncMock() + mock_ext.list_nodes = AsyncMock(return_value={}) + + mock_manager = MagicMock() + mock_manager.load_extension = MagicMock(return_value=mock_ext) + sealed_type = type("SealedNodeExtension", (), {}) + + with patch.object(loader_module, "pyisolate") as mock_pi: + mock_pi.ExtensionManager = MagicMock(return_value=mock_manager) + mock_pi.SealedNodeExtension = sealed_type + yield loader_module, mock_pi, mock_manager, mock_ext, sealed_type + + +def load_isolated_node(*args, **kwargs): + return sys.modules["comfy.isolation.extension_loader"].load_isolated_node(*args, **kwargs) + + +@pytest.mark.asyncio +async def test_uv_sealed_worker_selects_sealed_extension_type( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest(execution_model="sealed_worker") + + _, mock_pi, mock_manager, _, sealed_type = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + config = mock_manager.load_extension.call_args[0][0] + assert extension_type is sealed_type + assert config["execution_model"] == "sealed_worker" + assert "apis" not in config + + +@pytest.mark.asyncio +async def test_default_uv_keeps_host_coupled_extension_type( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest() + + _, mock_pi, mock_manager, _, sealed_type = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + config = mock_manager.load_extension.call_args[0][0] + assert extension_type is not sealed_type + assert "execution_model" not in config + + +@pytest.mark.asyncio +async def test_conda_without_execution_model_remains_sealed_worker( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest(package_manager="conda") + manifest["tool"]["comfy"]["isolation"]["conda_channels"] = ["conda-forge"] + manifest["tool"]["comfy"]["isolation"]["conda_dependencies"] = ["eccodes"] + + _, mock_pi, mock_manager, _, sealed_type = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + config = mock_manager.load_extension.call_args[0][0] + assert extension_type is sealed_type + assert config["execution_model"] == "sealed_worker" + + +@pytest.mark.asyncio +async def test_sealed_worker_uses_host_policy_ro_import_paths( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest(execution_model="sealed_worker") + + module, _, mock_manager, _, _ = mock_pyisolate + + with ( + patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib, + patch.object( + module, + "load_host_policy", + return_value={ + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"], + }, + ), + ): + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["sealed_host_ro_paths"] == ["/home/johnj/ComfyUI"] + + +@pytest.mark.asyncio +async def test_host_coupled_does_not_emit_sealed_host_ro_paths( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest(execution_model="host-coupled") + + module, _, mock_manager, _, _ = mock_pyisolate + + with ( + patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib, + patch.object( + module, + "load_host_policy", + return_value={ + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"], + }, + ), + ): + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert "sealed_host_ro_paths" not in config + + +@pytest.mark.asyncio +async def test_sealed_worker_manifest_ro_import_paths_blocked( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest( + execution_model="sealed_worker", + sealed_host_ro_paths=["/home/johnj/ComfyUI"], + ) + + _, _, _mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + with pytest.raises(ValueError, match="Manifest field 'sealed_host_ro_paths' is not allowed"): + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) diff --git a/tests/isolation/test_folder_paths_proxy.py b/tests/isolation/test_folder_paths_proxy.py new file mode 100644 index 000000000..451f5e607 --- /dev/null +++ b/tests/isolation/test_folder_paths_proxy.py @@ -0,0 +1,122 @@ +"""Unit tests for FolderPathsProxy.""" + +import pytest +from pathlib import Path + +from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy +from tests.isolation.singleton_boundary_helpers import capture_sealed_singleton_imports + + +class TestFolderPathsProxy: + """Test FolderPathsProxy methods.""" + + @pytest.fixture + def proxy(self): + """Create a FolderPathsProxy instance for testing.""" + return FolderPathsProxy() + + def test_get_temp_directory_returns_string(self, proxy): + """Verify get_temp_directory returns a non-empty string.""" + result = proxy.get_temp_directory() + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Temp directory path is empty" + + def test_get_temp_directory_returns_absolute_path(self, proxy): + """Verify get_temp_directory returns an absolute path.""" + result = proxy.get_temp_directory() + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_get_input_directory_returns_string(self, proxy): + """Verify get_input_directory returns a non-empty string.""" + result = proxy.get_input_directory() + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Input directory path is empty" + + def test_get_input_directory_returns_absolute_path(self, proxy): + """Verify get_input_directory returns an absolute path.""" + result = proxy.get_input_directory() + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_get_annotated_filepath_plain_name(self, proxy): + """Verify get_annotated_filepath works with plain filename.""" + result = proxy.get_annotated_filepath("test.png") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.png" in result, f"Filename not in result: {result}" + + def test_get_annotated_filepath_with_output_annotation(self, proxy): + """Verify get_annotated_filepath handles [output] annotation.""" + result = proxy.get_annotated_filepath("test.png[output]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + # Should resolve to output directory + assert "output" in result.lower() or Path(result).parent.name == "output" + + def test_get_annotated_filepath_with_input_annotation(self, proxy): + """Verify get_annotated_filepath handles [input] annotation.""" + result = proxy.get_annotated_filepath("test.png[input]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + + def test_get_annotated_filepath_with_temp_annotation(self, proxy): + """Verify get_annotated_filepath handles [temp] annotation.""" + result = proxy.get_annotated_filepath("test.png[temp]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + + def test_exists_annotated_filepath_returns_bool(self, proxy): + """Verify exists_annotated_filepath returns a boolean.""" + result = proxy.exists_annotated_filepath("nonexistent.png") + assert isinstance(result, bool), f"Expected bool, got {type(result)}" + + def test_exists_annotated_filepath_nonexistent_file(self, proxy): + """Verify exists_annotated_filepath returns False for nonexistent file.""" + result = proxy.exists_annotated_filepath("definitely_does_not_exist_12345.png") + assert result is False, "Expected False for nonexistent file" + + def test_exists_annotated_filepath_with_annotation(self, proxy): + """Verify exists_annotated_filepath works with annotation suffix.""" + # Even for nonexistent files, should return bool without error + result = proxy.exists_annotated_filepath("test.png[output]") + assert isinstance(result, bool), f"Expected bool, got {type(result)}" + + def test_models_dir_property_returns_string(self, proxy): + """Verify models_dir property returns valid path string.""" + result = proxy.models_dir + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Models directory path is empty" + + def test_models_dir_is_absolute_path(self, proxy): + """Verify models_dir returns an absolute path.""" + result = proxy.models_dir + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_add_model_folder_path_runs_without_error(self, proxy): + """Verify add_model_folder_path executes without raising.""" + test_path = "/tmp/test_models_florence2" + # Should not raise + proxy.add_model_folder_path("TEST_FLORENCE2", test_path) + + def test_get_folder_paths_returns_list(self, proxy): + """Verify get_folder_paths returns a list.""" + # Use known folder type that should exist + result = proxy.get_folder_paths("checkpoints") + assert isinstance(result, list), f"Expected list, got {type(result)}" + + def test_get_folder_paths_checkpoints_not_empty(self, proxy): + """Verify checkpoints folder paths list is not empty.""" + result = proxy.get_folder_paths("checkpoints") + # Should have at least one checkpoint path registered + assert len(result) > 0, "Checkpoints folder paths is empty" + + def test_sealed_child_safe_uses_rpc_without_importing_folder_paths(self, monkeypatch): + monkeypatch.setenv("PYISOLATE_CHILD", "1") + monkeypatch.setenv("PYISOLATE_IMPORT_TORCH", "0") + + payload = capture_sealed_singleton_imports() + + assert payload["temp_dir"] == "/sandbox/temp" + assert payload["models_dir"] == "/sandbox/models" + assert "folder_paths" not in payload["modules"] diff --git a/tests/isolation/test_host_policy.py b/tests/isolation/test_host_policy.py new file mode 100644 index 000000000..46d06bb38 --- /dev/null +++ b/tests/isolation/test_host_policy.py @@ -0,0 +1,209 @@ +from pathlib import Path + +import pytest + + +def _write_pyproject(path: Path, content: str) -> None: + path.write_text(content, encoding="utf-8") + + +def test_load_host_policy_defaults_when_pyproject_missing(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + policy = load_host_policy(tmp_path) + + assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"] + assert policy["allow_network"] == DEFAULT_POLICY["allow_network"] + assert policy["writable_paths"] == DEFAULT_POLICY["writable_paths"] + assert policy["readonly_paths"] == DEFAULT_POLICY["readonly_paths"] + assert policy["whitelist"] == DEFAULT_POLICY["whitelist"] + + +def test_load_host_policy_defaults_when_section_missing(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[project] +name = "ComfyUI" +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"] + assert policy["allow_network"] == DEFAULT_POLICY["allow_network"] + assert policy["whitelist"] == {} + + +def test_load_host_policy_reads_values(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sandbox_mode = "disabled" +allow_network = true +writable_paths = ["/tmp/a", "/tmp/b"] +readonly_paths = ["/opt/readonly"] + +[tool.comfy.host.whitelist] +ExampleNode = "*" +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["sandbox_mode"] == "disabled" + assert policy["allow_network"] is True + assert policy["writable_paths"] == ["/tmp/a", "/tmp/b"] + assert policy["readonly_paths"] == ["/opt/readonly"] + assert policy["whitelist"] == {"ExampleNode": "*"} + + +def test_load_host_policy_ignores_invalid_whitelist_type(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +allow_network = true +whitelist = ["bad"] +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["allow_network"] is True + assert policy["whitelist"] == DEFAULT_POLICY["whitelist"] + + +def test_load_host_policy_ignores_invalid_sandbox_mode(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sandbox_mode = "surprise" +""".strip(), + ) + + policy = load_host_policy(tmp_path) + + assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"] + + +def test_load_host_policy_uses_env_override_path(tmp_path, monkeypatch): + from comfy.isolation.host_policy import load_host_policy + + override_path = tmp_path / "host_policy_override.toml" + _write_pyproject( + override_path, + """ +[tool.comfy.host] +sandbox_mode = "disabled" +allow_network = true +""".strip(), + ) + + monkeypatch.setenv("COMFY_HOST_POLICY_PATH", str(override_path)) + + policy = load_host_policy(tmp_path / "missing-root") + + assert policy["sandbox_mode"] == "disabled" + assert policy["allow_network"] is True + + +def test_disallows_host_tmp_default_or_override_defaults(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + policy = load_host_policy(tmp_path) + + assert "/tmp" not in DEFAULT_POLICY["writable_paths"] + assert "/tmp" not in policy["writable_paths"] + + +def test_disallows_host_tmp_default_or_override_config(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +writable_paths = ["/dev/shm", "/tmp", "/tmp/", "/work/cache"] +""".strip(), + ) + + policy = load_host_policy(tmp_path) + + assert policy["writable_paths"] == ["/dev/shm", "/work/cache"] + + +def test_sealed_worker_ro_import_paths_defaults_off_and_parse(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + policy = load_host_policy(tmp_path) + assert policy["sealed_worker_ro_import_paths"] == [] + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = ["/home/johnj/ComfyUI", "/opt/comfy-shared"] +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["sealed_worker_ro_import_paths"] == [ + "/home/johnj/ComfyUI", + "/opt/comfy-shared", + ] + + +def test_sealed_worker_ro_import_paths_rejects_non_list_or_relative(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = "/home/johnj/ComfyUI" +""".strip(), + ) + with pytest.raises(ValueError, match="must be a list of absolute paths"): + load_host_policy(tmp_path) + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = ["relative/path"] +""".strip(), + ) + with pytest.raises(ValueError, match="entries must be absolute paths"): + load_host_policy(tmp_path) + + +def test_host_policy_path_override_controls_ro_import_paths(tmp_path, monkeypatch): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = ["/ignored/base/path"] +""".strip(), + ) + override_path = tmp_path / "host_policy_override.toml" + _write_pyproject( + override_path, + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = ["/override/ro/path"] +""".strip(), + ) + monkeypatch.setenv("COMFY_HOST_POLICY_PATH", str(override_path)) + + policy = load_host_policy(tmp_path) + assert policy["sealed_worker_ro_import_paths"] == ["/override/ro/path"] diff --git a/tests/isolation/test_init.py b/tests/isolation/test_init.py new file mode 100644 index 000000000..c237fe904 --- /dev/null +++ b/tests/isolation/test_init.py @@ -0,0 +1,80 @@ +"""Unit tests for PyIsolate isolation system initialization.""" + +import importlib +import sys + +from tests.isolation.singleton_boundary_helpers import ( + FakeSingletonRPC, + reset_forbidden_singleton_modules, +) + + +def test_log_prefix(): + """Verify LOG_PREFIX constant is correctly defined.""" + from comfy.isolation import LOG_PREFIX + assert LOG_PREFIX == "][" + assert isinstance(LOG_PREFIX, str) + + +def test_module_initialization(): + """Verify module initializes without errors.""" + isolation_pkg = importlib.import_module("comfy.isolation") + assert hasattr(isolation_pkg, "LOG_PREFIX") + assert hasattr(isolation_pkg, "initialize_proxies") + + +class TestInitializeProxies: + def test_initialize_proxies_runs_without_error(self): + from comfy.isolation import initialize_proxies + initialize_proxies() + + def test_initialize_proxies_registers_folder_paths_proxy(self): + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + initialize_proxies() + proxy = FolderPathsProxy() + assert proxy is not None + assert hasattr(proxy, "get_temp_directory") + + def test_initialize_proxies_registers_model_management_proxy(self): + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + initialize_proxies() + proxy = ModelManagementProxy() + assert proxy is not None + assert hasattr(proxy, "get_torch_device") + + def test_initialize_proxies_can_be_called_multiple_times(self): + from comfy.isolation import initialize_proxies + initialize_proxies() + initialize_proxies() + initialize_proxies() + + def test_dev_proxies_accessible_when_dev_mode(self, monkeypatch): + """Verify dev mode does not break core proxy initialization.""" + monkeypatch.setenv("PYISOLATE_DEV", "1") + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + initialize_proxies() + folder_proxy = FolderPathsProxy() + utils_proxy = UtilsProxy() + assert folder_proxy is not None + assert utils_proxy is not None + + def test_sealed_child_safe_initialize_proxies_avoids_real_utils_import(self, monkeypatch): + monkeypatch.setenv("PYISOLATE_CHILD", "1") + monkeypatch.setenv("PYISOLATE_IMPORT_TORCH", "0") + reset_forbidden_singleton_modules() + + from pyisolate._internal import rpc_protocol + from comfy.isolation import initialize_proxies + + fake_rpc = FakeSingletonRPC() + monkeypatch.setattr(rpc_protocol, "get_child_rpc_instance", lambda: fake_rpc) + + initialize_proxies() + + assert "comfy.utils" not in sys.modules + assert "folder_paths" not in sys.modules + assert "comfy_execution.progress" not in sys.modules diff --git a/tests/isolation/test_internal_probe_node_assets.py b/tests/isolation/test_internal_probe_node_assets.py new file mode 100644 index 000000000..c12cf4404 --- /dev/null +++ b/tests/isolation/test_internal_probe_node_assets.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import importlib.util +import json +from pathlib import Path + + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +ISOLATION_ROOT = COMFYUI_ROOT / "tests" / "isolation" +PROBE_ROOT = ISOLATION_ROOT / "internal_probe_node" +WORKFLOW_ROOT = ISOLATION_ROOT / "workflows" +TOOLKIT_ROOT = COMFYUI_ROOT / "custom_nodes" / "ComfyUI-IsolationToolkit" + +EXPECTED_PROBE_FILES = { + "__init__.py", + "probe_nodes.py", +} +EXPECTED_WORKFLOWS = { + "internal_probe_preview_image_audio.json", + "internal_probe_ui3d.json", +} +BANNED_REFERENCES = ( + "ComfyUI-IsolationToolkit", + "toolkit_smoke_playlist", + "run_isolation_toolkit_smoke.sh", +) + + +def _text_assets() -> list[Path]: + return sorted(list(PROBE_ROOT.rglob("*.py")) + list(WORKFLOW_ROOT.glob("internal_probe_*.json"))) + + +def _load_probe_package(): + spec = importlib.util.spec_from_file_location( + "internal_probe_node", + PROBE_ROOT / "__init__.py", + submodule_search_locations=[str(PROBE_ROOT)], + ) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def test_inventory_is_minimal_and_isolation_owned(): + assert PROBE_ROOT.is_dir() + assert WORKFLOW_ROOT.is_dir() + assert PROBE_ROOT.is_relative_to(ISOLATION_ROOT) + assert WORKFLOW_ROOT.is_relative_to(ISOLATION_ROOT) + assert not PROBE_ROOT.is_relative_to(TOOLKIT_ROOT) + + probe_files = {path.name for path in PROBE_ROOT.iterdir() if path.is_file()} + workflow_files = {path.name for path in WORKFLOW_ROOT.glob("internal_probe_*.json")} + + assert probe_files == EXPECTED_PROBE_FILES + assert workflow_files == EXPECTED_WORKFLOWS + + module = _load_probe_package() + mappings = module.NODE_CLASS_MAPPINGS + + assert sorted(mappings.keys()) == [ + "InternalIsolationProbeAudio", + "InternalIsolationProbeImage", + "InternalIsolationProbeUI3D", + ] + + preview_workflow = json.loads( + (WORKFLOW_ROOT / "internal_probe_preview_image_audio.json").read_text( + encoding="utf-8" + ) + ) + ui3d_workflow = json.loads( + (WORKFLOW_ROOT / "internal_probe_ui3d.json").read_text(encoding="utf-8") + ) + + assert [preview_workflow[node_id]["class_type"] for node_id in ("1", "2")] == [ + "InternalIsolationProbeImage", + "InternalIsolationProbeAudio", + ] + assert [ui3d_workflow[node_id]["class_type"] for node_id in ("1",)] == [ + "InternalIsolationProbeUI3D", + ] + + +def test_zero_toolkit_references_in_probe_assets(): + for asset in _text_assets(): + content = asset.read_text(encoding="utf-8") + for banned in BANNED_REFERENCES: + assert banned not in content, f"{asset} unexpectedly references {banned}" + + +def test_replacement_contract_has_zero_toolkit_references(): + contract_assets = [ + *(PROBE_ROOT.rglob("*.py")), + *WORKFLOW_ROOT.glob("internal_probe_*.json"), + ISOLATION_ROOT / "stage_internal_probe_node.py", + ISOLATION_ROOT / "internal_probe_host_policy.toml", + ] + + for asset in sorted(contract_assets): + assert asset.exists(), f"Missing replacement-contract asset: {asset}" + content = asset.read_text(encoding="utf-8") + for banned in BANNED_REFERENCES: + assert banned not in content, f"{asset} unexpectedly references {banned}" diff --git a/tests/isolation/test_internal_probe_node_loading.py b/tests/isolation/test_internal_probe_node_loading.py new file mode 100644 index 000000000..fd1a7268c --- /dev/null +++ b/tests/isolation/test_internal_probe_node_loading.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import sys +from pathlib import Path + +import pytest + +import nodes +from tests.isolation.stage_internal_probe_node import ( + PROBE_NODE_NAME, + stage_probe_node, + staged_probe_node, +) + + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +ISOLATION_ROOT = COMFYUI_ROOT / "tests" / "isolation" +PROBE_SOURCE_ROOT = ISOLATION_ROOT / "internal_probe_node" +EXPECTED_NODE_IDS = [ + "InternalIsolationProbeAudio", + "InternalIsolationProbeImage", + "InternalIsolationProbeUI3D", +] + +CLIENT_SCRIPT = """ +import importlib.util +import json +import os +import sys + +import pyisolate._internal.client # noqa: F401 # triggers snapshot bootstrap + +module_path = os.environ["PYISOLATE_MODULE_PATH"] +spec = importlib.util.spec_from_file_location( + "internal_probe_node", + os.path.join(module_path, "__init__.py"), + submodule_search_locations=[module_path], +) +module = importlib.util.module_from_spec(spec) +assert spec is not None +assert spec.loader is not None +sys.modules["internal_probe_node"] = module +spec.loader.exec_module(module) +print( + json.dumps( + { + "sys_path": list(sys.path), + "module_path": module_path, + "node_ids": sorted(module.NODE_CLASS_MAPPINGS.keys()), + } + ) +) +""" + + +def _run_client_process(env: dict[str, str]) -> dict: + pythonpath_parts = [str(COMFYUI_ROOT)] + existing = env.get("PYTHONPATH", "") + if existing: + pythonpath_parts.append(existing) + env["PYTHONPATH"] = ":".join(pythonpath_parts) + + result = subprocess.run( # noqa: S603 + [sys.executable, "-c", CLIENT_SCRIPT], + capture_output=True, + text=True, + env=env, + check=True, + ) + return json.loads(result.stdout.strip().splitlines()[-1]) + + +@pytest.fixture() +def staged_probe_module(tmp_path: Path) -> tuple[Path, Path]: + staged_comfy_root = tmp_path / "ComfyUI" + module_path = staged_comfy_root / "custom_nodes" / "InternalIsolationProbeNode" + shutil.copytree(PROBE_SOURCE_ROOT, module_path) + return staged_comfy_root, module_path + + +@pytest.mark.asyncio +async def test_staged_probe_node_discovered(staged_probe_module: tuple[Path, Path]) -> None: + _, module_path = staged_probe_module + class_mappings_snapshot = dict(nodes.NODE_CLASS_MAPPINGS) + display_name_snapshot = dict(nodes.NODE_DISPLAY_NAME_MAPPINGS) + loaded_module_dirs_snapshot = dict(nodes.LOADED_MODULE_DIRS) + + try: + ignore = set(nodes.NODE_CLASS_MAPPINGS.keys()) + loaded = await nodes.load_custom_node( + str(module_path), ignore=ignore, module_parent="custom_nodes" + ) + + assert loaded is True + assert nodes.LOADED_MODULE_DIRS["InternalIsolationProbeNode"] == str( + module_path.resolve() + ) + + for node_id in EXPECTED_NODE_IDS: + assert node_id in nodes.NODE_CLASS_MAPPINGS + node_cls = nodes.NODE_CLASS_MAPPINGS[node_id] + assert ( + getattr(node_cls, "RELATIVE_PYTHON_MODULE", None) + == "custom_nodes.InternalIsolationProbeNode" + ) + finally: + nodes.NODE_CLASS_MAPPINGS.clear() + nodes.NODE_CLASS_MAPPINGS.update(class_mappings_snapshot) + nodes.NODE_DISPLAY_NAME_MAPPINGS.clear() + nodes.NODE_DISPLAY_NAME_MAPPINGS.update(display_name_snapshot) + nodes.LOADED_MODULE_DIRS.clear() + nodes.LOADED_MODULE_DIRS.update(loaded_module_dirs_snapshot) + + +def test_staged_probe_node_module_path_is_valid_for_child_bootstrap( + tmp_path: Path, staged_probe_module: tuple[Path, Path] +) -> None: + staged_comfy_root, module_path = staged_probe_module + snapshot = { + "sys_path": [str(COMFYUI_ROOT), "/host/lib1", "/host/lib2"], + "sys_executable": sys.executable, + "sys_prefix": sys.prefix, + "environment": {}, + } + snapshot_path = tmp_path / "snapshot.json" + snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8") + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(snapshot_path), + "PYISOLATE_MODULE_PATH": str(module_path), + } + ) + + payload = _run_client_process(env) + + assert payload["module_path"] == str(module_path) + assert payload["node_ids"] == EXPECTED_NODE_IDS + assert str(COMFYUI_ROOT) in payload["sys_path"] + assert str(staged_comfy_root) not in payload["sys_path"] + + +def test_stage_probe_node_stages_only_under_explicit_root(tmp_path: Path) -> None: + comfy_root = tmp_path / "sandbox-root" + + module_path = stage_probe_node(comfy_root) + + assert module_path == comfy_root / "custom_nodes" / PROBE_NODE_NAME + assert module_path.is_dir() + assert (module_path / "__init__.py").is_file() + assert (module_path / "probe_nodes.py").is_file() + assert (module_path / "pyproject.toml").is_file() + + +def test_staged_probe_node_context_cleans_up_temp_root() -> None: + with staged_probe_node() as module_path: + staging_root = module_path.parents[1] + assert module_path.name == PROBE_NODE_NAME + assert module_path.is_dir() + assert staging_root.is_dir() + + assert not staging_root.exists() + + +def test_stage_script_requires_explicit_target_root() -> None: + result = subprocess.run( # noqa: S603 + [sys.executable, str(ISOLATION_ROOT / "stage_internal_probe_node.py")], + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode != 0 + assert "--target-root" in result.stderr diff --git a/tests/isolation/test_manifest_loader_cache.py b/tests/isolation/test_manifest_loader_cache.py new file mode 100644 index 000000000..ebee43b7e --- /dev/null +++ b/tests/isolation/test_manifest_loader_cache.py @@ -0,0 +1,434 @@ +""" +Unit tests for manifest_loader.py cache functions. + +Phase 1 tests verify: +1. Cache miss on first run (no cache exists) +2. Cache hit when nothing changes +3. Invalidation on .py file touch +4. Invalidation on manifest change +5. Cache location correctness (in venv_root, NOT in custom_nodes) +6. Corrupt cache handling (graceful failure) + +These tests verify the cache implementation is correct BEFORE it's activated +in extension_loader.py (Phase 2). +""" + +from __future__ import annotations + +import json +import sys +import time +from pathlib import Path +from unittest import mock + + + +class TestComputeCacheKey: + """Tests for compute_cache_key() function.""" + + def test_key_includes_manifest_content(self, tmp_path: Path) -> None: + """Cache key changes when manifest content changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + + # Initial manifest + manifest.write_text("isolated: true\ndependencies: []\n") + key1 = compute_cache_key(node_dir, manifest) + + # Modified manifest + manifest.write_text("isolated: true\ndependencies: [numpy]\n") + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when manifest content changes" + + def test_key_includes_py_file_mtime(self, tmp_path: Path) -> None: + """Cache key changes when any .py file is touched.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + py_file = node_dir / "nodes.py" + py_file.write_text("# test code") + + key1 = compute_cache_key(node_dir, manifest) + + # Wait a moment to ensure mtime changes + time.sleep(0.01) + py_file.write_text("# modified code") + + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when .py file mtime changes" + + def test_key_includes_python_version(self, tmp_path: Path) -> None: + """Cache key changes when Python version changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + key1 = compute_cache_key(node_dir, manifest) + + # Mock different Python version + with mock.patch.object(sys, "version", "3.99.0 (fake)"): + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when Python version changes" + + def test_key_includes_pyisolate_version(self, tmp_path: Path) -> None: + """Cache key changes when PyIsolate version changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + key1 = compute_cache_key(node_dir, manifest) + + # Mock different pyisolate version + with mock.patch.dict(sys.modules, {"pyisolate": mock.MagicMock(__version__="99.99.99")}): + # Need to reimport to pick up the mock + import importlib + from comfy.isolation import manifest_loader + importlib.reload(manifest_loader) + key2 = manifest_loader.compute_cache_key(node_dir, manifest) + + # Keys should be different (though the mock approach is tricky) + # At minimum, verify key is a valid hex string + assert len(key1) == 16, "Key should be 16 hex characters" + assert all(c in "0123456789abcdef" for c in key1), "Key should be hex" + assert len(key2) == 16, "Key should be 16 hex characters" + assert all(c in "0123456789abcdef" for c in key2), "Key should be hex" + + def test_key_excludes_pycache(self, tmp_path: Path) -> None: + """Cache key ignores __pycache__ directory changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + py_file = node_dir / "nodes.py" + py_file.write_text("# test code") + + key1 = compute_cache_key(node_dir, manifest) + + # Add __pycache__ file + pycache = node_dir / "__pycache__" + pycache.mkdir() + (pycache / "nodes.cpython-310.pyc").write_bytes(b"compiled") + + key2 = compute_cache_key(node_dir, manifest) + + assert key1 == key2, "Key should NOT change when __pycache__ modified" + + def test_key_is_deterministic(self, tmp_path: Path) -> None: + """Same inputs produce same key.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + + key1 = compute_cache_key(node_dir, manifest) + key2 = compute_cache_key(node_dir, manifest) + + assert key1 == key2, "Key should be deterministic" + + +class TestGetCachePath: + """Tests for get_cache_path() function.""" + + def test_returns_correct_paths(self, tmp_path: Path) -> None: + """Cache paths are in venv_root, not in node_dir.""" + from comfy.isolation.manifest_loader import get_cache_path + + node_dir = tmp_path / "custom_nodes" / "MyNode" + venv_root = tmp_path / ".pyisolate_venvs" + + key_file, data_file = get_cache_path(node_dir, venv_root) + + assert key_file == venv_root / "MyNode" / "cache" / "cache_key" + assert data_file == venv_root / "MyNode" / "cache" / "node_info.json" + + def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None: + """Verify cache is NOT stored in custom_nodes directory.""" + from comfy.isolation.manifest_loader import get_cache_path + + node_dir = tmp_path / "custom_nodes" / "MyNode" + venv_root = tmp_path / ".pyisolate_venvs" + + key_file, data_file = get_cache_path(node_dir, venv_root) + + # Neither path should be under node_dir + assert not str(key_file).startswith(str(node_dir)) + assert not str(data_file).startswith(str(node_dir)) + + +class TestIsCacheValid: + """Tests for is_cache_valid() function.""" + + def test_false_when_no_cache_exists(self, tmp_path: Path) -> None: + """Returns False when cache files don't exist.""" + from comfy.isolation.manifest_loader import is_cache_valid + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_true_when_cache_matches(self, tmp_path: Path) -> None: + """Returns True when cache key matches current state.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create valid cache + cache_key = compute_cache_key(node_dir, manifest) + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + data_file.write_text("{}") + + assert is_cache_valid(node_dir, manifest, venv_root) is True + + def test_false_when_key_mismatch(self, tmp_path: Path) -> None: + """Returns False when stored key doesn't match current state.""" + from comfy.isolation.manifest_loader import get_cache_path, is_cache_valid + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create cache with wrong key + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text("wrong_key_12345") + data_file.write_text("{}") + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_false_when_data_file_missing(self, tmp_path: Path) -> None: + """Returns False when node_info.json is missing.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create only key file, not data file + cache_key = compute_cache_key(node_dir, manifest) + key_file, _ = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_invalidation_on_py_change(self, tmp_path: Path) -> None: + """Cache invalidates when .py file is modified.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + py_file = node_dir / "nodes.py" + py_file.write_text("# original") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create valid cache + cache_key = compute_cache_key(node_dir, manifest) + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + data_file.write_text("{}") + + # Verify cache is valid initially + assert is_cache_valid(node_dir, manifest, venv_root) is True + + # Modify .py file + time.sleep(0.01) # Ensure mtime changes + py_file.write_text("# modified") + + # Cache should now be invalid + assert is_cache_valid(node_dir, manifest, venv_root) is False + + +class TestLoadFromCache: + """Tests for load_from_cache() function.""" + + def test_returns_none_when_no_cache(self, tmp_path: Path) -> None: + """Returns None when cache doesn't exist.""" + from comfy.isolation.manifest_loader import load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + assert load_from_cache(node_dir, venv_root) is None + + def test_returns_data_when_valid(self, tmp_path: Path) -> None: + """Returns cached data when file exists and is valid JSON.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"inputs": [], "outputs": []}} + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text(json.dumps(test_data)) + + result = load_from_cache(node_dir, venv_root) + assert result == test_data + + def test_returns_none_on_corrupt_json(self, tmp_path: Path) -> None: + """Returns None when JSON is corrupt.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text("{ corrupt json }") + + assert load_from_cache(node_dir, venv_root) is None + + def test_returns_none_on_invalid_structure(self, tmp_path: Path) -> None: + """Returns None when data is not a dict.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text("[1, 2, 3]") # Array, not dict + + assert load_from_cache(node_dir, venv_root) is None + + +class TestSaveToCache: + """Tests for save_to_cache() function.""" + + def test_creates_cache_directory(self, tmp_path: Path) -> None: + """Creates cache directory if it doesn't exist.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest) + + key_file, data_file = get_cache_path(node_dir, venv_root) + assert key_file.parent.exists() + + def test_writes_both_files(self, tmp_path: Path) -> None: + """Writes both cache_key and node_info.json.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {"key": "value"}}, manifest) + + key_file, data_file = get_cache_path(node_dir, venv_root) + assert key_file.exists() + assert data_file.exists() + + def test_data_is_valid_json(self, tmp_path: Path) -> None: + """Written data can be parsed as JSON.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"inputs": ["IMAGE"], "outputs": ["IMAGE"]}} + save_to_cache(node_dir, venv_root, test_data, manifest) + + _, data_file = get_cache_path(node_dir, venv_root) + loaded = json.loads(data_file.read_text()) + assert loaded == test_data + + def test_roundtrip_with_validation(self, tmp_path: Path) -> None: + """Saved cache is immediately valid.""" + from comfy.isolation.manifest_loader import ( + is_cache_valid, + load_from_cache, + save_to_cache, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"foo": "bar"}} + save_to_cache(node_dir, venv_root, test_data, manifest) + + assert is_cache_valid(node_dir, manifest, venv_root) is True + assert load_from_cache(node_dir, venv_root) == test_data + + def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None: + """Verify no files written to custom_nodes directory.""" + from comfy.isolation.manifest_loader import save_to_cache + + node_dir = tmp_path / "custom_nodes" / "MyNode" + node_dir.mkdir(parents=True) + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest) + + # Check nothing was created under node_dir + for item in node_dir.iterdir(): + assert item.name == "pyisolate.yaml", f"Unexpected file in node_dir: {item}" diff --git a/tests/isolation/test_manifest_loader_discovery.py b/tests/isolation/test_manifest_loader_discovery.py new file mode 100644 index 000000000..101b5d1e2 --- /dev/null +++ b/tests/isolation/test_manifest_loader_discovery.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import importlib +import sys +from pathlib import Path +from types import ModuleType + + +def _write_manifest(path: Path, *, standalone: bool = False) -> None: + lines = [ + "[project]", + 'name = "test-node"', + 'version = "0.1.0"', + "", + "[tool.comfy.isolation]", + "can_isolate = true", + "share_torch = false", + ] + if standalone: + lines.append("standalone = true") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _load_manifest_loader(custom_nodes_root: Path): + folder_paths = ModuleType("folder_paths") + folder_paths.base_path = str(custom_nodes_root) + folder_paths.get_folder_paths = lambda kind: [str(custom_nodes_root)] if kind == "custom_nodes" else [] + sys.modules["folder_paths"] = folder_paths + + if "comfy.isolation" not in sys.modules: + iso_mod = ModuleType("comfy.isolation") + iso_mod.__path__ = [ # type: ignore[attr-defined] + str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation") + ] + iso_mod.__package__ = "comfy.isolation" + sys.modules["comfy.isolation"] = iso_mod + + sys.modules.pop("comfy.isolation.manifest_loader", None) + + import comfy.isolation.manifest_loader as manifest_loader + + return importlib.reload(manifest_loader) + + +def test_finds_top_level_isolation_manifest(tmp_path: Path) -> None: + node_dir = tmp_path / "TopLevelNode" + node_dir.mkdir(parents=True) + _write_manifest(node_dir / "pyproject.toml") + + manifest_loader = _load_manifest_loader(tmp_path) + manifests = manifest_loader.find_manifest_directories() + + assert manifests == [(node_dir, node_dir / "pyproject.toml")] + + +def test_ignores_nested_manifest_without_standalone_flag(tmp_path: Path) -> None: + toolkit_dir = tmp_path / "ToolkitNode" + toolkit_dir.mkdir(parents=True) + _write_manifest(toolkit_dir / "pyproject.toml") + + nested_dir = toolkit_dir / "packages" / "nested_fixture" + nested_dir.mkdir(parents=True) + _write_manifest(nested_dir / "pyproject.toml", standalone=False) + + manifest_loader = _load_manifest_loader(tmp_path) + manifests = manifest_loader.find_manifest_directories() + + assert manifests == [(toolkit_dir, toolkit_dir / "pyproject.toml")] + + +def test_finds_nested_standalone_manifest(tmp_path: Path) -> None: + toolkit_dir = tmp_path / "ToolkitNode" + toolkit_dir.mkdir(parents=True) + _write_manifest(toolkit_dir / "pyproject.toml") + + nested_dir = toolkit_dir / "packages" / "uv_sealed_worker" + nested_dir.mkdir(parents=True) + _write_manifest(nested_dir / "pyproject.toml", standalone=True) + + manifest_loader = _load_manifest_loader(tmp_path) + manifests = manifest_loader.find_manifest_directories() + + assert manifests == [ + (toolkit_dir, toolkit_dir / "pyproject.toml"), + (nested_dir, nested_dir / "pyproject.toml"), + ] diff --git a/tests/isolation/test_model_management_proxy.py b/tests/isolation/test_model_management_proxy.py new file mode 100644 index 000000000..3a03bd54d --- /dev/null +++ b/tests/isolation/test_model_management_proxy.py @@ -0,0 +1,50 @@ +"""Unit tests for ModelManagementProxy.""" + +import pytest +import torch + +from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + + +class TestModelManagementProxy: + """Test ModelManagementProxy methods.""" + + @pytest.fixture + def proxy(self): + """Create a ModelManagementProxy instance for testing.""" + return ModelManagementProxy() + + def test_get_torch_device_returns_device(self, proxy): + """Verify get_torch_device returns a torch.device object.""" + result = proxy.get_torch_device() + assert isinstance(result, torch.device), f"Expected torch.device, got {type(result)}" + + def test_get_torch_device_is_valid(self, proxy): + """Verify get_torch_device returns a valid device (cpu or cuda).""" + result = proxy.get_torch_device() + assert result.type in ("cpu", "cuda"), f"Unexpected device type: {result.type}" + + def test_get_torch_device_name_returns_string(self, proxy): + """Verify get_torch_device_name returns a non-empty string.""" + device = proxy.get_torch_device() + result = proxy.get_torch_device_name(device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Device name is empty" + + def test_get_torch_device_name_with_cpu(self, proxy): + """Verify get_torch_device_name works with CPU device.""" + cpu_device = torch.device("cpu") + result = proxy.get_torch_device_name(cpu_device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "cpu" in result.lower(), f"Expected 'cpu' in device name, got: {result}" + + def test_get_torch_device_name_with_cuda_if_available(self, proxy): + """Verify get_torch_device_name works with CUDA device if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + cuda_device = torch.device("cuda:0") + result = proxy.get_torch_device_name(cuda_device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + # Should contain device identifier + assert len(result) > 0, "CUDA device name is empty" diff --git a/tests/isolation/test_path_helpers.py b/tests/isolation/test_path_helpers.py new file mode 100644 index 000000000..af96f1fe0 --- /dev/null +++ b/tests/isolation/test_path_helpers.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path + +import pytest + +from pyisolate.path_helpers import build_child_sys_path, serialize_host_snapshot + + +def test_serialize_host_snapshot_includes_expected_keys(tmp_path: Path, monkeypatch) -> None: + output = tmp_path / "snapshot.json" + monkeypatch.setenv("EXTRA_FLAG", "1") + snapshot = serialize_host_snapshot(output_path=output, extra_env_keys=["EXTRA_FLAG"]) + + assert "sys_path" in snapshot + assert "sys_executable" in snapshot + assert "sys_prefix" in snapshot + assert "environment" in snapshot + assert output.exists() + assert snapshot["environment"].get("EXTRA_FLAG") == "1" + + persisted = json.loads(output.read_text(encoding="utf-8")) + assert persisted["sys_path"] == snapshot["sys_path"] + + +def test_build_child_sys_path_preserves_host_order() -> None: + host_paths = ["/host/root", "/host/site-packages"] + extra_paths = ["/node/.venv/lib/python3.12/site-packages"] + result = build_child_sys_path(host_paths, extra_paths, preferred_root=None) + assert result == host_paths + extra_paths + + +def test_build_child_sys_path_inserts_comfy_root_when_missing() -> None: + host_paths = ["/host/site-packages"] + comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + extra_paths: list[str] = [] + result = build_child_sys_path(host_paths, extra_paths, preferred_root=comfy_root) + assert result[0] == comfy_root + assert result[1:] == host_paths + + +def test_build_child_sys_path_deduplicates_entries(tmp_path: Path) -> None: + path_a = str(tmp_path / "a") + path_b = str(tmp_path / "b") + host_paths = [path_a, path_b] + extra_paths = [path_a, path_b, str(tmp_path / "c")] + result = build_child_sys_path(host_paths, extra_paths) + assert result == [path_a, path_b, str(tmp_path / "c")] + + +def test_build_child_sys_path_skips_duplicate_comfy_root() -> None: + comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + host_paths = [comfy_root, "/host/other"] + result = build_child_sys_path(host_paths, extra_paths=[], preferred_root=comfy_root) + assert result == host_paths + + +def test_child_import_succeeds_after_path_unification(tmp_path: Path, monkeypatch) -> None: + host_root = tmp_path / "host" + utils_pkg = host_root / "utils" + app_pkg = host_root / "app" + utils_pkg.mkdir(parents=True) + app_pkg.mkdir(parents=True) + + (utils_pkg / "__init__.py").write_text("from . import install_util\n", encoding="utf-8") + (utils_pkg / "install_util.py").write_text("VALUE = 'hello'\n", encoding="utf-8") + (app_pkg / "__init__.py").write_text("", encoding="utf-8") + (app_pkg / "frontend_management.py").write_text( + "from utils import install_util\nVALUE = install_util.VALUE\n", + encoding="utf-8", + ) + + child_only = tmp_path / "child_only" + child_only.mkdir() + + target_module = "app.frontend_management" + for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]: + sys.modules.pop(name) + + monkeypatch.setattr(sys, "path", [str(child_only)]) + with pytest.raises(ModuleNotFoundError): + __import__(target_module) + + for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]: + sys.modules.pop(name) + + unified = build_child_sys_path([], [], preferred_root=str(host_root)) + monkeypatch.setattr(sys, "path", unified) + module = __import__(target_module, fromlist=["VALUE"]) + assert module.VALUE == "hello" diff --git a/tests/isolation/test_runtime_helpers_stub_contract.py b/tests/isolation/test_runtime_helpers_stub_contract.py new file mode 100644 index 000000000..16e47eb06 --- /dev/null +++ b/tests/isolation/test_runtime_helpers_stub_contract.py @@ -0,0 +1,125 @@ +"""Generic runtime-helper stub contract tests.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast + +from comfy.isolation import runtime_helpers +from comfy_api.latest import io as latest_io +from tests.isolation.stage_internal_probe_node import PROBE_NODE_NAME, staged_probe_node + + +class _DummyExtension: + def __init__(self, *, name: str, module_path: str): + self.name = name + self.module_path = module_path + + async def execute_node(self, _node_name: str, **inputs): + return { + "__node_output__": True, + "args": (inputs,), + "ui": {"status": "ok"}, + "expand": False, + "block_execution": False, + } + + +def _install_model_serialization_stub(monkeypatch): + async def deserialize_from_isolation(payload, _extension): + return payload + + monkeypatch.setitem( + sys.modules, + "pyisolate._internal.model_serialization", + SimpleNamespace( + serialize_for_isolation=lambda payload: payload, + deserialize_from_isolation=deserialize_from_isolation, + ), + ) + + +def test_stub_sets_relative_python_module(monkeypatch): + _install_model_serialization_stub(monkeypatch) + monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None) + monkeypatch.setattr(runtime_helpers, "_relieve_host_vram_pressure", lambda *args, **kwargs: None) + + extension = _DummyExtension(name="internal_probe", module_path=os.getcwd()) + stub = cast(Any, runtime_helpers.build_stub_class( + "ProbeNode", + { + "is_v3": True, + "schema_v1": {}, + "input_types": {}, + }, + extension, + {}, + logging.getLogger("test"), + )) + + info = getattr(stub, "GET_NODE_INFO_V1")() + assert info["python_module"] == "custom_nodes.internal_probe" + + +def test_stub_ui_dispatch_roundtrip(monkeypatch): + _install_model_serialization_stub(monkeypatch) + monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None) + monkeypatch.setattr(runtime_helpers, "_relieve_host_vram_pressure", lambda *args, **kwargs: None) + + extension = _DummyExtension(name="internal_probe", module_path=os.getcwd()) + stub = runtime_helpers.build_stub_class( + "ProbeNode", + { + "is_v3": True, + "schema_v1": {"python_module": "custom_nodes.internal_probe"}, + "input_types": {}, + }, + extension, + {}, + logging.getLogger("test"), + ) + + result = asyncio.run(getattr(stub, "_pyisolate_execute")(SimpleNamespace(), token="value")) + + assert isinstance(result, latest_io.NodeOutput) + assert result.ui == {"status": "ok"} + + +def test_stub_class_types_align_with_extension(): + extension = SimpleNamespace(name="internal_probe", module_path="/sandbox/probe") + running_extensions = {"internal_probe": extension} + + specs = [ + SimpleNamespace(module_path=Path("/sandbox/probe"), node_name="ProbeImage"), + SimpleNamespace(module_path=Path("/sandbox/probe"), node_name="ProbeAudio"), + SimpleNamespace(module_path=Path("/sandbox/other"), node_name="OtherNode"), + ] + + class_types = runtime_helpers.get_class_types_for_extension( + "internal_probe", running_extensions, specs + ) + + assert class_types == {"ProbeImage", "ProbeAudio"} + + +def test_probe_stage_requires_explicit_root(): + script = Path(__file__).resolve().parent / "stage_internal_probe_node.py" + result = subprocess.run([sys.executable, str(script)], capture_output=True, text=True, check=False) + + assert result.returncode != 0 + assert "--target-root" in result.stderr + + +def test_probe_stage_cleans_up_context(): + with staged_probe_node() as module_path: + staged_root = module_path.parents[1] + assert module_path.name == PROBE_NODE_NAME + assert staged_root.exists() + + assert not staged_root.exists() diff --git a/tests/isolation/test_savedimages_serialization.py b/tests/isolation/test_savedimages_serialization.py new file mode 100644 index 000000000..f2f3df1cc --- /dev/null +++ b/tests/isolation/test_savedimages_serialization.py @@ -0,0 +1,53 @@ +import logging +import socket +import sys +from pathlib import Path + +repo_root = Path(__file__).resolve().parents[2] +pyisolate_root = repo_root.parent / "pyisolate" +if pyisolate_root.exists(): + sys.path.insert(0, str(pyisolate_root)) + +from comfy.isolation.adapter import ComfyUIAdapter +from comfy_api.latest._io import FolderType +from comfy_api.latest._ui import SavedImages, SavedResult +from pyisolate._internal.rpc_transports import JSONSocketTransport +from pyisolate._internal.serialization_registry import SerializerRegistry + + +def test_savedimages_roundtrip(caplog): + registry = SerializerRegistry.get_instance() + registry.clear() + ComfyUIAdapter().register_serializers(registry) + + payload = SavedImages( + results=[SavedResult("issue82.png", "slice2", FolderType.output)], + is_animated=True, + ) + + a, b = socket.socketpair() + sender = JSONSocketTransport(a) + receiver = JSONSocketTransport(b) + try: + with caplog.at_level(logging.WARNING, logger="pyisolate._internal.rpc_transports"): + sender.send({"ui": payload}) + result = receiver.recv() + finally: + sender.close() + receiver.close() + registry.clear() + + ui = result["ui"] + assert isinstance(ui, SavedImages) + assert ui.is_animated is True + assert len(ui.results) == 1 + assert isinstance(ui.results[0], SavedResult) + assert ui.results[0].filename == "issue82.png" + assert ui.results[0].subfolder == "slice2" + assert ui.results[0].type == FolderType.output + assert ui.as_dict() == { + "images": [SavedResult("issue82.png", "slice2", FolderType.output)], + "animated": (True,), + } + assert not any("GENERIC SERIALIZER USED" in record.message for record in caplog.records) + assert not any("GENERIC DESERIALIZER USED" in record.message for record in caplog.records) diff --git a/tests/isolation/test_sealed_worker_contract_matrix.py b/tests/isolation/test_sealed_worker_contract_matrix.py new file mode 100644 index 000000000..7395c334c --- /dev/null +++ b/tests/isolation/test_sealed_worker_contract_matrix.py @@ -0,0 +1,368 @@ +"""Generic sealed-worker loader contract matrix tests.""" + +from __future__ import annotations + +import importlib +import json +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +TEST_WORKFLOW_ROOT = COMFYUI_ROOT / "tests" / "isolation" / "workflows" +SEALED_WORKFLOW_CLASS_TYPES: dict[str, set[str]] = { + "quick_6_uv_sealed_worker.json": { + "EmptyLatentImage", + "ProxyTestSealedWorker", + "UVSealedBoltonsSlugify", + "UVSealedLatentEcho", + "UVSealedRuntimeProbe", + }, + "isolation_7_uv_sealed_worker.json": { + "EmptyLatentImage", + "ProxyTestSealedWorker", + "UVSealedBoltonsSlugify", + "UVSealedLatentEcho", + "UVSealedRuntimeProbe", + }, + "quick_8_conda_sealed_worker.json": { + "CondaSealedLatentEcho", + "CondaSealedOpenWeatherDataset", + "CondaSealedRuntimeProbe", + "EmptyLatentImage", + "ProxyTestCondaSealedWorker", + }, + "isolation_9_conda_sealed_worker.json": { + "CondaSealedLatentEcho", + "CondaSealedOpenWeatherDataset", + "CondaSealedRuntimeProbe", + "EmptyLatentImage", + "ProxyTestCondaSealedWorker", + }, +} + + +def _workflow_class_types(path: Path) -> set[str]: + payload = json.loads(path.read_text(encoding="utf-8")) + return { + node["class_type"] + for node in payload.values() + if isinstance(node, dict) and "class_type" in node + } + + +def _make_manifest( + *, + package_manager: str = "uv", + execution_model: str | None = None, + can_isolate: bool = True, + dependencies: list[str] | None = None, + share_torch: bool = False, + sealed_host_ro_paths: list[str] | None = None, +) -> dict: + isolation: dict[str, object] = { + "can_isolate": can_isolate, + } + if package_manager != "uv": + isolation["package_manager"] = package_manager + if execution_model is not None: + isolation["execution_model"] = execution_model + if share_torch: + isolation["share_torch"] = True + if sealed_host_ro_paths is not None: + isolation["sealed_host_ro_paths"] = sealed_host_ro_paths + + if package_manager == "conda": + isolation["conda_channels"] = ["conda-forge"] + isolation["conda_dependencies"] = ["numpy"] + + return { + "project": { + "name": "contract-extension", + "dependencies": dependencies or ["numpy"], + }, + "tool": {"comfy": {"isolation": isolation}}, + } + + +@pytest.fixture +def manifest_file(tmp_path: Path) -> Path: + path = tmp_path / "pyproject.toml" + path.write_bytes(b"") + return path + + +def _loader_module( + monkeypatch: pytest.MonkeyPatch, *, preload_extension_wrapper: bool +): + mock_wrapper = MagicMock() + mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {}) + + iso_mod = types.ModuleType("comfy.isolation") + iso_mod.__path__ = [ + str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation") + ] + iso_mod.__package__ = "comfy.isolation" + + manifest_loader = types.SimpleNamespace( + is_cache_valid=lambda *args, **kwargs: False, + load_from_cache=lambda *args, **kwargs: None, + save_to_cache=lambda *args, **kwargs: None, + ) + host_policy = types.SimpleNamespace( + load_host_policy=lambda base_path: { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": [], + } + ) + folder_paths = types.SimpleNamespace(base_path="/fake/comfyui") + + monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod) + monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock()) + monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader) + monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy) + monkeypatch.setitem(sys.modules, "folder_paths", folder_paths) + if preload_extension_wrapper: + monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper) + else: + sys.modules.pop("comfy.isolation.extension_wrapper", None) + sys.modules.pop("comfy.isolation.extension_loader", None) + + module = importlib.import_module("comfy.isolation.extension_loader") + try: + yield module, mock_wrapper + finally: + sys.modules.pop("comfy.isolation.extension_loader", None) + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"): + delattr(comfy_pkg, "isolation") + + +@pytest.fixture +def loader_module(monkeypatch: pytest.MonkeyPatch): + yield from _loader_module(monkeypatch, preload_extension_wrapper=True) + + +@pytest.fixture +def sealed_loader_module(monkeypatch: pytest.MonkeyPatch): + yield from _loader_module(monkeypatch, preload_extension_wrapper=False) + + +@pytest.fixture +def mocked_loader(loader_module): + module, mock_wrapper = loader_module + mock_ext = AsyncMock() + mock_ext.list_nodes = AsyncMock(return_value={}) + + mock_manager = MagicMock() + mock_manager.load_extension = MagicMock(return_value=mock_ext) + sealed_type = type("SealedNodeExtension", (), {}) + + with patch.object(module, "pyisolate") as mock_pi: + mock_pi.ExtensionManager = MagicMock(return_value=mock_manager) + mock_pi.SealedNodeExtension = sealed_type + yield module, mock_pi, mock_manager, sealed_type, mock_wrapper + + +@pytest.fixture +def sealed_mocked_loader(sealed_loader_module): + module, mock_wrapper = sealed_loader_module + mock_ext = AsyncMock() + mock_ext.list_nodes = AsyncMock(return_value={}) + + mock_manager = MagicMock() + mock_manager.load_extension = MagicMock(return_value=mock_ext) + sealed_type = type("SealedNodeExtension", (), {}) + + with patch.object(module, "pyisolate") as mock_pi: + mock_pi.ExtensionManager = MagicMock(return_value=mock_manager) + mock_pi.SealedNodeExtension = sealed_type + yield module, mock_pi, mock_manager, sealed_type, mock_wrapper + + +async def _load_node(module, manifest: dict, manifest_path: Path, tmp_path: Path) -> dict: + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await module.load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_path, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + manager = module.pyisolate.ExtensionManager.return_value + return manager.load_extension.call_args[0][0] + + +@pytest.mark.asyncio +async def test_uv_host_coupled_default(mocked_loader, manifest_file: Path, tmp_path: Path): + module, mock_pi, _mock_manager, sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="uv") + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type is not sealed_type + assert "execution_model" not in config + + +@pytest.mark.asyncio +async def test_uv_sealed_worker_opt_in( + sealed_mocked_loader, manifest_file: Path, tmp_path: Path +): + module, mock_pi, _mock_manager, sealed_type, _ = sealed_mocked_loader + manifest = _make_manifest(package_manager="uv", execution_model="sealed_worker") + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type is sealed_type + assert config["execution_model"] == "sealed_worker" + assert "apis" not in config + assert "comfy.isolation.extension_wrapper" not in sys.modules + + +@pytest.mark.asyncio +async def test_conda_defaults_to_sealed_worker( + sealed_mocked_loader, manifest_file: Path, tmp_path: Path +): + module, mock_pi, _mock_manager, sealed_type, _ = sealed_mocked_loader + manifest = _make_manifest(package_manager="conda") + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type is sealed_type + assert config["execution_model"] == "sealed_worker" + assert config["package_manager"] == "conda" + assert "comfy.isolation.extension_wrapper" not in sys.modules + + +@pytest.mark.asyncio +async def test_conda_never_uses_comfy_extension_type( + mocked_loader, manifest_file: Path, tmp_path: Path +): + module, mock_pi, _mock_manager, sealed_type, mock_wrapper = mocked_loader + manifest = _make_manifest(package_manager="conda") + + await _load_node(module, manifest, manifest_file, tmp_path) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type is sealed_type + assert extension_type is not mock_wrapper.ComfyNodeExtension + + +@pytest.mark.asyncio +async def test_conda_forces_share_torch_false(mocked_loader, manifest_file: Path, tmp_path: Path): + module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="conda", share_torch=True) + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + assert config["share_torch"] is False + + +@pytest.mark.asyncio +async def test_conda_forces_share_cuda_ipc_false( + mocked_loader, manifest_file: Path, tmp_path: Path +): + module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="conda", share_torch=True) + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + assert config["share_cuda_ipc"] is False + + +@pytest.mark.asyncio +async def test_conda_sandbox_policy_applied(mocked_loader, manifest_file: Path, tmp_path: Path): + module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="conda") + + custom_policy = { + "sandbox_mode": "required", + "allow_network": True, + "writable_paths": ["/data/write"], + "readonly_paths": ["/data/read"], + } + + with patch("platform.system", return_value="Linux"): + with patch.object(module, "load_host_policy", return_value=custom_policy): + config = await _load_node(module, manifest, manifest_file, tmp_path) + + assert config["sandbox_mode"] == "required" + assert config["sandbox"] == { + "network": True, + "writable_paths": ["/data/write"], + "readonly_paths": ["/data/read"], + } + + +def test_sealed_worker_workflow_templates_present() -> None: + missing = [ + filename + for filename in SEALED_WORKFLOW_CLASS_TYPES + if not (TEST_WORKFLOW_ROOT / filename).is_file() + ] + assert not missing, f"missing sealed-worker workflow templates: {missing}" + + +@pytest.mark.parametrize( + "workflow_name,expected_class_types", + SEALED_WORKFLOW_CLASS_TYPES.items(), +) +def test_sealed_worker_workflow_class_type_contract( + workflow_name: str, expected_class_types: set[str] +) -> None: + workflow_path = TEST_WORKFLOW_ROOT / workflow_name + assert workflow_path.is_file(), f"workflow missing: {workflow_path}" + + assert _workflow_class_types(workflow_path) == expected_class_types + + +@pytest.mark.asyncio +async def test_sealed_worker_host_policy_ro_import_matrix( + mocked_loader, manifest_file: Path, tmp_path: Path +): + module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="uv", execution_model="sealed_worker") + + with patch.object( + module, + "load_host_policy", + return_value={ + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": [], + }, + ): + default_config = await _load_node(module, manifest, manifest_file, tmp_path) + + with patch.object( + module, + "load_host_policy", + return_value={ + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"], + }, + ): + opt_in_config = await _load_node(module, manifest, manifest_file, tmp_path) + + assert default_config["execution_model"] == "sealed_worker" + assert "sealed_host_ro_paths" not in default_config + + assert opt_in_config["execution_model"] == "sealed_worker" + assert opt_in_config["sealed_host_ro_paths"] == ["/home/johnj/ComfyUI"] + assert "apis" not in opt_in_config diff --git a/tests/isolation/test_shared_model_proxy_contract.py b/tests/isolation/test_shared_model_proxy_contract.py new file mode 100644 index 000000000..9e91c74f3 --- /dev/null +++ b/tests/isolation/test_shared_model_proxy_contract.py @@ -0,0 +1,44 @@ +import asyncio +import sys +from pathlib import Path + +repo_root = Path(__file__).resolve().parents[2] +pyisolate_root = repo_root.parent / "pyisolate" +if pyisolate_root.exists(): + sys.path.insert(0, str(pyisolate_root)) + +from comfy.isolation.adapter import ComfyUIAdapter +from comfy.isolation.runtime_helpers import _wrap_remote_handles_as_host_proxies +from pyisolate._internal.model_serialization import deserialize_from_isolation +from pyisolate._internal.remote_handle import RemoteObjectHandle +from pyisolate._internal.serialization_registry import SerializerRegistry + + +def test_shared_model_ksampler_contract(): + registry = SerializerRegistry.get_instance() + registry.clear() + ComfyUIAdapter().register_serializers(registry) + + handle = RemoteObjectHandle("model_0", "ModelPatcher") + + class FakeExtension: + async def call_remote_object_method(self, object_id, method_name, *args, **kwargs): + assert object_id == "model_0" + assert method_name == "get_model_object" + assert args == ("latent_format",) + assert kwargs == {} + return "resolved:latent_format" + + wrapped = (handle,) + assert isinstance(wrapped, tuple) + assert isinstance(wrapped[0], RemoteObjectHandle) + + deserialized = asyncio.run(deserialize_from_isolation(wrapped)) + proxied = _wrap_remote_handles_as_host_proxies(deserialized, FakeExtension()) + model_for_host = proxied[0] + + assert not isinstance(model_for_host, RemoteObjectHandle) + assert hasattr(model_for_host, "get_model_object") + assert model_for_host.get_model_object("latent_format") == "resolved:latent_format" + + registry.clear() diff --git a/tests/isolation/test_singleton_proxy_boundary_matrix.py b/tests/isolation/test_singleton_proxy_boundary_matrix.py new file mode 100644 index 000000000..31cc86e04 --- /dev/null +++ b/tests/isolation/test_singleton_proxy_boundary_matrix.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import json + +from tests.isolation.singleton_boundary_helpers import ( + capture_minimal_sealed_worker_imports, + capture_sealed_singleton_imports, +) + + +def test_minimal_sealed_worker_forbidden_imports() -> None: + payload = capture_minimal_sealed_worker_imports() + + assert payload["mode"] == "minimal_sealed_worker" + assert payload["runtime_probe_function"] == "inspect" + assert payload["forbidden_matches"] == [] + + +def test_torch_share_subset_scope() -> None: + minimal = capture_minimal_sealed_worker_imports() + + allowed_torch_share_only = { + "torch", + "folder_paths", + "comfy.utils", + "comfy.model_management", + "main", + "comfy.isolation.extension_wrapper", + } + + assert minimal["forbidden_matches"] == [] + assert all( + module_name not in minimal["modules"] for module_name in sorted(allowed_torch_share_only) + ) + + +def test_capture_payload_is_json_serializable() -> None: + payload = capture_minimal_sealed_worker_imports() + + encoded = json.dumps(payload, sort_keys=True) + + assert "\"minimal_sealed_worker\"" in encoded + + +def test_folder_paths_child_safe() -> None: + payload = capture_sealed_singleton_imports() + + assert payload["mode"] == "sealed_singletons" + assert payload["folder_path"] == "/sandbox/input/demo.png" + assert payload["temp_dir"] == "/sandbox/temp" + assert payload["models_dir"] == "/sandbox/models" + assert payload["forbidden_matches"] == [] + + +def test_utils_child_safe() -> None: + payload = capture_sealed_singleton_imports() + + progress_calls = [ + call + for call in payload["rpc_calls"] + if call["object_id"] == "UtilsProxy" and call["method"] == "progress_bar_hook" + ] + + assert progress_calls + assert payload["forbidden_matches"] == [] + + +def test_progress_child_safe() -> None: + payload = capture_sealed_singleton_imports() + + progress_calls = [ + call + for call in payload["rpc_calls"] + if call["object_id"] == "ProgressProxy" and call["method"] == "rpc_set_progress" + ] + + assert progress_calls + assert payload["forbidden_matches"] == [] diff --git a/tests/isolation/test_web_directory_handler.py b/tests/isolation/test_web_directory_handler.py new file mode 100644 index 000000000..f50e01977 --- /dev/null +++ b/tests/isolation/test_web_directory_handler.py @@ -0,0 +1,129 @@ +"""Tests for WebDirectoryProxy host-side cache and aiohttp handler integration.""" + +from __future__ import annotations + +import base64 +import sys +from unittest.mock import MagicMock + +import pytest + +from comfy.isolation.proxies.web_directory_proxy import ( + ALLOWED_EXTENSIONS, + WebDirectoryCache, +) + + +@pytest.fixture() +def mock_proxy() -> MagicMock: + """Create a mock WebDirectoryProxy RPC proxy.""" + proxy = MagicMock() + proxy.list_web_files.return_value = [ + {"relative_path": "js/app.js", "content_type": "application/javascript"}, + {"relative_path": "js/utils.js", "content_type": "application/javascript"}, + {"relative_path": "index.html", "content_type": "text/html"}, + {"relative_path": "style.css", "content_type": "text/css"}, + ] + proxy.get_web_file.return_value = { + "content": base64.b64encode(b"console.log('hello');").decode("ascii"), + "content_type": "application/javascript", + } + return proxy + + +@pytest.fixture() +def cache_with_proxy(mock_proxy: MagicMock) -> WebDirectoryCache: + """Create a WebDirectoryCache with a registered mock proxy.""" + cache = WebDirectoryCache() + cache.register_proxy("test-extension", mock_proxy) + return cache + + +class TestExtensionsListing: + """AC-2: /extensions endpoint lists proxied JS files in URL format.""" + + def test_extensions_listing_produces_url_format_paths( + self, cache_with_proxy: WebDirectoryCache + ) -> None: + """Simulate what server.py does: build /extensions/{name}/{path} URLs.""" + import urllib.parse + + ext_name = "test-extension" + urls = [] + for entry in cache_with_proxy.list_files(ext_name): + if entry["relative_path"].endswith(".js"): + urls.append( + "/extensions/" + urllib.parse.quote(ext_name) + + "/" + entry["relative_path"] + ) + + # Emit the actual URL list so it appears in test log output. + sys.stdout.write(f"\n--- Proxied JS URLs ({len(urls)}) ---\n") + for url in urls: + sys.stdout.write(f" {url}\n") + sys.stdout.write("--- End URLs ---\n") + + # At least one proxied JS URL in /extensions/{name}/{path} format + assert len(urls) >= 1, f"Expected >= 1 proxied JS URL, got {len(urls)}" + assert "/extensions/test-extension/js/app.js" in urls, ( + f"Expected /extensions/test-extension/js/app.js in {urls}" + ) + + +class TestCacheHit: + """AC-3: Cache populated on first request, reused on second.""" + + def test_cache_hit_single_rpc_call( + self, cache_with_proxy: WebDirectoryCache, mock_proxy: MagicMock + ) -> None: + # First call — RPC + result1 = cache_with_proxy.get_file("test-extension", "js/app.js") + assert result1 is not None + assert result1["content"] == b"console.log('hello');" + + # Second call — cache hit + result2 = cache_with_proxy.get_file("test-extension", "js/app.js") + assert result2 is not None + assert result2["content"] == b"console.log('hello');" + + # Proxy was called exactly once + assert mock_proxy.get_web_file.call_count == 1 + + def test_cache_returns_none_for_unknown_extension( + self, cache_with_proxy: WebDirectoryCache + ) -> None: + result = cache_with_proxy.get_file("nonexistent", "js/app.js") + assert result is None + + +class TestForbiddenType: + """AC-4: Disallowed file types return HTTP 403 Forbidden.""" + + @pytest.mark.parametrize( + "disallowed_path,expected_status", + [ + ("backdoor.py", 403), + ("malware.exe", 403), + ("exploit.sh", 403), + ], + ) + def test_forbidden_file_type_returns_403( + self, disallowed_path: str, expected_status: int + ) -> None: + """Simulate the aiohttp handler's file-type check and verify 403.""" + import os + suffix = os.path.splitext(disallowed_path)[1].lower() + + # This mirrors the handler logic in server.py: + # if suffix not in ALLOWED_EXTENSIONS: return web.Response(status=403) + if suffix not in ALLOWED_EXTENSIONS: + status = 403 + else: + status = 200 + + sys.stdout.write( + f"\n--- HTTP status for {disallowed_path} (suffix={suffix}): {status} ---\n" + ) + assert status == expected_status, ( + f"Expected HTTP {expected_status} for {disallowed_path}, got {status}" + ) diff --git a/tests/isolation/test_web_directory_proxy.py b/tests/isolation/test_web_directory_proxy.py new file mode 100644 index 000000000..2922da92d --- /dev/null +++ b/tests/isolation/test_web_directory_proxy.py @@ -0,0 +1,130 @@ +"""Tests for WebDirectoryProxy — allow-list, traversal prevention, content serving.""" + +from __future__ import annotations + +import base64 +from pathlib import Path + +import pytest + +from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + + +@pytest.fixture() +def web_dir_with_mixed_files(tmp_path: Path) -> Path: + """Create a temp web directory with allowed and disallowed file types.""" + web = tmp_path / "web" + js_dir = web / "js" + js_dir.mkdir(parents=True) + + # Allowed types + (js_dir / "app.js").write_text("console.log('hello');") + (web / "index.html").write_text("") + (web / "style.css").write_text("body { margin: 0; }") + + # Disallowed types + (web / "backdoor.py").write_text("import os; os.system('rm -rf /')") + (web / "malware.exe").write_bytes(b"\x00" * 16) + (web / "exploit.sh").write_text("#!/bin/bash\nrm -rf /") + + return web + + +@pytest.fixture() +def proxy_with_web_dir(web_dir_with_mixed_files: Path) -> WebDirectoryProxy: + """Create a WebDirectoryProxy with a registered test web directory.""" + proxy = WebDirectoryProxy() + # Clear class-level state to avoid cross-test pollution + WebDirectoryProxy._web_dirs = {} + WebDirectoryProxy.register_web_dir("test-extension", str(web_dir_with_mixed_files)) + return proxy + + +class TestAllowList: + """AC-2: list_web_files returns only allowed file types.""" + + def test_allowlist_only_safe_types( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + files = proxy_with_web_dir.list_web_files("test-extension") + extensions = {Path(f["relative_path"]).suffix for f in files} + + # Only .js, .html, .css should appear + assert extensions == {".js", ".html", ".css"} + + def test_allowlist_excludes_dangerous_types( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + files = proxy_with_web_dir.list_web_files("test-extension") + paths = [f["relative_path"] for f in files] + + assert not any(p.endswith(".py") for p in paths) + assert not any(p.endswith(".exe") for p in paths) + assert not any(p.endswith(".sh") for p in paths) + + def test_allowlist_correct_count( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + files = proxy_with_web_dir.list_web_files("test-extension") + # 3 allowed files: app.js, index.html, style.css + assert len(files) == 3 + + def test_allowlist_unknown_extension_returns_empty( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + files = proxy_with_web_dir.list_web_files("nonexistent-extension") + assert files == [] + + +class TestTraversal: + """AC-3: get_web_file rejects directory traversal attempts.""" + + @pytest.mark.parametrize( + "malicious_path", + [ + "../../../etc/passwd", + "/etc/passwd", + "../../__init__.py", + ], + ) + def test_traversal_rejected( + self, proxy_with_web_dir: WebDirectoryProxy, malicious_path: str + ) -> None: + with pytest.raises(ValueError): + proxy_with_web_dir.get_web_file("test-extension", malicious_path) + + +class TestContent: + """AC-4: get_web_file returns base64 content with correct MIME types.""" + + def test_content_js_mime_type( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + result = proxy_with_web_dir.get_web_file("test-extension", "js/app.js") + assert result["content_type"] == "application/javascript" + + def test_content_html_mime_type( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + result = proxy_with_web_dir.get_web_file("test-extension", "index.html") + assert result["content_type"] == "text/html" + + def test_content_css_mime_type( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + result = proxy_with_web_dir.get_web_file("test-extension", "style.css") + assert result["content_type"] == "text/css" + + def test_content_base64_roundtrip( + self, proxy_with_web_dir: WebDirectoryProxy, web_dir_with_mixed_files: Path + ) -> None: + result = proxy_with_web_dir.get_web_file("test-extension", "js/app.js") + decoded = base64.b64decode(result["content"]) + source = (web_dir_with_mixed_files / "js" / "app.js").read_bytes() + assert decoded == source + + def test_content_disallowed_type_rejected( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + with pytest.raises(ValueError, match="Disallowed file type"): + proxy_with_web_dir.get_web_file("test-extension", "backdoor.py") diff --git a/tests/isolation/uv_sealed_worker/__init__.py b/tests/isolation/uv_sealed_worker/__init__.py new file mode 100644 index 000000000..453915a93 --- /dev/null +++ b/tests/isolation/uv_sealed_worker/__init__.py @@ -0,0 +1,230 @@ +# pylint: disable=import-outside-toplevel,import-error +from __future__ import annotations + +import os +import sys +import logging +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def _artifact_dir() -> Path | None: + raw = os.environ.get("PYISOLATE_ARTIFACT_DIR") + if not raw: + return None + path = Path(raw) + path.mkdir(parents=True, exist_ok=True) + return path + + +def _write_artifact(name: str, content: str) -> None: + artifact_dir = _artifact_dir() + if artifact_dir is None: + return + (artifact_dir / name).write_text(content, encoding="utf-8") + + +def _contains_tensor_marker(value: Any) -> bool: + if isinstance(value, dict): + if value.get("__type__") == "TensorValue": + return True + return any(_contains_tensor_marker(v) for v in value.values()) + if isinstance(value, (list, tuple)): + return any(_contains_tensor_marker(v) for v in value) + return False + + +class InspectRuntimeNode: + RETURN_TYPES = ( + "STRING", + "STRING", + "BOOLEAN", + "BOOLEAN", + "STRING", + "STRING", + "BOOLEAN", + ) + RETURN_NAMES = ( + "path_dump", + "boltons_origin", + "saw_comfy_root", + "imported_comfy_wrapper", + "comfy_module_dump", + "report", + "saw_user_site", + ) + FUNCTION = "inspect" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def inspect(self) -> tuple[str, str, bool, bool, str, str, bool]: + import boltons + + path_dump = "\n".join(sys.path) + comfy_root = "/home/johnj/ComfyUI" + saw_comfy_root = any( + entry == comfy_root + or entry.startswith(f"{comfy_root}/comfy") + or entry.startswith(f"{comfy_root}/.venv") + for entry in sys.path + ) + imported_comfy_wrapper = "comfy.isolation.extension_wrapper" in sys.modules + comfy_module_dump = "\n".join( + sorted(name for name in sys.modules if name.startswith("comfy")) + ) + saw_user_site = any("/.local/lib/" in entry for entry in sys.path) + boltons_origin = getattr(boltons, "__file__", "") + + report_lines = [ + "UV sealed worker runtime probe", + f"boltons_origin={boltons_origin}", + f"saw_comfy_root={saw_comfy_root}", + f"imported_comfy_wrapper={imported_comfy_wrapper}", + f"saw_user_site={saw_user_site}", + ] + report = "\n".join(report_lines) + + _write_artifact("child_bootstrap_paths.txt", path_dump) + _write_artifact("child_import_trace.txt", comfy_module_dump) + _write_artifact("child_dependency_dump.txt", boltons_origin) + logger.warning("][ UV sealed runtime probe executed") + logger.warning("][ boltons origin: %s", boltons_origin) + + return ( + path_dump, + boltons_origin, + saw_comfy_root, + imported_comfy_wrapper, + comfy_module_dump, + report, + saw_user_site, + ) + + +class BoltonsSlugifyNode: + RETURN_TYPES = ("STRING", "STRING") + RETURN_NAMES = ("slug", "boltons_origin") + FUNCTION = "slugify_text" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"text": ("STRING", {"default": "Sealed Worker Rocks"})}} + + def slugify_text(self, text: str) -> tuple[str, str]: + import boltons + from boltons.strutils import slugify + + slug = slugify(text) + origin = getattr(boltons, "__file__", "") + logger.warning("][ boltons slugify: %r -> %r", text, slug) + return slug, origin + + +class FilesystemBarrierNode: + RETURN_TYPES = ("STRING", "BOOLEAN", "BOOLEAN", "BOOLEAN") + RETURN_NAMES = ( + "report", + "outside_blocked", + "module_mutation_blocked", + "artifact_write_ok", + ) + FUNCTION = "probe" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def probe(self) -> tuple[str, bool, bool, bool]: + artifact_dir = _artifact_dir() + artifact_write_ok = False + if artifact_dir is not None: + probe_path = artifact_dir / "filesystem_barrier_probe.txt" + probe_path.write_text("artifact write ok\n", encoding="utf-8") + artifact_write_ok = probe_path.exists() + + module_target = Path(__file__).with_name( + "mutated_from_child_should_not_exist.txt" + ) + module_mutation_blocked = False + try: + module_target.write_text("mutation should fail\n", encoding="utf-8") + except Exception: + module_mutation_blocked = True + else: + module_target.unlink(missing_ok=True) + + outside_target = Path("/home/johnj/mysolate/.uv_sealed_worker_escape_probe") + outside_blocked = False + try: + outside_target.write_text("escape should fail\n", encoding="utf-8") + except Exception: + outside_blocked = True + else: + outside_target.unlink(missing_ok=True) + + report_lines = [ + "UV sealed worker filesystem barrier probe", + f"artifact_write_ok={artifact_write_ok}", + f"module_mutation_blocked={module_mutation_blocked}", + f"outside_blocked={outside_blocked}", + ] + report = "\n".join(report_lines) + _write_artifact("filesystem_barrier_report.txt", report) + logger.warning("][ filesystem barrier probe executed") + return report, outside_blocked, module_mutation_blocked, artifact_write_ok + + +class EchoTensorNode: + RETURN_TYPES = ("TENSOR", "BOOLEAN") + RETURN_NAMES = ("tensor", "saw_json_tensor") + FUNCTION = "echo" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"tensor": ("TENSOR",)}} + + def echo(self, tensor: Any) -> tuple[Any, bool]: + saw_json_tensor = _contains_tensor_marker(tensor) + logger.warning("][ tensor echo json_marker=%s", saw_json_tensor) + return tensor, saw_json_tensor + + +class EchoLatentNode: + RETURN_TYPES = ("LATENT", "BOOLEAN") + RETURN_NAMES = ("latent", "saw_json_tensor") + FUNCTION = "echo_latent" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"latent": ("LATENT",)}} + + def echo_latent(self, latent: Any) -> tuple[Any, bool]: + saw_json_tensor = _contains_tensor_marker(latent) + logger.warning("][ latent echo json_marker=%s", saw_json_tensor) + return latent, saw_json_tensor + + +NODE_CLASS_MAPPINGS = { + "UVSealedRuntimeProbe": InspectRuntimeNode, + "UVSealedBoltonsSlugify": BoltonsSlugifyNode, + "UVSealedFilesystemBarrier": FilesystemBarrierNode, + "UVSealedTensorEcho": EchoTensorNode, + "UVSealedLatentEcho": EchoLatentNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "UVSealedRuntimeProbe": "UV Sealed Runtime Probe", + "UVSealedBoltonsSlugify": "UV Sealed Boltons Slugify", + "UVSealedFilesystemBarrier": "UV Sealed Filesystem Barrier", + "UVSealedTensorEcho": "UV Sealed Tensor Echo", + "UVSealedLatentEcho": "UV Sealed Latent Echo", +} diff --git a/tests/isolation/uv_sealed_worker/pyproject.toml b/tests/isolation/uv_sealed_worker/pyproject.toml new file mode 100644 index 000000000..f50d21eb3 --- /dev/null +++ b/tests/isolation/uv_sealed_worker/pyproject.toml @@ -0,0 +1,11 @@ +[project] +name = "comfyui-toolkit-uv-sealed-worker" +version = "0.1.0" +dependencies = ["boltons"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = false +package_manager = "uv" +execution_model = "sealed_worker" +standalone = true diff --git a/tests/isolation/workflows/internal_probe_preview_image_audio.json b/tests/isolation/workflows/internal_probe_preview_image_audio.json new file mode 100644 index 000000000..69f5f0d2f --- /dev/null +++ b/tests/isolation/workflows/internal_probe_preview_image_audio.json @@ -0,0 +1,10 @@ +{ + "1": { + "class_type": "InternalIsolationProbeImage", + "inputs": {} + }, + "2": { + "class_type": "InternalIsolationProbeAudio", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/internal_probe_ui3d.json b/tests/isolation/workflows/internal_probe_ui3d.json new file mode 100644 index 000000000..fea2dc3e7 --- /dev/null +++ b/tests/isolation/workflows/internal_probe_ui3d.json @@ -0,0 +1,6 @@ +{ + "1": { + "class_type": "InternalIsolationProbeUI3D", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/isolation_7_uv_sealed_worker.json b/tests/isolation/workflows/isolation_7_uv_sealed_worker.json new file mode 100644 index 000000000..3b83fa0db --- /dev/null +++ b/tests/isolation/workflows/isolation_7_uv_sealed_worker.json @@ -0,0 +1,22 @@ +{ + "1": { + "class_type": "EmptyLatentImage", + "inputs": {} + }, + "2": { + "class_type": "ProxyTestSealedWorker", + "inputs": {} + }, + "3": { + "class_type": "UVSealedBoltonsSlugify", + "inputs": {} + }, + "4": { + "class_type": "UVSealedLatentEcho", + "inputs": {} + }, + "5": { + "class_type": "UVSealedRuntimeProbe", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/isolation_9_conda_sealed_worker.json b/tests/isolation/workflows/isolation_9_conda_sealed_worker.json new file mode 100644 index 000000000..acfa2e59b --- /dev/null +++ b/tests/isolation/workflows/isolation_9_conda_sealed_worker.json @@ -0,0 +1,22 @@ +{ + "1": { + "class_type": "CondaSealedLatentEcho", + "inputs": {} + }, + "2": { + "class_type": "CondaSealedOpenWeatherDataset", + "inputs": {} + }, + "3": { + "class_type": "CondaSealedRuntimeProbe", + "inputs": {} + }, + "4": { + "class_type": "EmptyLatentImage", + "inputs": {} + }, + "5": { + "class_type": "ProxyTestCondaSealedWorker", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/quick_6_uv_sealed_worker.json b/tests/isolation/workflows/quick_6_uv_sealed_worker.json new file mode 100644 index 000000000..3b83fa0db --- /dev/null +++ b/tests/isolation/workflows/quick_6_uv_sealed_worker.json @@ -0,0 +1,22 @@ +{ + "1": { + "class_type": "EmptyLatentImage", + "inputs": {} + }, + "2": { + "class_type": "ProxyTestSealedWorker", + "inputs": {} + }, + "3": { + "class_type": "UVSealedBoltonsSlugify", + "inputs": {} + }, + "4": { + "class_type": "UVSealedLatentEcho", + "inputs": {} + }, + "5": { + "class_type": "UVSealedRuntimeProbe", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/quick_8_conda_sealed_worker.json b/tests/isolation/workflows/quick_8_conda_sealed_worker.json new file mode 100644 index 000000000..acfa2e59b --- /dev/null +++ b/tests/isolation/workflows/quick_8_conda_sealed_worker.json @@ -0,0 +1,22 @@ +{ + "1": { + "class_type": "CondaSealedLatentEcho", + "inputs": {} + }, + "2": { + "class_type": "CondaSealedOpenWeatherDataset", + "inputs": {} + }, + "3": { + "class_type": "CondaSealedRuntimeProbe", + "inputs": {} + }, + "4": { + "class_type": "EmptyLatentImage", + "inputs": {} + }, + "5": { + "class_type": "ProxyTestCondaSealedWorker", + "inputs": {} + } +} diff --git a/tests/test_adapter.py b/tests/test_adapter.py new file mode 100644 index 000000000..298bc53f6 --- /dev/null +++ b/tests/test_adapter.py @@ -0,0 +1,124 @@ +import os +import subprocess +import sys +import textwrap +import types +from pathlib import Path + +repo_root = Path(__file__).resolve().parents[1] +pyisolate_root = repo_root.parent / "pyisolate" +if pyisolate_root.exists(): + sys.path.insert(0, str(pyisolate_root)) + +from comfy.isolation.adapter import ComfyUIAdapter +from pyisolate._internal.sandbox import build_bwrap_command +from pyisolate._internal.sandbox_detect import RestrictionModel +from pyisolate._internal.serialization_registry import SerializerRegistry + + +def test_identifier(): + adapter = ComfyUIAdapter() + assert adapter.identifier == "comfyui" + + +def test_get_path_config_valid(): + adapter = ComfyUIAdapter() + path = os.path.join("/opt", "ComfyUI", "custom_nodes", "demo") + cfg = adapter.get_path_config(path) + assert cfg is not None + assert cfg["preferred_root"].endswith("ComfyUI") + assert "custom_nodes" in cfg["additional_paths"][0] + + +def test_get_path_config_invalid(): + adapter = ComfyUIAdapter() + assert adapter.get_path_config("/random/path") is None + + +def test_provide_rpc_services(): + adapter = ComfyUIAdapter() + services = adapter.provide_rpc_services() + names = {s.__name__ for s in services} + assert "PromptServerService" in names + assert "FolderPathsProxy" in names + + +def test_register_serializers(): + adapter = ComfyUIAdapter() + registry = SerializerRegistry.get_instance() + registry.clear() + + adapter.register_serializers(registry) + assert registry.has_handler("ModelPatcher") + assert registry.has_handler("CLIP") + assert registry.has_handler("VAE") + + registry.clear() + + +def test_child_temp_directory_fence_uses_private_tmp(tmp_path): + adapter = ComfyUIAdapter() + child_script = textwrap.dedent( + """ + from pathlib import Path + + child_temp = Path("/tmp/comfyui_temp") + child_temp.mkdir(parents=True, exist_ok=True) + scratch = child_temp / "child_only.txt" + scratch.write_text("child-only", encoding="utf-8") + print(f"CHILD_TEMP={child_temp}") + print(f"CHILD_FILE={scratch}") + """ + ) + fake_folder_paths = types.SimpleNamespace( + temp_directory="/host/tmp/should_not_survive", + folder_names_and_paths={}, + extension_mimetypes_cache={}, + filename_list_cache={}, + ) + + class FolderPathsProxy: + def get_temp_directory(self): + return "/host/tmp/should_not_survive" + + original_folder_paths = sys.modules.get("folder_paths") + sys.modules["folder_paths"] = fake_folder_paths + try: + os.environ["PYISOLATE_CHILD"] = "1" + adapter.handle_api_registration(FolderPathsProxy, rpc=None) + finally: + os.environ.pop("PYISOLATE_CHILD", None) + if original_folder_paths is not None: + sys.modules["folder_paths"] = original_folder_paths + else: + sys.modules.pop("folder_paths", None) + + import tempfile as _tf + expected_temp = os.path.join(_tf.gettempdir(), "comfyui_temp") + assert fake_folder_paths.temp_directory == expected_temp + + host_child_file = Path(expected_temp) / "child_only.txt" + if host_child_file.exists(): + host_child_file.unlink() + + cmd = build_bwrap_command( + python_exe=sys.executable, + module_path=str(repo_root / "custom_nodes" / "ComfyUI-IsolationToolkit"), + venv_path=str(repo_root / ".venv"), + uds_address=str(tmp_path / "adapter.sock"), + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + sandbox_config={"writable_paths": ["/dev/shm"], "readonly_paths": [], "network": False}, + adapter=adapter, + ) + assert "--tmpfs" in cmd and "/tmp" in cmd + assert ["--bind", "/tmp", "/tmp"] not in [cmd[i : i + 3] for i in range(len(cmd) - 2)] + + command_tail = cmd[-3:] + assert command_tail[1:] == ["-m", "pyisolate._internal.uds_client"] + cmd = cmd[:-3] + [sys.executable, "-c", child_script] + + completed = subprocess.run(cmd, check=True, capture_output=True, text=True) + + assert "CHILD_TEMP=/tmp/comfyui_temp" in completed.stdout + assert not host_child_file.exists(), "Child temp file leaked into host /tmp"