diff --git a/tests/isolation/conda_sealed_worker/__init__.py b/tests/isolation/conda_sealed_worker/__init__.py new file mode 100644 index 000000000..0208a4bde --- /dev/null +++ b/tests/isolation/conda_sealed_worker/__init__.py @@ -0,0 +1,209 @@ +# pylint: disable=import-outside-toplevel,import-error +from __future__ import annotations + +import logging +import os +import sys +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def _artifact_dir() -> Path | None: + raw = os.environ.get("PYISOLATE_ARTIFACT_DIR") + if not raw: + return None + path = Path(raw) + path.mkdir(parents=True, exist_ok=True) + return path + + +def _write_artifact(name: str, content: str) -> None: + artifact_dir = _artifact_dir() + if artifact_dir is None: + return + (artifact_dir / name).write_text(content, encoding="utf-8") + + +def _contains_tensor_marker(value: Any) -> bool: + if isinstance(value, dict): + if value.get("__type__") == "TensorValue": + return True + return any(_contains_tensor_marker(v) for v in value.values()) + if isinstance(value, (list, tuple)): + return any(_contains_tensor_marker(v) for v in value) + return False + + +class InspectRuntimeNode: + RETURN_TYPES = ( + "STRING", + "STRING", + "BOOLEAN", + "BOOLEAN", + "STRING", + "STRING", + "BOOLEAN", + ) + RETURN_NAMES = ( + "path_dump", + "runtime_report", + "saw_comfy_root", + "imported_comfy_wrapper", + "comfy_module_dump", + "python_exe", + "saw_user_site", + ) + FUNCTION = "inspect" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def inspect(self) -> tuple[str, str, bool, bool, str, str, bool]: + import cfgrib + import eccodes + import xarray as xr + + path_dump = "\n".join(sys.path) + comfy_root = "/home/johnj/ComfyUI" + saw_comfy_root = any( + entry == comfy_root + or entry.startswith(f"{comfy_root}/comfy") + or entry.startswith(f"{comfy_root}/.venv") + for entry in sys.path + ) + imported_comfy_wrapper = "comfy.isolation.extension_wrapper" in sys.modules + comfy_module_dump = "\n".join( + sorted(name for name in sys.modules if name.startswith("comfy")) + ) + saw_user_site = any("/.local/lib/" in entry for entry in sys.path) + python_exe = sys.executable + + runtime_lines = [ + "Conda sealed worker runtime probe", + f"python_exe={python_exe}", + f"xarray_origin={getattr(xr, '__file__', '')}", + f"cfgrib_origin={getattr(cfgrib, '__file__', '')}", + f"eccodes_origin={getattr(eccodes, '__file__', '')}", + f"saw_comfy_root={saw_comfy_root}", + f"imported_comfy_wrapper={imported_comfy_wrapper}", + f"saw_user_site={saw_user_site}", + ] + runtime_report = "\n".join(runtime_lines) + + _write_artifact("child_bootstrap_paths.txt", path_dump) + _write_artifact("child_import_trace.txt", comfy_module_dump) + _write_artifact("child_dependency_dump.txt", runtime_report) + logger.warning("][ Conda sealed runtime probe executed") + logger.warning("][ conda python executable: %s", python_exe) + logger.warning( + "][ conda dependency origins: xarray=%s cfgrib=%s eccodes=%s", + getattr(xr, "__file__", ""), + getattr(cfgrib, "__file__", ""), + getattr(eccodes, "__file__", ""), + ) + + return ( + path_dump, + runtime_report, + saw_comfy_root, + imported_comfy_wrapper, + comfy_module_dump, + python_exe, + saw_user_site, + ) + + +class OpenWeatherDatasetNode: + RETURN_TYPES = ("FLOAT", "STRING", "STRING") + RETURN_NAMES = ("sum_value", "grib_path", "dependency_report") + FUNCTION = "open_dataset" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def open_dataset(self) -> tuple[float, str, str]: + import eccodes + import xarray as xr + + artifact_dir = _artifact_dir() + if artifact_dir is None: + artifact_dir = Path(os.environ.get("HOME", ".")) / "pyisolate_artifacts" + artifact_dir.mkdir(parents=True, exist_ok=True) + + grib_path = artifact_dir / "toolkit_weather_fixture.grib2" + + gid = eccodes.codes_grib_new_from_samples("GRIB2") + for key, value in [ + ("gridType", "regular_ll"), + ("Nx", 2), + ("Ny", 2), + ("latitudeOfFirstGridPointInDegrees", 1.0), + ("longitudeOfFirstGridPointInDegrees", 0.0), + ("latitudeOfLastGridPointInDegrees", 0.0), + ("longitudeOfLastGridPointInDegrees", 1.0), + ("iDirectionIncrementInDegrees", 1.0), + ("jDirectionIncrementInDegrees", 1.0), + ("jScansPositively", 0), + ("shortName", "t"), + ("typeOfLevel", "surface"), + ("level", 0), + ("date", 20260315), + ("time", 0), + ("step", 0), + ]: + eccodes.codes_set(gid, key, value) + + eccodes.codes_set_values(gid, [1.0, 2.0, 3.0, 4.0]) + with grib_path.open("wb") as handle: + eccodes.codes_write(gid, handle) + eccodes.codes_release(gid) + + dataset = xr.open_dataset(grib_path, engine="cfgrib") + sum_value = float(dataset["t"].sum().item()) + dependency_report = "\n".join( + [ + f"dataset_sum={sum_value}", + f"grib_path={grib_path}", + "xarray_engine=cfgrib", + ] + ) + _write_artifact("weather_dependency_report.txt", dependency_report) + logger.warning("][ cfgrib import ok") + logger.warning("][ xarray open_dataset engine=cfgrib path=%s", grib_path) + logger.warning("][ conda weather dataset sum=%s", sum_value) + return sum_value, str(grib_path), dependency_report + + +class EchoLatentNode: + RETURN_TYPES = ("LATENT", "BOOLEAN") + RETURN_NAMES = ("latent", "saw_json_tensor") + FUNCTION = "echo_latent" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"latent": ("LATENT",)}} + + def echo_latent(self, latent: Any) -> tuple[Any, bool]: + saw_json_tensor = _contains_tensor_marker(latent) + logger.warning("][ conda latent echo json_marker=%s", saw_json_tensor) + return latent, saw_json_tensor + + +NODE_CLASS_MAPPINGS = { + "CondaSealedRuntimeProbe": InspectRuntimeNode, + "CondaSealedOpenWeatherDataset": OpenWeatherDatasetNode, + "CondaSealedLatentEcho": EchoLatentNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CondaSealedRuntimeProbe": "Conda Sealed Runtime Probe", + "CondaSealedOpenWeatherDataset": "Conda Sealed Open Weather Dataset", + "CondaSealedLatentEcho": "Conda Sealed Latent Echo", +} diff --git a/tests/isolation/conda_sealed_worker/pyproject.toml b/tests/isolation/conda_sealed_worker/pyproject.toml new file mode 100644 index 000000000..6d6d7d804 --- /dev/null +++ b/tests/isolation/conda_sealed_worker/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "comfyui-toolkit-conda-sealed-worker" +version = "0.1.0" +dependencies = ["xarray", "cfgrib"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = false +package_manager = "conda" +execution_model = "sealed_worker" +standalone = true +conda_channels = ["conda-forge"] +conda_dependencies = ["eccodes", "cfgrib"] diff --git a/tests/isolation/internal_probe_host_policy.toml b/tests/isolation/internal_probe_host_policy.toml new file mode 100644 index 000000000..57bde615d --- /dev/null +++ b/tests/isolation/internal_probe_host_policy.toml @@ -0,0 +1,7 @@ +[tool.comfy.host] +sandbox_mode = "required" +allow_network = false +writable_paths = [ + "/dev/shm", + "/home/johnj/ComfyUI/output", +] diff --git a/tests/isolation/internal_probe_node/__init__.py b/tests/isolation/internal_probe_node/__init__.py new file mode 100644 index 000000000..f4155bf99 --- /dev/null +++ b/tests/isolation/internal_probe_node/__init__.py @@ -0,0 +1,6 @@ +from .probe_nodes import ( + NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS, +) + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/tests/isolation/internal_probe_node/probe_nodes.py b/tests/isolation/internal_probe_node/probe_nodes.py new file mode 100644 index 000000000..1c29996e7 --- /dev/null +++ b/tests/isolation/internal_probe_node/probe_nodes.py @@ -0,0 +1,75 @@ +from __future__ import annotations + + +class InternalIsolationProbeImage: + CATEGORY = "tests/isolation" + RETURN_TYPES = () + FUNCTION = "run" + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + def run(self): + from comfy_api.latest import UI + import torch + + image = torch.zeros((1, 2, 2, 3), dtype=torch.float32) + image[:, :, :, 0] = 1.0 + ui = UI.PreviewImage(image) + return {"ui": ui.as_dict(), "result": ()} + + +class InternalIsolationProbeAudio: + CATEGORY = "tests/isolation" + RETURN_TYPES = () + FUNCTION = "run" + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + def run(self): + from comfy_api.latest import UI + import torch + + waveform = torch.zeros((1, 1, 32), dtype=torch.float32) + audio = {"waveform": waveform, "sample_rate": 44100} + ui = UI.PreviewAudio(audio) + return {"ui": ui.as_dict(), "result": ()} + + +class InternalIsolationProbeUI3D: + CATEGORY = "tests/isolation" + RETURN_TYPES = () + FUNCTION = "run" + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + def run(self): + from comfy_api.latest import UI + import torch + + bg_image = torch.zeros((1, 2, 2, 3), dtype=torch.float32) + bg_image[:, :, :, 1] = 1.0 + camera_info = {"distance": 1.0} + ui = UI.PreviewUI3D("internal_probe_preview.obj", camera_info, bg_image=bg_image) + return {"ui": ui.as_dict(), "result": ()} + + +NODE_CLASS_MAPPINGS = { + "InternalIsolationProbeImage": InternalIsolationProbeImage, + "InternalIsolationProbeAudio": InternalIsolationProbeAudio, + "InternalIsolationProbeUI3D": InternalIsolationProbeUI3D, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "InternalIsolationProbeImage": "Internal Isolation Probe Image", + "InternalIsolationProbeAudio": "Internal Isolation Probe Audio", + "InternalIsolationProbeUI3D": "Internal Isolation Probe UI3D", +} diff --git a/tests/isolation/singleton_boundary_helpers.py b/tests/isolation/singleton_boundary_helpers.py new file mode 100644 index 000000000..f113f6a81 --- /dev/null +++ b/tests/isolation/singleton_boundary_helpers.py @@ -0,0 +1,955 @@ +from __future__ import annotations + +import asyncio +import importlib.util +import os +import sys +from pathlib import Path +from typing import Any + + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +UV_SEALED_WORKER_MODULE = COMFYUI_ROOT / "tests" / "isolation" / "uv_sealed_worker" / "__init__.py" +FORBIDDEN_MINIMAL_SEALED_MODULES = ( + "torch", + "folder_paths", + "comfy.utils", + "comfy.model_management", + "main", + "comfy.isolation.extension_wrapper", +) +FORBIDDEN_SEALED_SINGLETON_MODULES = ( + "torch", + "folder_paths", + "comfy.utils", + "comfy_execution.progress", +) +FORBIDDEN_EXACT_SMALL_PROXY_MODULES = FORBIDDEN_SEALED_SINGLETON_MODULES +FORBIDDEN_MODEL_MANAGEMENT_MODULES = ( + "comfy.model_management", +) + + +def _load_module_from_path(module_name: str, module_path: Path): + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"unable to build import spec for {module_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + except Exception: + sys.modules.pop(module_name, None) + raise + return module + + +def matching_modules(prefixes: tuple[str, ...], modules: set[str]) -> list[str]: + return sorted( + module_name + for module_name in modules + if any( + module_name == prefix or module_name.startswith(f"{prefix}.") + for prefix in prefixes + ) + ) + + +def _load_helper_proxy_service() -> Any | None: + try: + from comfy.isolation.proxies.helper_proxies import HelperProxiesService + except (ImportError, AttributeError): + return None + return HelperProxiesService + + +def _load_model_management_proxy() -> Any | None: + try: + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + except (ImportError, AttributeError): + return None + return ModelManagementProxy + + +async def _capture_minimal_sealed_worker_imports() -> dict[str, object]: + from pyisolate.sealed import SealedNodeExtension + + module_name = "tests.isolation.uv_sealed_worker_boundary_probe" + before = set(sys.modules) + extension = SealedNodeExtension() + module = _load_module_from_path(module_name, UV_SEALED_WORKER_MODULE) + try: + await extension.on_module_loaded(module) + node_list = await extension.list_nodes() + node_details = await extension.get_node_details("UVSealedRuntimeProbe") + imported = set(sys.modules) - before + return { + "mode": "minimal_sealed_worker", + "node_names": sorted(node_list), + "runtime_probe_function": node_details["function"], + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_MINIMAL_SEALED_MODULES, imported), + } + finally: + sys.modules.pop(module_name, None) + + +def capture_minimal_sealed_worker_imports() -> dict[str, object]: + return asyncio.run(_capture_minimal_sealed_worker_imports()) + + +class FakeSingletonCaller: + def __init__(self, methods: dict[str, Any], calls: list[dict[str, Any]], object_id: str): + self._methods = methods + self._calls = calls + self._object_id = object_id + + def __getattr__(self, name: str): + if name not in self._methods: + raise AttributeError(name) + + async def method(*args: Any, **kwargs: Any) -> Any: + self._calls.append( + { + "object_id": self._object_id, + "method": name, + "args": list(args), + "kwargs": dict(kwargs), + } + ) + result = self._methods[name] + return result(*args, **kwargs) if callable(result) else result + + return method + + +class FakeSingletonRPC: + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + self._device = {"__pyisolate_torch_device__": "cpu"} + self._services: dict[str, dict[str, Any]] = { + "FolderPathsProxy": { + "rpc_get_models_dir": lambda: "/sandbox/models", + "rpc_get_folder_names_and_paths": lambda: { + "checkpoints": { + "paths": ["/sandbox/models/checkpoints"], + "extensions": [".ckpt", ".safetensors"], + } + }, + "rpc_get_extension_mimetypes_cache": lambda: {"webp": "image"}, + "rpc_get_filename_list_cache": lambda: {}, + "rpc_get_temp_directory": lambda: "/sandbox/temp", + "rpc_get_input_directory": lambda: "/sandbox/input", + "rpc_get_output_directory": lambda: "/sandbox/output", + "rpc_get_user_directory": lambda: "/sandbox/user", + "rpc_get_annotated_filepath": self._get_annotated_filepath, + "rpc_exists_annotated_filepath": lambda _name: False, + "rpc_add_model_folder_path": lambda *_args, **_kwargs: None, + "rpc_get_folder_paths": lambda folder_name: [f"/sandbox/models/{folder_name}"], + "rpc_get_filename_list": lambda folder_name: [f"{folder_name}_fixture.safetensors"], + "rpc_get_full_path": lambda folder_name, filename: f"/sandbox/models/{folder_name}/{filename}", + }, + "UtilsProxy": { + "progress_bar_hook": lambda value, total, preview=None, node_id=None: { + "value": value, + "total": total, + "preview": preview, + "node_id": node_id, + } + }, + "ProgressProxy": { + "rpc_set_progress": lambda value, max_value, node_id=None, image=None: { + "value": value, + "max_value": max_value, + "node_id": node_id, + "image": image, + } + }, + "HelperProxiesService": { + "rpc_restore_input_types": lambda raw: raw, + }, + "ModelManagementProxy": { + "rpc_call": self._model_management_rpc_call, + }, + } + + def _model_management_rpc_call(self, method_name: str, args: Any = None, kwargs: Any = None) -> Any: + if method_name == "get_torch_device": + return self._device + elif method_name == "get_torch_device_name": + return "cpu" + elif method_name == "get_free_memory": + return 34359738368 + raise AssertionError(f"unexpected model_management method {method_name}") + + @staticmethod + def _get_annotated_filepath(name: str, default_dir: str | None = None) -> str: + if name.endswith("[output]"): + return f"/sandbox/output/{name[:-8]}" + if name.endswith("[input]"): + return f"/sandbox/input/{name[:-7]}" + if name.endswith("[temp]"): + return f"/sandbox/temp/{name[:-6]}" + base_dir = default_dir or "/sandbox/input" + return f"{base_dir}/{name}" + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + return FakeSingletonCaller(methods, self.calls, object_id) + + +def _clear_proxy_rpcs() -> None: + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + FolderPathsProxy.clear_rpc() + ProgressProxy.clear_rpc() + UtilsProxy.clear_rpc() + helper_proxy_service = _load_helper_proxy_service() + if helper_proxy_service is not None: + helper_proxy_service.clear_rpc() + model_management_proxy = _load_model_management_proxy() + if model_management_proxy is not None and hasattr(model_management_proxy, "clear_rpc"): + model_management_proxy.clear_rpc() + + +def prepare_sealed_singleton_proxies(fake_rpc: FakeSingletonRPC) -> None: + os.environ["PYISOLATE_CHILD"] = "1" + os.environ["PYISOLATE_IMPORT_TORCH"] = "0" + _clear_proxy_rpcs() + + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + FolderPathsProxy.set_rpc(fake_rpc) + ProgressProxy.set_rpc(fake_rpc) + UtilsProxy.set_rpc(fake_rpc) + helper_proxy_service = _load_helper_proxy_service() + if helper_proxy_service is not None: + helper_proxy_service.set_rpc(fake_rpc) + model_management_proxy = _load_model_management_proxy() + if model_management_proxy is not None and hasattr(model_management_proxy, "set_rpc"): + model_management_proxy.set_rpc(fake_rpc) + + +def reset_forbidden_singleton_modules() -> None: + for module_name in ( + "folder_paths", + "comfy.utils", + "comfy_execution.progress", + ): + sys.modules.pop(module_name, None) + + +class FakeExactRelayCaller: + def __init__(self, methods: dict[str, Any], transcripts: list[dict[str, Any]], object_id: str): + self._methods = methods + self._transcripts = transcripts + self._object_id = object_id + + def __getattr__(self, name: str): + if name not in self._methods: + raise AttributeError(name) + + async def method(*args: Any, **kwargs: Any) -> Any: + self._transcripts.append( + { + "phase": "child_call", + "object_id": self._object_id, + "method": name, + "args": list(args), + "kwargs": dict(kwargs), + } + ) + impl = self._methods[name] + self._transcripts.append( + { + "phase": "host_invocation", + "object_id": self._object_id, + "method": name, + "target": impl["target"], + "args": list(args), + "kwargs": dict(kwargs), + } + ) + result = impl["result"](*args, **kwargs) if callable(impl["result"]) else impl["result"] + self._transcripts.append( + { + "phase": "result", + "object_id": self._object_id, + "method": name, + "result": result, + } + ) + return result + + return method + + +class FakeExactRelayRPC: + def __init__(self) -> None: + self.transcripts: list[dict[str, Any]] = [] + self._device = {"__pyisolate_torch_device__": "cpu"} + self._services: dict[str, dict[str, Any]] = { + "FolderPathsProxy": { + "rpc_get_models_dir": { + "target": "folder_paths.models_dir", + "result": "/sandbox/models", + }, + "rpc_get_temp_directory": { + "target": "folder_paths.get_temp_directory", + "result": "/sandbox/temp", + }, + "rpc_get_input_directory": { + "target": "folder_paths.get_input_directory", + "result": "/sandbox/input", + }, + "rpc_get_output_directory": { + "target": "folder_paths.get_output_directory", + "result": "/sandbox/output", + }, + "rpc_get_user_directory": { + "target": "folder_paths.get_user_directory", + "result": "/sandbox/user", + }, + "rpc_get_folder_names_and_paths": { + "target": "folder_paths.folder_names_and_paths", + "result": { + "checkpoints": { + "paths": ["/sandbox/models/checkpoints"], + "extensions": [".ckpt", ".safetensors"], + } + }, + }, + "rpc_get_extension_mimetypes_cache": { + "target": "folder_paths.extension_mimetypes_cache", + "result": {"webp": "image"}, + }, + "rpc_get_filename_list_cache": { + "target": "folder_paths.filename_list_cache", + "result": {}, + }, + "rpc_get_annotated_filepath": { + "target": "folder_paths.get_annotated_filepath", + "result": lambda name, default_dir=None: FakeSingletonRPC._get_annotated_filepath(name, default_dir), + }, + "rpc_exists_annotated_filepath": { + "target": "folder_paths.exists_annotated_filepath", + "result": False, + }, + "rpc_add_model_folder_path": { + "target": "folder_paths.add_model_folder_path", + "result": None, + }, + "rpc_get_folder_paths": { + "target": "folder_paths.get_folder_paths", + "result": lambda folder_name: [f"/sandbox/models/{folder_name}"], + }, + "rpc_get_filename_list": { + "target": "folder_paths.get_filename_list", + "result": lambda folder_name: [f"{folder_name}_fixture.safetensors"], + }, + "rpc_get_full_path": { + "target": "folder_paths.get_full_path", + "result": lambda folder_name, filename: f"/sandbox/models/{folder_name}/{filename}", + }, + }, + "UtilsProxy": { + "progress_bar_hook": { + "target": "comfy.utils.PROGRESS_BAR_HOOK", + "result": lambda value, total, preview=None, node_id=None: { + "value": value, + "total": total, + "preview": preview, + "node_id": node_id, + }, + }, + }, + "ProgressProxy": { + "rpc_set_progress": { + "target": "comfy_execution.progress.get_progress_state().update_progress", + "result": None, + }, + }, + "HelperProxiesService": { + "rpc_restore_input_types": { + "target": "comfy.isolation.proxies.helper_proxies.restore_input_types", + "result": lambda raw: raw, + } + }, + "ModelManagementProxy": { + "rpc_call": { + "target": "comfy.model_management.*", + "result": self._model_management_rpc_call, + }, + }, + } + + def _model_management_rpc_call(self, method_name: str, args: Any = None, kwargs: Any = None) -> Any: + device = {"__pyisolate_torch_device__": "cpu"} + if method_name == "get_torch_device": + return device + elif method_name == "get_torch_device_name": + return "cpu" + elif method_name == "get_free_memory": + return 34359738368 + raise AssertionError(f"unexpected exact-relay method {method_name}") + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + return FakeExactRelayCaller(methods, self.transcripts, object_id) + + +def capture_exact_small_proxy_relay() -> dict[str, object]: + reset_forbidden_singleton_modules() + fake_rpc = FakeExactRelayRPC() + previous_child = os.environ.get("PYISOLATE_CHILD") + previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH") + try: + prepare_sealed_singleton_proxies(fake_rpc) + + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.helper_proxies import restore_input_types + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + folder_proxy = FolderPathsProxy() + utils_proxy = UtilsProxy() + progress_proxy = ProgressProxy() + before = set(sys.modules) + + restored = restore_input_types( + { + "required": { + "image": {"__pyisolate_any_type__": True, "value": "*"}, + } + } + ) + folder_path = folder_proxy.get_annotated_filepath("demo.png[input]") + models_dir = folder_proxy.models_dir + folder_names_and_paths = folder_proxy.folder_names_and_paths + asyncio.run(utils_proxy.progress_bar_hook(2, 5, node_id="node-17")) + progress_proxy.set_progress(1.5, 5.0, node_id="node-17") + + imported = set(sys.modules) - before + return { + "mode": "exact_small_proxy_relay", + "folder_path": folder_path, + "models_dir": models_dir, + "folder_names_and_paths": folder_names_and_paths, + "restored_any_type": str(restored["required"]["image"]), + "transcripts": fake_rpc.transcripts, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_EXACT_SMALL_PROXY_MODULES, imported), + } + finally: + _clear_proxy_rpcs() + if previous_child is None: + os.environ.pop("PYISOLATE_CHILD", None) + else: + os.environ["PYISOLATE_CHILD"] = previous_child + if previous_import_torch is None: + os.environ.pop("PYISOLATE_IMPORT_TORCH", None) + else: + os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch + + +class FakeModelManagementExactRelayRPC: + def __init__(self) -> None: + self.transcripts: list[dict[str, object]] = [] + self._device = {"__pyisolate_torch_device__": "cpu"} + self._services: dict[str, dict[str, Any]] = { + "ModelManagementProxy": { + "rpc_call": self._rpc_call, + } + } + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + return _ModelManagementExactRelayCaller(methods) + + def _rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any: + self.transcripts.append( + { + "phase": "child_call", + "object_id": "ModelManagementProxy", + "method": method_name, + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + target = f"comfy.model_management.{method_name}" + self.transcripts.append( + { + "phase": "host_invocation", + "object_id": "ModelManagementProxy", + "method": method_name, + "target": target, + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + if method_name == "get_torch_device": + result = self._device + elif method_name == "get_torch_device_name": + result = "cpu" + elif method_name == "get_free_memory": + result = 34359738368 + else: + raise AssertionError(f"unexpected exact-relay method {method_name}") + self.transcripts.append( + { + "phase": "result", + "object_id": "ModelManagementProxy", + "method": method_name, + "result": _json_safe(result), + } + ) + return result + + +class _ModelManagementExactRelayCaller: + def __init__(self, methods: dict[str, Any]): + self._methods = methods + + def __getattr__(self, name: str): + if name not in self._methods: + raise AttributeError(name) + + async def method(*args: Any, **kwargs: Any) -> Any: + impl = self._methods[name] + return impl(*args, **kwargs) if callable(impl) else impl + + return method + + +def _json_safe(value: Any) -> Any: + if callable(value): + return f"" + if isinstance(value, tuple): + return [_json_safe(item) for item in value] + if isinstance(value, list): + return [_json_safe(item) for item in value] + if isinstance(value, dict): + return {key: _json_safe(inner) for key, inner in value.items()} + return value + + +def capture_model_management_exact_relay() -> dict[str, object]: + for module_name in FORBIDDEN_MODEL_MANAGEMENT_MODULES: + sys.modules.pop(module_name, None) + + fake_rpc = FakeModelManagementExactRelayRPC() + previous_child = os.environ.get("PYISOLATE_CHILD") + previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH") + try: + os.environ["PYISOLATE_CHILD"] = "1" + os.environ["PYISOLATE_IMPORT_TORCH"] = "0" + + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + + if hasattr(ModelManagementProxy, "clear_rpc"): + ModelManagementProxy.clear_rpc() + if hasattr(ModelManagementProxy, "set_rpc"): + ModelManagementProxy.set_rpc(fake_rpc) + + proxy = ModelManagementProxy() + before = set(sys.modules) + device = proxy.get_torch_device() + device_name = proxy.get_torch_device_name(device) + free_memory = proxy.get_free_memory(device) + imported = set(sys.modules) - before + return { + "mode": "model_management_exact_relay", + "device": str(device), + "device_type": getattr(device, "type", None), + "device_name": device_name, + "free_memory": free_memory, + "transcripts": fake_rpc.transcripts, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_MODEL_MANAGEMENT_MODULES, imported), + } + finally: + model_management_proxy = _load_model_management_proxy() + if model_management_proxy is not None and hasattr(model_management_proxy, "clear_rpc"): + model_management_proxy.clear_rpc() + if previous_child is None: + os.environ.pop("PYISOLATE_CHILD", None) + else: + os.environ["PYISOLATE_CHILD"] = previous_child + if previous_import_torch is None: + os.environ.pop("PYISOLATE_IMPORT_TORCH", None) + else: + os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch + + +FORBIDDEN_PROMPT_WEB_MODULES = ( + "server", + "aiohttp", + "comfy.isolation.extension_wrapper", +) +FORBIDDEN_EXACT_BOOTSTRAP_MODULES = ( + "comfy.isolation.adapter", + "folder_paths", + "comfy.utils", + "comfy.model_management", + "server", + "main", + "comfy.isolation.extension_wrapper", +) + + +class _PromptServiceExactRelayCaller: + def __init__(self, methods: dict[str, Any], transcripts: list[dict[str, Any]], object_id: str): + self._methods = methods + self._transcripts = transcripts + self._object_id = object_id + + def __getattr__(self, name: str): + if name not in self._methods: + raise AttributeError(name) + + async def method(*args: Any, **kwargs: Any) -> Any: + self._transcripts.append( + { + "phase": "child_call", + "object_id": self._object_id, + "method": name, + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + impl = self._methods[name] + self._transcripts.append( + { + "phase": "host_invocation", + "object_id": self._object_id, + "method": name, + "target": impl["target"], + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + result = impl["result"](*args, **kwargs) if callable(impl["result"]) else impl["result"] + self._transcripts.append( + { + "phase": "result", + "object_id": self._object_id, + "method": name, + "result": _json_safe(result), + } + ) + return result + + return method + + +class FakePromptWebRPC: + def __init__(self) -> None: + self.transcripts: list[dict[str, Any]] = [] + self._services = { + "PromptServerService": { + "ui_send_progress_text": { + "target": "server.PromptServer.instance.send_progress_text", + "result": None, + }, + "register_route_rpc": { + "target": "server.PromptServer.instance.routes.add_route", + "result": None, + }, + } + } + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + return _PromptServiceExactRelayCaller(methods, self.transcripts, object_id) + + +class FakeWebDirectoryProxy: + def __init__(self, transcripts: list[dict[str, Any]]): + self._transcripts = transcripts + + def get_web_file(self, extension_name: str, relative_path: str) -> dict[str, Any]: + self._transcripts.append( + { + "phase": "child_call", + "object_id": "WebDirectoryProxy", + "method": "get_web_file", + "args": [extension_name, relative_path], + "kwargs": {}, + } + ) + self._transcripts.append( + { + "phase": "host_invocation", + "object_id": "WebDirectoryProxy", + "method": "get_web_file", + "target": "comfy.isolation.proxies.web_directory_proxy.WebDirectoryProxy.get_web_file", + "args": [extension_name, relative_path], + "kwargs": {}, + } + ) + result = { + "content": "Y29uc29sZS5sb2coJ2RlbycpOw==", + "content_type": "application/javascript", + } + self._transcripts.append( + { + "phase": "result", + "object_id": "WebDirectoryProxy", + "method": "get_web_file", + "result": result, + } + ) + return result + + +def capture_prompt_web_exact_relay() -> dict[str, object]: + for module_name in FORBIDDEN_PROMPT_WEB_MODULES: + sys.modules.pop(module_name, None) + + fake_rpc = FakePromptWebRPC() + + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache + + PromptServerStub.set_rpc(fake_rpc) + stub = PromptServerStub() + cache = WebDirectoryCache() + cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts)) + + before = set(sys.modules) + + def demo_handler(_request): + return {"ok": True} + + stub.send_progress_text("hello", "node-17") + stub.routes.get("/demo")(demo_handler) + web_file = cache.get_file("demo_ext", "js/app.js") + imported = set(sys.modules) - before + return { + "mode": "prompt_web_exact_relay", + "web_file": { + "content_type": web_file["content_type"] if web_file else None, + "content": web_file["content"].decode("utf-8") if web_file else None, + }, + "transcripts": fake_rpc.transcripts, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_PROMPT_WEB_MODULES, imported), + } + + +class FakeExactBootstrapRPC: + def __init__(self) -> None: + self.transcripts: list[dict[str, Any]] = [] + self._device = {"__pyisolate_torch_device__": "cpu"} + self._services: dict[str, dict[str, Any]] = { + "FolderPathsProxy": FakeExactRelayRPC()._services["FolderPathsProxy"], + "HelperProxiesService": FakeExactRelayRPC()._services["HelperProxiesService"], + "ProgressProxy": FakeExactRelayRPC()._services["ProgressProxy"], + "UtilsProxy": FakeExactRelayRPC()._services["UtilsProxy"], + "PromptServerService": { + "ui_send_sync": { + "target": "server.PromptServer.instance.send_sync", + "result": None, + }, + "ui_send": { + "target": "server.PromptServer.instance.send", + "result": None, + }, + "ui_send_progress_text": { + "target": "server.PromptServer.instance.send_progress_text", + "result": None, + }, + "register_route_rpc": { + "target": "server.PromptServer.instance.routes.add_route", + "result": None, + }, + }, + "ModelManagementProxy": { + "rpc_call": self._rpc_call, + }, + } + + def create_caller(self, cls: Any, object_id: str): + methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id)) + if methods is None: + raise KeyError(object_id) + if object_id == "ModelManagementProxy": + return _ModelManagementExactRelayCaller(methods) + return _PromptServiceExactRelayCaller(methods, self.transcripts, object_id) + + def _rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any: + self.transcripts.append( + { + "phase": "child_call", + "object_id": "ModelManagementProxy", + "method": method_name, + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + self.transcripts.append( + { + "phase": "host_invocation", + "object_id": "ModelManagementProxy", + "method": method_name, + "target": f"comfy.model_management.{method_name}", + "args": _json_safe(args), + "kwargs": _json_safe(kwargs), + } + ) + result = self._device if method_name == "get_torch_device" else None + self.transcripts.append( + { + "phase": "result", + "object_id": "ModelManagementProxy", + "method": method_name, + "result": _json_safe(result), + } + ) + return result + + +def capture_exact_proxy_bootstrap_contract() -> dict[str, object]: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance, set_child_rpc_instance + + from comfy.isolation.adapter import ComfyUIAdapter + from comfy.isolation.child_hooks import initialize_child_process + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.helper_proxies import HelperProxiesService + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + host_services = sorted(cls.__name__ for cls in ComfyUIAdapter().provide_rpc_services()) + + for module_name in FORBIDDEN_EXACT_BOOTSTRAP_MODULES: + sys.modules.pop(module_name, None) + + previous_child = os.environ.get("PYISOLATE_CHILD") + previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH") + os.environ["PYISOLATE_CHILD"] = "1" + os.environ["PYISOLATE_IMPORT_TORCH"] = "0" + + _clear_proxy_rpcs() + if hasattr(PromptServerStub, "clear_rpc"): + PromptServerStub.clear_rpc() + else: + PromptServerStub._rpc = None # type: ignore[attr-defined] + fake_rpc = FakeExactBootstrapRPC() + set_child_rpc_instance(fake_rpc) + + before = set(sys.modules) + try: + initialize_child_process() + imported = set(sys.modules) - before + matrix = { + "base.py": { + "bound": get_child_rpc_instance() is fake_rpc, + "details": {"child_rpc_instance": get_child_rpc_instance() is fake_rpc}, + }, + "folder_paths_proxy.py": { + "bound": "FolderPathsProxy" in host_services and FolderPathsProxy._rpc is not None, + "details": {"host_service": "FolderPathsProxy" in host_services, "child_rpc": FolderPathsProxy._rpc is not None}, + }, + "helper_proxies.py": { + "bound": "HelperProxiesService" in host_services and HelperProxiesService._rpc is not None, + "details": {"host_service": "HelperProxiesService" in host_services, "child_rpc": HelperProxiesService._rpc is not None}, + }, + "model_management_proxy.py": { + "bound": "ModelManagementProxy" in host_services and ModelManagementProxy._rpc is not None, + "details": {"host_service": "ModelManagementProxy" in host_services, "child_rpc": ModelManagementProxy._rpc is not None}, + }, + "progress_proxy.py": { + "bound": "ProgressProxy" in host_services and ProgressProxy._rpc is not None, + "details": {"host_service": "ProgressProxy" in host_services, "child_rpc": ProgressProxy._rpc is not None}, + }, + "prompt_server_impl.py": { + "bound": "PromptServerService" in host_services and PromptServerStub._rpc is not None, + "details": {"host_service": "PromptServerService" in host_services, "child_rpc": PromptServerStub._rpc is not None}, + }, + "utils_proxy.py": { + "bound": "UtilsProxy" in host_services and UtilsProxy._rpc is not None, + "details": {"host_service": "UtilsProxy" in host_services, "child_rpc": UtilsProxy._rpc is not None}, + }, + "web_directory_proxy.py": { + "bound": "WebDirectoryProxy" in host_services, + "details": {"host_service": "WebDirectoryProxy" in host_services}, + }, + } + finally: + set_child_rpc_instance(None) + if previous_child is None: + os.environ.pop("PYISOLATE_CHILD", None) + else: + os.environ["PYISOLATE_CHILD"] = previous_child + if previous_import_torch is None: + os.environ.pop("PYISOLATE_IMPORT_TORCH", None) + else: + os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch + + omitted = sorted(name for name, status in matrix.items() if not status["bound"]) + return { + "mode": "exact_proxy_bootstrap_contract", + "host_services": host_services, + "matrix": matrix, + "omitted_proxies": omitted, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_EXACT_BOOTSTRAP_MODULES, imported), + } + +def capture_sealed_singleton_imports() -> dict[str, object]: + reset_forbidden_singleton_modules() + fake_rpc = FakeSingletonRPC() + previous_child = os.environ.get("PYISOLATE_CHILD") + previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH") + before = set(sys.modules) + try: + prepare_sealed_singleton_proxies(fake_rpc) + + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + + folder_proxy = FolderPathsProxy() + progress_proxy = ProgressProxy() + utils_proxy = UtilsProxy() + + folder_path = folder_proxy.get_annotated_filepath("demo.png[input]") + temp_dir = folder_proxy.get_temp_directory() + models_dir = folder_proxy.models_dir + asyncio.run(utils_proxy.progress_bar_hook(2, 5, node_id="node-17")) + progress_proxy.set_progress(1.5, 5.0, node_id="node-17") + + imported = set(sys.modules) - before + return { + "mode": "sealed_singletons", + "folder_path": folder_path, + "temp_dir": temp_dir, + "models_dir": models_dir, + "rpc_calls": fake_rpc.calls, + "modules": sorted(imported), + "forbidden_matches": matching_modules(FORBIDDEN_SEALED_SINGLETON_MODULES, imported), + } + finally: + _clear_proxy_rpcs() + if previous_child is None: + os.environ.pop("PYISOLATE_CHILD", None) + else: + os.environ["PYISOLATE_CHILD"] = previous_child + if previous_import_torch is None: + os.environ.pop("PYISOLATE_IMPORT_TORCH", None) + else: + os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch diff --git a/tests/isolation/stage_internal_probe_node.py b/tests/isolation/stage_internal_probe_node.py new file mode 100644 index 000000000..b072ab43e --- /dev/null +++ b/tests/isolation/stage_internal_probe_node.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import argparse +import shutil +import sys +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator + + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +PROBE_SOURCE_ROOT = COMFYUI_ROOT / "tests" / "isolation" / "internal_probe_node" +PROBE_NODE_NAME = "InternalIsolationProbeNode" + +PYPROJECT_CONTENT = """[project] +name = "InternalIsolationProbeNode" +version = "0.0.1" + +[tool.comfy.isolation] +can_isolate = true +share_torch = true +""" + + +def _probe_target_root(comfy_root: Path) -> Path: + return Path(comfy_root) / "custom_nodes" / PROBE_NODE_NAME + + +def stage_probe_node(comfy_root: Path) -> Path: + if not PROBE_SOURCE_ROOT.is_dir(): + raise RuntimeError(f"Missing probe source directory: {PROBE_SOURCE_ROOT}") + + target_root = _probe_target_root(comfy_root) + target_root.mkdir(parents=True, exist_ok=True) + for source_path in PROBE_SOURCE_ROOT.iterdir(): + destination_path = target_root / source_path.name + if source_path.is_dir(): + shutil.copytree(source_path, destination_path, dirs_exist_ok=True) + else: + shutil.copy2(source_path, destination_path) + + (target_root / "pyproject.toml").write_text(PYPROJECT_CONTENT, encoding="utf-8") + return target_root + + +@contextmanager +def staged_probe_node() -> Iterator[Path]: + staging_root = Path(tempfile.mkdtemp(prefix="comfyui_internal_probe_")) + try: + yield stage_probe_node(staging_root) + finally: + shutil.rmtree(staging_root, ignore_errors=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Stage the internal isolation probe node under an explicit ComfyUI root." + ) + parser.add_argument( + "--target-root", + type=Path, + required=True, + help="Explicit ComfyUI root to stage under. Caller owns cleanup.", + ) + args = parser.parse_args() + + staged = stage_probe_node(args.target_root) + sys.stdout.write(f"{staged}\n") diff --git a/tests/isolation/test_client_snapshot.py b/tests/isolation/test_client_snapshot.py new file mode 100644 index 000000000..0eedf6b41 --- /dev/null +++ b/tests/isolation/test_client_snapshot.py @@ -0,0 +1,122 @@ +"""Tests for pyisolate._internal.client import-time snapshot handling.""" + +import json +import os +import subprocess +import sys +from pathlib import Path + +import pytest + +# Paths needed for subprocess +PYISOLATE_ROOT = str(Path(__file__).parent.parent) +COMFYUI_ROOT = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + +SCRIPT = """ +import json, sys +import pyisolate._internal.client # noqa: F401 # triggers snapshot logic +print(json.dumps(sys.path[:6])) +""" + + +def _run_client_process(env): + # Ensure subprocess can find pyisolate and ComfyUI + pythonpath_parts = [PYISOLATE_ROOT, COMFYUI_ROOT] + existing = env.get("PYTHONPATH", "") + if existing: + pythonpath_parts.append(existing) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_parts) + + result = subprocess.run( # noqa: S603 + [sys.executable, "-c", SCRIPT], + capture_output=True, + text=True, + env=env, + check=True, + ) + stdout = result.stdout.strip().splitlines()[-1] + return json.loads(stdout) + + +@pytest.fixture() +def comfy_module_path(tmp_path): + comfy_root = tmp_path / "ComfyUI" + module_path = comfy_root / "custom_nodes" / "TestNode" + module_path.mkdir(parents=True) + return comfy_root, module_path + + +def test_snapshot_applied_and_comfy_root_prepend(tmp_path, comfy_module_path): + comfy_root, module_path = comfy_module_path + # Must include real ComfyUI path for utils validation to pass + host_paths = [COMFYUI_ROOT, "/host/lib1", "/host/lib2"] + snapshot = { + "sys_path": host_paths, + "sys_executable": sys.executable, + "sys_prefix": sys.prefix, + "environment": {}, + } + snapshot_path = tmp_path / "snapshot.json" + snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8") + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(snapshot_path), + "PYISOLATE_MODULE_PATH": str(module_path), + } + ) + + path_prefix = _run_client_process(env) + + # Current client behavior preserves the runtime bootstrap path order and + # keeps the resolved ComfyUI root available for imports. + assert COMFYUI_ROOT in path_prefix + # Module path should not override runtime root selection. + assert str(comfy_root) not in path_prefix + + +def test_missing_snapshot_file_does_not_crash(tmp_path, comfy_module_path): + _, module_path = comfy_module_path + missing_snapshot = tmp_path / "missing.json" + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(missing_snapshot), + "PYISOLATE_MODULE_PATH": str(module_path), + } + ) + + # Should not raise even though snapshot path is missing + paths = _run_client_process(env) + assert len(paths) > 0 + + +def test_no_comfy_root_when_module_path_absent(tmp_path): + # Must include real ComfyUI path for utils validation to pass + host_paths = [COMFYUI_ROOT, "/alpha", "/beta"] + snapshot = { + "sys_path": host_paths, + "sys_executable": sys.executable, + "sys_prefix": sys.prefix, + "environment": {}, + } + snapshot_path = tmp_path / "snapshot.json" + snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8") + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(snapshot_path), + } + ) + + paths = _run_client_process(env) + # Runtime path bootstrap keeps ComfyUI importability regardless of host + # snapshot extras. + assert COMFYUI_ROOT in paths + assert "/alpha" not in paths and "/beta" not in paths diff --git a/tests/isolation/test_cuda_wheels_and_env_flags.py b/tests/isolation/test_cuda_wheels_and_env_flags.py new file mode 100644 index 000000000..f0361d5ef --- /dev/null +++ b/tests/isolation/test_cuda_wheels_and_env_flags.py @@ -0,0 +1,460 @@ +"""Synthetic integration coverage for manifest plumbing and env flags. + +These tests do not perform a real wheel install or a real ComfyUI E2E run. +""" + +import asyncio +import logging +import os +import sys +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +import comfy.isolation as isolation_pkg +from comfy.isolation import runtime_helpers +from comfy.isolation import extension_loader as extension_loader_module +from comfy.isolation import extension_wrapper as extension_wrapper_module +from comfy.isolation import model_patcher_proxy_utils +from comfy.isolation.extension_loader import ExtensionLoadError, load_isolated_node +from comfy.isolation.extension_wrapper import ComfyNodeExtension +from comfy.isolation.model_patcher_proxy_utils import maybe_wrap_model_for_isolation +from pyisolate._internal.environment_conda import _generate_pixi_toml + + +class _DummyExtension: + def __init__(self) -> None: + self.name = "demo-extension" + + async def stop(self) -> None: + return None + + +def _write_manifest(node_dir, manifest_text: str) -> None: + (node_dir / "pyproject.toml").write_text(manifest_text, encoding="utf-8") + + +def test_load_isolated_node_passes_normalized_cuda_wheels_config(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["flash-attn>=1.0", "sageattention==0.1"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = true + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["flash_attn", "sageattention"] + +[tool.comfy.isolation.cuda_wheels.package_map] +flash_attn = "flash-attn-special" +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + specs = asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert len(specs) == 1 + assert captured["sandbox_mode"] == "required" + assert captured["cuda_wheels"] == { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn", "sageattention"], + "package_map": {"flash-attn": "flash-attn-special"}, + } + + +def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency( + tmp_path, monkeypatch +): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["numpy>=1.0"] + +[tool.comfy.isolation] +can_isolate = true + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["flash-attn"] +""".strip(), + ) + + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + with pytest.raises(ExtensionLoadError, match="undeclared dependencies"): + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + +def test_conda_cuda_wheels_declared_packages_do_not_force_pixi_solve(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["numpy>=1.0", "spconv", "cumm", "flash-attn"] + +[tool.comfy.isolation] +can_isolate = true +package_manager = "conda" +conda_channels = ["conda-forge"] + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["spconv", "cumm", "flash-attn"] +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "disabled", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + generated = _generate_pixi_toml(captured) + assert 'numpy = ">=1.0"' in generated + assert "spconv =" not in generated + assert "cumm =" not in generated + assert "flash-attn =" not in generated + + +def test_conda_cuda_wheels_loader_accepts_sam3d_contract(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = [ + "torch", + "torchvision", + "pytorch3d", + "gsplat", + "nvdiffrast", + "flash-attn", + "sageattention", + "spconv", + "cumm", +] + +[tool.comfy.isolation] +can_isolate = true +package_manager = "conda" +conda_channels = ["conda-forge"] + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["pytorch3d", "gsplat", "nvdiffrast", "flash-attn", "sageattention", "spconv", "cumm"] +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "disabled", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert captured["package_manager"] == "conda" + assert captured["cuda_wheels"] == { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": [ + "pytorch3d", + "gsplat", + "nvdiffrast", + "flash-attn", + "sageattention", + "spconv", + "cumm", + ], + "package_map": {}, + } + + +def test_load_isolated_node_omits_cuda_wheels_when_not_configured(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["numpy>=1.0"] + +[tool.comfy.isolation] +can_isolate = true +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager) + monkeypatch.setattr( + extension_loader_module, + "load_host_policy", + lambda base_path: { + "sandbox_mode": "disabled", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True) + monkeypatch.setattr( + extension_loader_module, + "load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert captured["sandbox_mode"] == "disabled" + assert "cuda_wheels" not in captured + + +def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch): + class DummyRegistry: + def register(self, model): + return "model-123" + + class DummyProxy: + def __init__(self, model_id, registry, manage_lifecycle): + self.model_id = model_id + self.registry = registry + self.manage_lifecycle = manage_lifecycle + + monkeypatch.setattr(model_patcher_proxy_utils.args, "use_process_isolation", True) + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.delenv("PYISOLATE_CHILD", raising=False) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy_registry", + SimpleNamespace(ModelPatcherRegistry=DummyRegistry), + ) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy", + SimpleNamespace(ModelPatcherProxy=DummyProxy), + ) + + wrapped = cast(Any, maybe_wrap_model_for_isolation(object())) + + assert isinstance(wrapped, DummyProxy) + assert getattr(wrapped, "model_id") == "model-123" + assert getattr(wrapped, "manage_lifecycle") is True + + +def test_flush_transport_state_uses_child_env_without_legacy_flag(monkeypatch): + monkeypatch.setenv("PYISOLATE_CHILD", "1") + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.setattr(extension_wrapper_module, "_flush_tensor_transport_state", lambda marker: 3) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy_registry", + SimpleNamespace( + ModelPatcherRegistry=lambda: SimpleNamespace( + sweep_pending_cleanup=lambda: 0 + ) + ), + ) + + flushed = asyncio.run( + ComfyNodeExtension.flush_transport_state(SimpleNamespace(name="demo")) + ) + + assert flushed == 3 + + +def test_build_stub_class_relieves_host_vram_without_legacy_flag(monkeypatch): + relieve_calls: list[str] = [] + + async def deserialize_from_isolation(result, extension): + return result + + monkeypatch.delenv("PYISOLATE_CHILD", raising=False) + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.setattr( + runtime_helpers, "_relieve_host_vram_pressure", lambda marker, logger: relieve_calls.append(marker) + ) + monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None) + monkeypatch.setattr(isolation_pkg, "_RUNNING_EXTENSIONS", {}, raising=False) + monkeypatch.setitem( + sys.modules, + "pyisolate._internal.model_serialization", + SimpleNamespace( + serialize_for_isolation=lambda payload: payload, + deserialize_from_isolation=deserialize_from_isolation, + ), + ) + + class DummyExtension: + name = "demo-extension" + module_path = os.getcwd() + + async def execute_node(self, node_name, **inputs): + return inputs + + stub_cls = runtime_helpers.build_stub_class( + "DemoNode", + {"input_types": {}}, + DummyExtension(), + {}, + logging.getLogger("test"), + ) + + result = asyncio.run( + getattr(stub_cls, "_pyisolate_execute")(SimpleNamespace(), value=1) + ) + + assert relieve_calls == ["RUNTIME:pre_execute"] + assert result == {"value": 1} diff --git a/tests/isolation/test_exact_proxy_bootstrap_contract.py b/tests/isolation/test_exact_proxy_bootstrap_contract.py new file mode 100644 index 000000000..c67fb5ac4 --- /dev/null +++ b/tests/isolation/test_exact_proxy_bootstrap_contract.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from tests.isolation.singleton_boundary_helpers import ( + capture_exact_proxy_bootstrap_contract, +) + + +def test_no_proxy_omission_allowed() -> None: + payload = capture_exact_proxy_bootstrap_contract() + + assert payload["omitted_proxies"] == [] + assert payload["forbidden_matches"] == [] + + matrix = payload["matrix"] + assert matrix["base.py"]["bound"] is True + assert matrix["folder_paths_proxy.py"]["bound"] is True + assert matrix["helper_proxies.py"]["bound"] is True + assert matrix["model_management_proxy.py"]["bound"] is True + assert matrix["progress_proxy.py"]["bound"] is True + assert matrix["prompt_server_impl.py"]["bound"] is True + assert matrix["utils_proxy.py"]["bound"] is True + assert matrix["web_directory_proxy.py"]["bound"] is True diff --git a/tests/isolation/test_exact_proxy_relay_matrix.py b/tests/isolation/test_exact_proxy_relay_matrix.py new file mode 100644 index 000000000..ca9dbf94d --- /dev/null +++ b/tests/isolation/test_exact_proxy_relay_matrix.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from tests.isolation.singleton_boundary_helpers import ( + capture_exact_small_proxy_relay, + capture_model_management_exact_relay, + capture_prompt_web_exact_relay, +) + + +def _transcripts_for(payload: dict[str, object], object_id: str, method: str) -> list[dict[str, object]]: + return [ + entry + for entry in payload["transcripts"] + if entry["object_id"] == object_id and entry["method"] == method + ] + + +def test_folder_paths_exact_relay() -> None: + payload = capture_exact_small_proxy_relay() + + assert payload["forbidden_matches"] == [] + assert payload["models_dir"] == "/sandbox/models" + assert payload["folder_path"] == "/sandbox/input/demo.png" + + models_dir_calls = _transcripts_for(payload, "FolderPathsProxy", "rpc_get_models_dir") + annotated_calls = _transcripts_for(payload, "FolderPathsProxy", "rpc_get_annotated_filepath") + + assert models_dir_calls + assert annotated_calls + assert all(entry["phase"] != "child_call" or entry["method"] != "rpc_snapshot" for entry in payload["transcripts"]) + + +def test_progress_exact_relay() -> None: + payload = capture_exact_small_proxy_relay() + + progress_calls = _transcripts_for(payload, "ProgressProxy", "rpc_set_progress") + + assert progress_calls + host_targets = [entry["target"] for entry in progress_calls if entry["phase"] == "host_invocation"] + assert host_targets == ["comfy_execution.progress.get_progress_state().update_progress"] + result_entries = [entry for entry in progress_calls if entry["phase"] == "result"] + assert result_entries == [{"phase": "result", "object_id": "ProgressProxy", "method": "rpc_set_progress", "result": None}] + + +def test_utils_exact_relay() -> None: + payload = capture_exact_small_proxy_relay() + + utils_calls = _transcripts_for(payload, "UtilsProxy", "progress_bar_hook") + + assert utils_calls + host_targets = [entry["target"] for entry in utils_calls if entry["phase"] == "host_invocation"] + assert host_targets == ["comfy.utils.PROGRESS_BAR_HOOK"] + result_entries = [entry for entry in utils_calls if entry["phase"] == "result"] + assert result_entries + assert result_entries[0]["result"]["value"] == 2 + assert result_entries[0]["result"]["total"] == 5 + + +def test_helper_proxy_exact_relay() -> None: + payload = capture_exact_small_proxy_relay() + + helper_calls = _transcripts_for(payload, "HelperProxiesService", "rpc_restore_input_types") + + assert helper_calls + host_targets = [entry["target"] for entry in helper_calls if entry["phase"] == "host_invocation"] + assert host_targets == ["comfy.isolation.proxies.helper_proxies.restore_input_types"] + assert payload["restored_any_type"] == "*" + + +def test_model_management_exact_relay() -> None: + payload = capture_model_management_exact_relay() + + model_calls = _transcripts_for(payload, "ModelManagementProxy", "get_torch_device") + model_calls += _transcripts_for(payload, "ModelManagementProxy", "get_torch_device_name") + model_calls += _transcripts_for(payload, "ModelManagementProxy", "get_free_memory") + + assert payload["forbidden_matches"] == [] + assert model_calls + host_targets = [ + entry["target"] + for entry in payload["transcripts"] + if entry["phase"] == "host_invocation" + ] + assert host_targets == [ + "comfy.model_management.get_torch_device", + "comfy.model_management.get_torch_device_name", + "comfy.model_management.get_free_memory", + ] + + +def test_model_management_capability_preserved() -> None: + payload = capture_model_management_exact_relay() + + assert payload["device"] == "cpu" + assert payload["device_type"] == "cpu" + assert payload["device_name"] == "cpu" + assert payload["free_memory"] == 34359738368 + + +def test_prompt_server_exact_relay() -> None: + payload = capture_prompt_web_exact_relay() + + prompt_calls = _transcripts_for(payload, "PromptServerService", "ui_send_progress_text") + prompt_calls += _transcripts_for(payload, "PromptServerService", "register_route_rpc") + + assert payload["forbidden_matches"] == [] + assert prompt_calls + host_targets = [ + entry["target"] + for entry in payload["transcripts"] + if entry["object_id"] == "PromptServerService" and entry["phase"] == "host_invocation" + ] + assert host_targets == [ + "server.PromptServer.instance.send_progress_text", + "server.PromptServer.instance.routes.add_route", + ] + + +def test_web_directory_exact_relay() -> None: + payload = capture_prompt_web_exact_relay() + + web_calls = _transcripts_for(payload, "WebDirectoryProxy", "get_web_file") + + assert web_calls + host_targets = [entry["target"] for entry in web_calls if entry["phase"] == "host_invocation"] + assert host_targets == ["comfy.isolation.proxies.web_directory_proxy.WebDirectoryProxy.get_web_file"] + assert payload["web_file"]["content_type"] == "application/javascript" + assert payload["web_file"]["content"] == "console.log('deo');" diff --git a/tests/isolation/test_extension_loader_conda.py b/tests/isolation/test_extension_loader_conda.py new file mode 100644 index 000000000..21154655f --- /dev/null +++ b/tests/isolation/test_extension_loader_conda.py @@ -0,0 +1,428 @@ +"""Tests for conda config parsing in extension_loader.py (Slice 5). + +These tests verify that extension_loader.py correctly parses conda-related +fields from pyproject.toml manifests and passes them into the extension config +dict given to pyisolate. The torch import chain is broken by pre-mocking +extension_wrapper before importing extension_loader. +""" + +from __future__ import annotations + +import importlib +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +def _make_manifest( + *, + package_manager: str = "uv", + conda_channels: list[str] | None = None, + conda_dependencies: list[str] | None = None, + conda_platforms: list[str] | None = None, + share_torch: bool = False, + can_isolate: bool = True, + dependencies: list[str] | None = None, + cuda_wheels: list[str] | None = None, +) -> dict: + """Build a manifest dict matching tomllib.load() output.""" + isolation: dict = {"can_isolate": can_isolate} + if package_manager != "uv": + isolation["package_manager"] = package_manager + if conda_channels is not None: + isolation["conda_channels"] = conda_channels + if conda_dependencies is not None: + isolation["conda_dependencies"] = conda_dependencies + if conda_platforms is not None: + isolation["conda_platforms"] = conda_platforms + if share_torch: + isolation["share_torch"] = True + if cuda_wheels is not None: + isolation["cuda_wheels"] = cuda_wheels + + return { + "project": { + "name": "test-extension", + "dependencies": dependencies or ["numpy"], + }, + "tool": {"comfy": {"isolation": isolation}}, + } + + +@pytest.fixture +def manifest_file(tmp_path): + """Create a dummy pyproject.toml so manifest_path.open('rb') succeeds.""" + path = tmp_path / "pyproject.toml" + path.write_bytes(b"") # content is overridden by tomllib mock + return path + + +@pytest.fixture +def loader_module(monkeypatch): + """Import extension_loader under a mocked isolation package for this test only.""" + mock_wrapper = MagicMock() + mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {}) + + iso_mod = types.ModuleType("comfy.isolation") + iso_mod.__path__ = [ # type: ignore[attr-defined] + str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation") + ] + iso_mod.__package__ = "comfy.isolation" + + manifest_loader = types.SimpleNamespace( + is_cache_valid=lambda *args, **kwargs: False, + load_from_cache=lambda *args, **kwargs: None, + save_to_cache=lambda *args, **kwargs: None, + ) + host_policy = types.SimpleNamespace( + load_host_policy=lambda base_path: { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + } + ) + folder_paths = types.SimpleNamespace(base_path="/fake/comfyui") + + monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod) + monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper) + monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock()) + monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader) + monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy) + monkeypatch.setitem(sys.modules, "folder_paths", folder_paths) + sys.modules.pop("comfy.isolation.extension_loader", None) + + module = importlib.import_module("comfy.isolation.extension_loader") + try: + yield module, mock_wrapper + finally: + sys.modules.pop("comfy.isolation.extension_loader", None) + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"): + delattr(comfy_pkg, "isolation") + + +@pytest.fixture +def mock_pyisolate(loader_module): + """Mock pyisolate to avoid real venv creation.""" + module, mock_wrapper = loader_module + mock_ext = AsyncMock() + mock_ext.list_nodes = AsyncMock(return_value={}) + + mock_manager = MagicMock() + mock_manager.load_extension = MagicMock(return_value=mock_ext) + sealed_type = type("SealedNodeExtension", (), {}) + + with patch.object(module, "pyisolate") as mock_pi: + mock_pi.ExtensionManager = MagicMock(return_value=mock_manager) + mock_pi.SealedNodeExtension = sealed_type + yield module, mock_pi, mock_manager, mock_ext, mock_wrapper + + +def load_isolated_node(*args, **kwargs): + return sys.modules["comfy.isolation.extension_loader"].load_isolated_node( + *args, **kwargs + ) + + +class TestCondaPackageManagerParsing: + """Verify extension_loader.py parses conda config from pyproject.toml.""" + + @pytest.mark.asyncio + async def test_conda_package_manager_in_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """package_manager='conda' must appear in extension_config.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["package_manager"] == "conda" + + @pytest.mark.asyncio + async def test_conda_channels_in_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """conda_channels must be passed through to extension_config.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge", "nvidia"], + conda_dependencies=["eccodes"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["conda_channels"] == ["conda-forge", "nvidia"] + + @pytest.mark.asyncio + async def test_conda_dependencies_in_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """conda_dependencies must be passed through to extension_config.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes", "cfgrib"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["conda_dependencies"] == ["eccodes", "cfgrib"] + + @pytest.mark.asyncio + async def test_conda_platforms_in_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """conda_platforms must be passed through to extension_config.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + conda_platforms=["linux-64"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["conda_platforms"] == ["linux-64"] + + +class TestCondaForcedOverrides: + """Verify conda forces share_torch=False, share_cuda_ipc=False.""" + + @pytest.mark.asyncio + async def test_conda_forces_share_torch_false( + self, mock_pyisolate, manifest_file, tmp_path + ): + """share_torch must be forced False for conda, even if manifest says True.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + share_torch=True, # manifest requests True — must be overridden + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["share_torch"] is False + + @pytest.mark.asyncio + async def test_conda_forces_share_cuda_ipc_false( + self, mock_pyisolate, manifest_file, tmp_path + ): + """share_cuda_ipc must be forced False for conda.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + share_torch=True, + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["share_cuda_ipc"] is False + + @pytest.mark.asyncio + async def test_conda_sealed_worker_uses_host_policy_sandbox_config( + self, mock_pyisolate, manifest_file, tmp_path + ): + """Conda sealed_worker must carry the host-policy sandbox config on Linux.""" + + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + ) + + _, _, mock_manager, _, _ = mock_pyisolate + + with ( + patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib, + patch( + "comfy.isolation.extension_loader.platform.system", + return_value="Linux", + ), + ): + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["sandbox"] == { + "network": False, + "writable_paths": [], + "readonly_paths": [], + } + + @pytest.mark.asyncio + async def test_conda_uses_sealed_extension_type( + self, mock_pyisolate, manifest_file, tmp_path + ): + """Conda must not launch through ComfyNodeExtension.""" + + _, mock_pi, _, _, mock_wrapper = mock_pyisolate + manifest = _make_manifest( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["eccodes"], + ) + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type.__name__ == "SealedNodeExtension" + assert extension_type is not mock_wrapper.ComfyNodeExtension + + +class TestUvUnchanged: + """Verify uv extensions are NOT affected by conda changes.""" + + @pytest.mark.asyncio + async def test_uv_default_no_conda_keys( + self, mock_pyisolate, manifest_file, tmp_path + ): + """Default uv extension must NOT have package_manager or conda keys.""" + + manifest = _make_manifest() # defaults: uv, no conda fields + + _, _, mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + # uv extensions should not have conda-specific keys + assert config.get("package_manager", "uv") == "uv" + assert "conda_channels" not in config + assert "conda_dependencies" not in config + + @pytest.mark.asyncio + async def test_uv_keeps_comfy_extension_type( + self, mock_pyisolate, manifest_file, tmp_path + ): + """uv keeps the existing ComfyNodeExtension path.""" + + _, mock_pi, _, _, _ = mock_pyisolate + manifest = _make_manifest() + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type.__name__ == "ComfyNodeExtension" + assert extension_type is not mock_pi.SealedNodeExtension diff --git a/tests/isolation/test_extension_loader_sealed_worker.py b/tests/isolation/test_extension_loader_sealed_worker.py new file mode 100644 index 000000000..d694b178f --- /dev/null +++ b/tests/isolation/test_extension_loader_sealed_worker.py @@ -0,0 +1,281 @@ +"""Tests for execution_model parsing and sealed-worker loader selection.""" + +from __future__ import annotations + +import importlib +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +def _make_manifest( + *, + package_manager: str = "uv", + execution_model: str | None = None, + can_isolate: bool = True, + dependencies: list[str] | None = None, + sealed_host_ro_paths: list[str] | None = None, +) -> dict: + isolation: dict = {"can_isolate": can_isolate} + if package_manager != "uv": + isolation["package_manager"] = package_manager + if execution_model is not None: + isolation["execution_model"] = execution_model + if sealed_host_ro_paths is not None: + isolation["sealed_host_ro_paths"] = sealed_host_ro_paths + + return { + "project": { + "name": "test-extension", + "dependencies": dependencies or ["numpy"], + }, + "tool": {"comfy": {"isolation": isolation}}, + } + + +@pytest.fixture +def manifest_file(tmp_path): + path = tmp_path / "pyproject.toml" + path.write_bytes(b"") + return path + + +@pytest.fixture +def loader_module(monkeypatch): + mock_wrapper = MagicMock() + mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {}) + + iso_mod = types.ModuleType("comfy.isolation") + iso_mod.__path__ = [ # type: ignore[attr-defined] + str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation") + ] + iso_mod.__package__ = "comfy.isolation" + + manifest_loader = types.SimpleNamespace( + is_cache_valid=lambda *args, **kwargs: False, + load_from_cache=lambda *args, **kwargs: None, + save_to_cache=lambda *args, **kwargs: None, + ) + host_policy = types.SimpleNamespace( + load_host_policy=lambda base_path: { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": [], + } + ) + folder_paths = types.SimpleNamespace(base_path="/fake/comfyui") + + monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod) + monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper) + monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock()) + monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader) + monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy) + monkeypatch.setitem(sys.modules, "folder_paths", folder_paths) + sys.modules.pop("comfy.isolation.extension_loader", None) + + module = importlib.import_module("comfy.isolation.extension_loader") + try: + yield module + finally: + sys.modules.pop("comfy.isolation.extension_loader", None) + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"): + delattr(comfy_pkg, "isolation") + + +@pytest.fixture +def mock_pyisolate(loader_module): + mock_ext = AsyncMock() + mock_ext.list_nodes = AsyncMock(return_value={}) + + mock_manager = MagicMock() + mock_manager.load_extension = MagicMock(return_value=mock_ext) + sealed_type = type("SealedNodeExtension", (), {}) + + with patch.object(loader_module, "pyisolate") as mock_pi: + mock_pi.ExtensionManager = MagicMock(return_value=mock_manager) + mock_pi.SealedNodeExtension = sealed_type + yield loader_module, mock_pi, mock_manager, mock_ext, sealed_type + + +def load_isolated_node(*args, **kwargs): + return sys.modules["comfy.isolation.extension_loader"].load_isolated_node(*args, **kwargs) + + +@pytest.mark.asyncio +async def test_uv_sealed_worker_selects_sealed_extension_type( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest(execution_model="sealed_worker") + + _, mock_pi, mock_manager, _, sealed_type = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + config = mock_manager.load_extension.call_args[0][0] + assert extension_type is sealed_type + assert config["execution_model"] == "sealed_worker" + assert "apis" not in config + + +@pytest.mark.asyncio +async def test_default_uv_keeps_host_coupled_extension_type( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest() + + _, mock_pi, mock_manager, _, sealed_type = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + config = mock_manager.load_extension.call_args[0][0] + assert extension_type is not sealed_type + assert "execution_model" not in config + + +@pytest.mark.asyncio +async def test_conda_without_execution_model_remains_sealed_worker( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest(package_manager="conda") + manifest["tool"]["comfy"]["isolation"]["conda_channels"] = ["conda-forge"] + manifest["tool"]["comfy"]["isolation"]["conda_dependencies"] = ["eccodes"] + + _, mock_pi, mock_manager, _, sealed_type = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + config = mock_manager.load_extension.call_args[0][0] + assert extension_type is sealed_type + assert config["execution_model"] == "sealed_worker" + + +@pytest.mark.asyncio +async def test_sealed_worker_uses_host_policy_ro_import_paths( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest(execution_model="sealed_worker") + + module, _, mock_manager, _, _ = mock_pyisolate + + with ( + patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib, + patch.object( + module, + "load_host_policy", + return_value={ + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"], + }, + ), + ): + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert config["sealed_host_ro_paths"] == ["/home/johnj/ComfyUI"] + + +@pytest.mark.asyncio +async def test_host_coupled_does_not_emit_sealed_host_ro_paths( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest(execution_model="host-coupled") + + module, _, mock_manager, _, _ = mock_pyisolate + + with ( + patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib, + patch.object( + module, + "load_host_policy", + return_value={ + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"], + }, + ), + ): + mock_tomllib.load.return_value = manifest + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + + config = mock_manager.load_extension.call_args[0][0] + assert "sealed_host_ro_paths" not in config + + +@pytest.mark.asyncio +async def test_sealed_worker_manifest_ro_import_paths_blocked( + mock_pyisolate, manifest_file, tmp_path +): + manifest = _make_manifest( + execution_model="sealed_worker", + sealed_host_ro_paths=["/home/johnj/ComfyUI"], + ) + + _, _, _mock_manager, _, _ = mock_pyisolate + + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + with pytest.raises(ValueError, match="Manifest field 'sealed_host_ro_paths' is not allowed"): + await load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_file, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) diff --git a/tests/isolation/test_folder_paths_proxy.py b/tests/isolation/test_folder_paths_proxy.py new file mode 100644 index 000000000..451f5e607 --- /dev/null +++ b/tests/isolation/test_folder_paths_proxy.py @@ -0,0 +1,122 @@ +"""Unit tests for FolderPathsProxy.""" + +import pytest +from pathlib import Path + +from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy +from tests.isolation.singleton_boundary_helpers import capture_sealed_singleton_imports + + +class TestFolderPathsProxy: + """Test FolderPathsProxy methods.""" + + @pytest.fixture + def proxy(self): + """Create a FolderPathsProxy instance for testing.""" + return FolderPathsProxy() + + def test_get_temp_directory_returns_string(self, proxy): + """Verify get_temp_directory returns a non-empty string.""" + result = proxy.get_temp_directory() + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Temp directory path is empty" + + def test_get_temp_directory_returns_absolute_path(self, proxy): + """Verify get_temp_directory returns an absolute path.""" + result = proxy.get_temp_directory() + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_get_input_directory_returns_string(self, proxy): + """Verify get_input_directory returns a non-empty string.""" + result = proxy.get_input_directory() + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Input directory path is empty" + + def test_get_input_directory_returns_absolute_path(self, proxy): + """Verify get_input_directory returns an absolute path.""" + result = proxy.get_input_directory() + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_get_annotated_filepath_plain_name(self, proxy): + """Verify get_annotated_filepath works with plain filename.""" + result = proxy.get_annotated_filepath("test.png") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.png" in result, f"Filename not in result: {result}" + + def test_get_annotated_filepath_with_output_annotation(self, proxy): + """Verify get_annotated_filepath handles [output] annotation.""" + result = proxy.get_annotated_filepath("test.png[output]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + # Should resolve to output directory + assert "output" in result.lower() or Path(result).parent.name == "output" + + def test_get_annotated_filepath_with_input_annotation(self, proxy): + """Verify get_annotated_filepath handles [input] annotation.""" + result = proxy.get_annotated_filepath("test.png[input]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + + def test_get_annotated_filepath_with_temp_annotation(self, proxy): + """Verify get_annotated_filepath handles [temp] annotation.""" + result = proxy.get_annotated_filepath("test.png[temp]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + + def test_exists_annotated_filepath_returns_bool(self, proxy): + """Verify exists_annotated_filepath returns a boolean.""" + result = proxy.exists_annotated_filepath("nonexistent.png") + assert isinstance(result, bool), f"Expected bool, got {type(result)}" + + def test_exists_annotated_filepath_nonexistent_file(self, proxy): + """Verify exists_annotated_filepath returns False for nonexistent file.""" + result = proxy.exists_annotated_filepath("definitely_does_not_exist_12345.png") + assert result is False, "Expected False for nonexistent file" + + def test_exists_annotated_filepath_with_annotation(self, proxy): + """Verify exists_annotated_filepath works with annotation suffix.""" + # Even for nonexistent files, should return bool without error + result = proxy.exists_annotated_filepath("test.png[output]") + assert isinstance(result, bool), f"Expected bool, got {type(result)}" + + def test_models_dir_property_returns_string(self, proxy): + """Verify models_dir property returns valid path string.""" + result = proxy.models_dir + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Models directory path is empty" + + def test_models_dir_is_absolute_path(self, proxy): + """Verify models_dir returns an absolute path.""" + result = proxy.models_dir + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_add_model_folder_path_runs_without_error(self, proxy): + """Verify add_model_folder_path executes without raising.""" + test_path = "/tmp/test_models_florence2" + # Should not raise + proxy.add_model_folder_path("TEST_FLORENCE2", test_path) + + def test_get_folder_paths_returns_list(self, proxy): + """Verify get_folder_paths returns a list.""" + # Use known folder type that should exist + result = proxy.get_folder_paths("checkpoints") + assert isinstance(result, list), f"Expected list, got {type(result)}" + + def test_get_folder_paths_checkpoints_not_empty(self, proxy): + """Verify checkpoints folder paths list is not empty.""" + result = proxy.get_folder_paths("checkpoints") + # Should have at least one checkpoint path registered + assert len(result) > 0, "Checkpoints folder paths is empty" + + def test_sealed_child_safe_uses_rpc_without_importing_folder_paths(self, monkeypatch): + monkeypatch.setenv("PYISOLATE_CHILD", "1") + monkeypatch.setenv("PYISOLATE_IMPORT_TORCH", "0") + + payload = capture_sealed_singleton_imports() + + assert payload["temp_dir"] == "/sandbox/temp" + assert payload["models_dir"] == "/sandbox/models" + assert "folder_paths" not in payload["modules"] diff --git a/tests/isolation/test_host_policy.py b/tests/isolation/test_host_policy.py new file mode 100644 index 000000000..46d06bb38 --- /dev/null +++ b/tests/isolation/test_host_policy.py @@ -0,0 +1,209 @@ +from pathlib import Path + +import pytest + + +def _write_pyproject(path: Path, content: str) -> None: + path.write_text(content, encoding="utf-8") + + +def test_load_host_policy_defaults_when_pyproject_missing(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + policy = load_host_policy(tmp_path) + + assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"] + assert policy["allow_network"] == DEFAULT_POLICY["allow_network"] + assert policy["writable_paths"] == DEFAULT_POLICY["writable_paths"] + assert policy["readonly_paths"] == DEFAULT_POLICY["readonly_paths"] + assert policy["whitelist"] == DEFAULT_POLICY["whitelist"] + + +def test_load_host_policy_defaults_when_section_missing(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[project] +name = "ComfyUI" +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"] + assert policy["allow_network"] == DEFAULT_POLICY["allow_network"] + assert policy["whitelist"] == {} + + +def test_load_host_policy_reads_values(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sandbox_mode = "disabled" +allow_network = true +writable_paths = ["/tmp/a", "/tmp/b"] +readonly_paths = ["/opt/readonly"] + +[tool.comfy.host.whitelist] +ExampleNode = "*" +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["sandbox_mode"] == "disabled" + assert policy["allow_network"] is True + assert policy["writable_paths"] == ["/tmp/a", "/tmp/b"] + assert policy["readonly_paths"] == ["/opt/readonly"] + assert policy["whitelist"] == {"ExampleNode": "*"} + + +def test_load_host_policy_ignores_invalid_whitelist_type(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +allow_network = true +whitelist = ["bad"] +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["allow_network"] is True + assert policy["whitelist"] == DEFAULT_POLICY["whitelist"] + + +def test_load_host_policy_ignores_invalid_sandbox_mode(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sandbox_mode = "surprise" +""".strip(), + ) + + policy = load_host_policy(tmp_path) + + assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"] + + +def test_load_host_policy_uses_env_override_path(tmp_path, monkeypatch): + from comfy.isolation.host_policy import load_host_policy + + override_path = tmp_path / "host_policy_override.toml" + _write_pyproject( + override_path, + """ +[tool.comfy.host] +sandbox_mode = "disabled" +allow_network = true +""".strip(), + ) + + monkeypatch.setenv("COMFY_HOST_POLICY_PATH", str(override_path)) + + policy = load_host_policy(tmp_path / "missing-root") + + assert policy["sandbox_mode"] == "disabled" + assert policy["allow_network"] is True + + +def test_disallows_host_tmp_default_or_override_defaults(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + policy = load_host_policy(tmp_path) + + assert "/tmp" not in DEFAULT_POLICY["writable_paths"] + assert "/tmp" not in policy["writable_paths"] + + +def test_disallows_host_tmp_default_or_override_config(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +writable_paths = ["/dev/shm", "/tmp", "/tmp/", "/work/cache"] +""".strip(), + ) + + policy = load_host_policy(tmp_path) + + assert policy["writable_paths"] == ["/dev/shm", "/work/cache"] + + +def test_sealed_worker_ro_import_paths_defaults_off_and_parse(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + policy = load_host_policy(tmp_path) + assert policy["sealed_worker_ro_import_paths"] == [] + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = ["/home/johnj/ComfyUI", "/opt/comfy-shared"] +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["sealed_worker_ro_import_paths"] == [ + "/home/johnj/ComfyUI", + "/opt/comfy-shared", + ] + + +def test_sealed_worker_ro_import_paths_rejects_non_list_or_relative(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = "/home/johnj/ComfyUI" +""".strip(), + ) + with pytest.raises(ValueError, match="must be a list of absolute paths"): + load_host_policy(tmp_path) + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = ["relative/path"] +""".strip(), + ) + with pytest.raises(ValueError, match="entries must be absolute paths"): + load_host_policy(tmp_path) + + +def test_host_policy_path_override_controls_ro_import_paths(tmp_path, monkeypatch): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = ["/ignored/base/path"] +""".strip(), + ) + override_path = tmp_path / "host_policy_override.toml" + _write_pyproject( + override_path, + """ +[tool.comfy.host] +sealed_worker_ro_import_paths = ["/override/ro/path"] +""".strip(), + ) + monkeypatch.setenv("COMFY_HOST_POLICY_PATH", str(override_path)) + + policy = load_host_policy(tmp_path) + assert policy["sealed_worker_ro_import_paths"] == ["/override/ro/path"] diff --git a/tests/isolation/test_init.py b/tests/isolation/test_init.py new file mode 100644 index 000000000..c237fe904 --- /dev/null +++ b/tests/isolation/test_init.py @@ -0,0 +1,80 @@ +"""Unit tests for PyIsolate isolation system initialization.""" + +import importlib +import sys + +from tests.isolation.singleton_boundary_helpers import ( + FakeSingletonRPC, + reset_forbidden_singleton_modules, +) + + +def test_log_prefix(): + """Verify LOG_PREFIX constant is correctly defined.""" + from comfy.isolation import LOG_PREFIX + assert LOG_PREFIX == "][" + assert isinstance(LOG_PREFIX, str) + + +def test_module_initialization(): + """Verify module initializes without errors.""" + isolation_pkg = importlib.import_module("comfy.isolation") + assert hasattr(isolation_pkg, "LOG_PREFIX") + assert hasattr(isolation_pkg, "initialize_proxies") + + +class TestInitializeProxies: + def test_initialize_proxies_runs_without_error(self): + from comfy.isolation import initialize_proxies + initialize_proxies() + + def test_initialize_proxies_registers_folder_paths_proxy(self): + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + initialize_proxies() + proxy = FolderPathsProxy() + assert proxy is not None + assert hasattr(proxy, "get_temp_directory") + + def test_initialize_proxies_registers_model_management_proxy(self): + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + initialize_proxies() + proxy = ModelManagementProxy() + assert proxy is not None + assert hasattr(proxy, "get_torch_device") + + def test_initialize_proxies_can_be_called_multiple_times(self): + from comfy.isolation import initialize_proxies + initialize_proxies() + initialize_proxies() + initialize_proxies() + + def test_dev_proxies_accessible_when_dev_mode(self, monkeypatch): + """Verify dev mode does not break core proxy initialization.""" + monkeypatch.setenv("PYISOLATE_DEV", "1") + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + initialize_proxies() + folder_proxy = FolderPathsProxy() + utils_proxy = UtilsProxy() + assert folder_proxy is not None + assert utils_proxy is not None + + def test_sealed_child_safe_initialize_proxies_avoids_real_utils_import(self, monkeypatch): + monkeypatch.setenv("PYISOLATE_CHILD", "1") + monkeypatch.setenv("PYISOLATE_IMPORT_TORCH", "0") + reset_forbidden_singleton_modules() + + from pyisolate._internal import rpc_protocol + from comfy.isolation import initialize_proxies + + fake_rpc = FakeSingletonRPC() + monkeypatch.setattr(rpc_protocol, "get_child_rpc_instance", lambda: fake_rpc) + + initialize_proxies() + + assert "comfy.utils" not in sys.modules + assert "folder_paths" not in sys.modules + assert "comfy_execution.progress" not in sys.modules diff --git a/tests/isolation/test_internal_probe_node_assets.py b/tests/isolation/test_internal_probe_node_assets.py new file mode 100644 index 000000000..c12cf4404 --- /dev/null +++ b/tests/isolation/test_internal_probe_node_assets.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import importlib.util +import json +from pathlib import Path + + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +ISOLATION_ROOT = COMFYUI_ROOT / "tests" / "isolation" +PROBE_ROOT = ISOLATION_ROOT / "internal_probe_node" +WORKFLOW_ROOT = ISOLATION_ROOT / "workflows" +TOOLKIT_ROOT = COMFYUI_ROOT / "custom_nodes" / "ComfyUI-IsolationToolkit" + +EXPECTED_PROBE_FILES = { + "__init__.py", + "probe_nodes.py", +} +EXPECTED_WORKFLOWS = { + "internal_probe_preview_image_audio.json", + "internal_probe_ui3d.json", +} +BANNED_REFERENCES = ( + "ComfyUI-IsolationToolkit", + "toolkit_smoke_playlist", + "run_isolation_toolkit_smoke.sh", +) + + +def _text_assets() -> list[Path]: + return sorted(list(PROBE_ROOT.rglob("*.py")) + list(WORKFLOW_ROOT.glob("internal_probe_*.json"))) + + +def _load_probe_package(): + spec = importlib.util.spec_from_file_location( + "internal_probe_node", + PROBE_ROOT / "__init__.py", + submodule_search_locations=[str(PROBE_ROOT)], + ) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def test_inventory_is_minimal_and_isolation_owned(): + assert PROBE_ROOT.is_dir() + assert WORKFLOW_ROOT.is_dir() + assert PROBE_ROOT.is_relative_to(ISOLATION_ROOT) + assert WORKFLOW_ROOT.is_relative_to(ISOLATION_ROOT) + assert not PROBE_ROOT.is_relative_to(TOOLKIT_ROOT) + + probe_files = {path.name for path in PROBE_ROOT.iterdir() if path.is_file()} + workflow_files = {path.name for path in WORKFLOW_ROOT.glob("internal_probe_*.json")} + + assert probe_files == EXPECTED_PROBE_FILES + assert workflow_files == EXPECTED_WORKFLOWS + + module = _load_probe_package() + mappings = module.NODE_CLASS_MAPPINGS + + assert sorted(mappings.keys()) == [ + "InternalIsolationProbeAudio", + "InternalIsolationProbeImage", + "InternalIsolationProbeUI3D", + ] + + preview_workflow = json.loads( + (WORKFLOW_ROOT / "internal_probe_preview_image_audio.json").read_text( + encoding="utf-8" + ) + ) + ui3d_workflow = json.loads( + (WORKFLOW_ROOT / "internal_probe_ui3d.json").read_text(encoding="utf-8") + ) + + assert [preview_workflow[node_id]["class_type"] for node_id in ("1", "2")] == [ + "InternalIsolationProbeImage", + "InternalIsolationProbeAudio", + ] + assert [ui3d_workflow[node_id]["class_type"] for node_id in ("1",)] == [ + "InternalIsolationProbeUI3D", + ] + + +def test_zero_toolkit_references_in_probe_assets(): + for asset in _text_assets(): + content = asset.read_text(encoding="utf-8") + for banned in BANNED_REFERENCES: + assert banned not in content, f"{asset} unexpectedly references {banned}" + + +def test_replacement_contract_has_zero_toolkit_references(): + contract_assets = [ + *(PROBE_ROOT.rglob("*.py")), + *WORKFLOW_ROOT.glob("internal_probe_*.json"), + ISOLATION_ROOT / "stage_internal_probe_node.py", + ISOLATION_ROOT / "internal_probe_host_policy.toml", + ] + + for asset in sorted(contract_assets): + assert asset.exists(), f"Missing replacement-contract asset: {asset}" + content = asset.read_text(encoding="utf-8") + for banned in BANNED_REFERENCES: + assert banned not in content, f"{asset} unexpectedly references {banned}" diff --git a/tests/isolation/test_internal_probe_node_loading.py b/tests/isolation/test_internal_probe_node_loading.py new file mode 100644 index 000000000..fd1a7268c --- /dev/null +++ b/tests/isolation/test_internal_probe_node_loading.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import sys +from pathlib import Path + +import pytest + +import nodes +from tests.isolation.stage_internal_probe_node import ( + PROBE_NODE_NAME, + stage_probe_node, + staged_probe_node, +) + + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +ISOLATION_ROOT = COMFYUI_ROOT / "tests" / "isolation" +PROBE_SOURCE_ROOT = ISOLATION_ROOT / "internal_probe_node" +EXPECTED_NODE_IDS = [ + "InternalIsolationProbeAudio", + "InternalIsolationProbeImage", + "InternalIsolationProbeUI3D", +] + +CLIENT_SCRIPT = """ +import importlib.util +import json +import os +import sys + +import pyisolate._internal.client # noqa: F401 # triggers snapshot bootstrap + +module_path = os.environ["PYISOLATE_MODULE_PATH"] +spec = importlib.util.spec_from_file_location( + "internal_probe_node", + os.path.join(module_path, "__init__.py"), + submodule_search_locations=[module_path], +) +module = importlib.util.module_from_spec(spec) +assert spec is not None +assert spec.loader is not None +sys.modules["internal_probe_node"] = module +spec.loader.exec_module(module) +print( + json.dumps( + { + "sys_path": list(sys.path), + "module_path": module_path, + "node_ids": sorted(module.NODE_CLASS_MAPPINGS.keys()), + } + ) +) +""" + + +def _run_client_process(env: dict[str, str]) -> dict: + pythonpath_parts = [str(COMFYUI_ROOT)] + existing = env.get("PYTHONPATH", "") + if existing: + pythonpath_parts.append(existing) + env["PYTHONPATH"] = ":".join(pythonpath_parts) + + result = subprocess.run( # noqa: S603 + [sys.executable, "-c", CLIENT_SCRIPT], + capture_output=True, + text=True, + env=env, + check=True, + ) + return json.loads(result.stdout.strip().splitlines()[-1]) + + +@pytest.fixture() +def staged_probe_module(tmp_path: Path) -> tuple[Path, Path]: + staged_comfy_root = tmp_path / "ComfyUI" + module_path = staged_comfy_root / "custom_nodes" / "InternalIsolationProbeNode" + shutil.copytree(PROBE_SOURCE_ROOT, module_path) + return staged_comfy_root, module_path + + +@pytest.mark.asyncio +async def test_staged_probe_node_discovered(staged_probe_module: tuple[Path, Path]) -> None: + _, module_path = staged_probe_module + class_mappings_snapshot = dict(nodes.NODE_CLASS_MAPPINGS) + display_name_snapshot = dict(nodes.NODE_DISPLAY_NAME_MAPPINGS) + loaded_module_dirs_snapshot = dict(nodes.LOADED_MODULE_DIRS) + + try: + ignore = set(nodes.NODE_CLASS_MAPPINGS.keys()) + loaded = await nodes.load_custom_node( + str(module_path), ignore=ignore, module_parent="custom_nodes" + ) + + assert loaded is True + assert nodes.LOADED_MODULE_DIRS["InternalIsolationProbeNode"] == str( + module_path.resolve() + ) + + for node_id in EXPECTED_NODE_IDS: + assert node_id in nodes.NODE_CLASS_MAPPINGS + node_cls = nodes.NODE_CLASS_MAPPINGS[node_id] + assert ( + getattr(node_cls, "RELATIVE_PYTHON_MODULE", None) + == "custom_nodes.InternalIsolationProbeNode" + ) + finally: + nodes.NODE_CLASS_MAPPINGS.clear() + nodes.NODE_CLASS_MAPPINGS.update(class_mappings_snapshot) + nodes.NODE_DISPLAY_NAME_MAPPINGS.clear() + nodes.NODE_DISPLAY_NAME_MAPPINGS.update(display_name_snapshot) + nodes.LOADED_MODULE_DIRS.clear() + nodes.LOADED_MODULE_DIRS.update(loaded_module_dirs_snapshot) + + +def test_staged_probe_node_module_path_is_valid_for_child_bootstrap( + tmp_path: Path, staged_probe_module: tuple[Path, Path] +) -> None: + staged_comfy_root, module_path = staged_probe_module + snapshot = { + "sys_path": [str(COMFYUI_ROOT), "/host/lib1", "/host/lib2"], + "sys_executable": sys.executable, + "sys_prefix": sys.prefix, + "environment": {}, + } + snapshot_path = tmp_path / "snapshot.json" + snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8") + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(snapshot_path), + "PYISOLATE_MODULE_PATH": str(module_path), + } + ) + + payload = _run_client_process(env) + + assert payload["module_path"] == str(module_path) + assert payload["node_ids"] == EXPECTED_NODE_IDS + assert str(COMFYUI_ROOT) in payload["sys_path"] + assert str(staged_comfy_root) not in payload["sys_path"] + + +def test_stage_probe_node_stages_only_under_explicit_root(tmp_path: Path) -> None: + comfy_root = tmp_path / "sandbox-root" + + module_path = stage_probe_node(comfy_root) + + assert module_path == comfy_root / "custom_nodes" / PROBE_NODE_NAME + assert module_path.is_dir() + assert (module_path / "__init__.py").is_file() + assert (module_path / "probe_nodes.py").is_file() + assert (module_path / "pyproject.toml").is_file() + + +def test_staged_probe_node_context_cleans_up_temp_root() -> None: + with staged_probe_node() as module_path: + staging_root = module_path.parents[1] + assert module_path.name == PROBE_NODE_NAME + assert module_path.is_dir() + assert staging_root.is_dir() + + assert not staging_root.exists() + + +def test_stage_script_requires_explicit_target_root() -> None: + result = subprocess.run( # noqa: S603 + [sys.executable, str(ISOLATION_ROOT / "stage_internal_probe_node.py")], + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode != 0 + assert "--target-root" in result.stderr diff --git a/tests/isolation/test_manifest_loader_cache.py b/tests/isolation/test_manifest_loader_cache.py new file mode 100644 index 000000000..ebee43b7e --- /dev/null +++ b/tests/isolation/test_manifest_loader_cache.py @@ -0,0 +1,434 @@ +""" +Unit tests for manifest_loader.py cache functions. + +Phase 1 tests verify: +1. Cache miss on first run (no cache exists) +2. Cache hit when nothing changes +3. Invalidation on .py file touch +4. Invalidation on manifest change +5. Cache location correctness (in venv_root, NOT in custom_nodes) +6. Corrupt cache handling (graceful failure) + +These tests verify the cache implementation is correct BEFORE it's activated +in extension_loader.py (Phase 2). +""" + +from __future__ import annotations + +import json +import sys +import time +from pathlib import Path +from unittest import mock + + + +class TestComputeCacheKey: + """Tests for compute_cache_key() function.""" + + def test_key_includes_manifest_content(self, tmp_path: Path) -> None: + """Cache key changes when manifest content changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + + # Initial manifest + manifest.write_text("isolated: true\ndependencies: []\n") + key1 = compute_cache_key(node_dir, manifest) + + # Modified manifest + manifest.write_text("isolated: true\ndependencies: [numpy]\n") + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when manifest content changes" + + def test_key_includes_py_file_mtime(self, tmp_path: Path) -> None: + """Cache key changes when any .py file is touched.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + py_file = node_dir / "nodes.py" + py_file.write_text("# test code") + + key1 = compute_cache_key(node_dir, manifest) + + # Wait a moment to ensure mtime changes + time.sleep(0.01) + py_file.write_text("# modified code") + + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when .py file mtime changes" + + def test_key_includes_python_version(self, tmp_path: Path) -> None: + """Cache key changes when Python version changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + key1 = compute_cache_key(node_dir, manifest) + + # Mock different Python version + with mock.patch.object(sys, "version", "3.99.0 (fake)"): + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when Python version changes" + + def test_key_includes_pyisolate_version(self, tmp_path: Path) -> None: + """Cache key changes when PyIsolate version changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + key1 = compute_cache_key(node_dir, manifest) + + # Mock different pyisolate version + with mock.patch.dict(sys.modules, {"pyisolate": mock.MagicMock(__version__="99.99.99")}): + # Need to reimport to pick up the mock + import importlib + from comfy.isolation import manifest_loader + importlib.reload(manifest_loader) + key2 = manifest_loader.compute_cache_key(node_dir, manifest) + + # Keys should be different (though the mock approach is tricky) + # At minimum, verify key is a valid hex string + assert len(key1) == 16, "Key should be 16 hex characters" + assert all(c in "0123456789abcdef" for c in key1), "Key should be hex" + assert len(key2) == 16, "Key should be 16 hex characters" + assert all(c in "0123456789abcdef" for c in key2), "Key should be hex" + + def test_key_excludes_pycache(self, tmp_path: Path) -> None: + """Cache key ignores __pycache__ directory changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + py_file = node_dir / "nodes.py" + py_file.write_text("# test code") + + key1 = compute_cache_key(node_dir, manifest) + + # Add __pycache__ file + pycache = node_dir / "__pycache__" + pycache.mkdir() + (pycache / "nodes.cpython-310.pyc").write_bytes(b"compiled") + + key2 = compute_cache_key(node_dir, manifest) + + assert key1 == key2, "Key should NOT change when __pycache__ modified" + + def test_key_is_deterministic(self, tmp_path: Path) -> None: + """Same inputs produce same key.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + + key1 = compute_cache_key(node_dir, manifest) + key2 = compute_cache_key(node_dir, manifest) + + assert key1 == key2, "Key should be deterministic" + + +class TestGetCachePath: + """Tests for get_cache_path() function.""" + + def test_returns_correct_paths(self, tmp_path: Path) -> None: + """Cache paths are in venv_root, not in node_dir.""" + from comfy.isolation.manifest_loader import get_cache_path + + node_dir = tmp_path / "custom_nodes" / "MyNode" + venv_root = tmp_path / ".pyisolate_venvs" + + key_file, data_file = get_cache_path(node_dir, venv_root) + + assert key_file == venv_root / "MyNode" / "cache" / "cache_key" + assert data_file == venv_root / "MyNode" / "cache" / "node_info.json" + + def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None: + """Verify cache is NOT stored in custom_nodes directory.""" + from comfy.isolation.manifest_loader import get_cache_path + + node_dir = tmp_path / "custom_nodes" / "MyNode" + venv_root = tmp_path / ".pyisolate_venvs" + + key_file, data_file = get_cache_path(node_dir, venv_root) + + # Neither path should be under node_dir + assert not str(key_file).startswith(str(node_dir)) + assert not str(data_file).startswith(str(node_dir)) + + +class TestIsCacheValid: + """Tests for is_cache_valid() function.""" + + def test_false_when_no_cache_exists(self, tmp_path: Path) -> None: + """Returns False when cache files don't exist.""" + from comfy.isolation.manifest_loader import is_cache_valid + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_true_when_cache_matches(self, tmp_path: Path) -> None: + """Returns True when cache key matches current state.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create valid cache + cache_key = compute_cache_key(node_dir, manifest) + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + data_file.write_text("{}") + + assert is_cache_valid(node_dir, manifest, venv_root) is True + + def test_false_when_key_mismatch(self, tmp_path: Path) -> None: + """Returns False when stored key doesn't match current state.""" + from comfy.isolation.manifest_loader import get_cache_path, is_cache_valid + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create cache with wrong key + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text("wrong_key_12345") + data_file.write_text("{}") + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_false_when_data_file_missing(self, tmp_path: Path) -> None: + """Returns False when node_info.json is missing.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create only key file, not data file + cache_key = compute_cache_key(node_dir, manifest) + key_file, _ = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_invalidation_on_py_change(self, tmp_path: Path) -> None: + """Cache invalidates when .py file is modified.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + py_file = node_dir / "nodes.py" + py_file.write_text("# original") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create valid cache + cache_key = compute_cache_key(node_dir, manifest) + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + data_file.write_text("{}") + + # Verify cache is valid initially + assert is_cache_valid(node_dir, manifest, venv_root) is True + + # Modify .py file + time.sleep(0.01) # Ensure mtime changes + py_file.write_text("# modified") + + # Cache should now be invalid + assert is_cache_valid(node_dir, manifest, venv_root) is False + + +class TestLoadFromCache: + """Tests for load_from_cache() function.""" + + def test_returns_none_when_no_cache(self, tmp_path: Path) -> None: + """Returns None when cache doesn't exist.""" + from comfy.isolation.manifest_loader import load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + assert load_from_cache(node_dir, venv_root) is None + + def test_returns_data_when_valid(self, tmp_path: Path) -> None: + """Returns cached data when file exists and is valid JSON.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"inputs": [], "outputs": []}} + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text(json.dumps(test_data)) + + result = load_from_cache(node_dir, venv_root) + assert result == test_data + + def test_returns_none_on_corrupt_json(self, tmp_path: Path) -> None: + """Returns None when JSON is corrupt.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text("{ corrupt json }") + + assert load_from_cache(node_dir, venv_root) is None + + def test_returns_none_on_invalid_structure(self, tmp_path: Path) -> None: + """Returns None when data is not a dict.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text("[1, 2, 3]") # Array, not dict + + assert load_from_cache(node_dir, venv_root) is None + + +class TestSaveToCache: + """Tests for save_to_cache() function.""" + + def test_creates_cache_directory(self, tmp_path: Path) -> None: + """Creates cache directory if it doesn't exist.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest) + + key_file, data_file = get_cache_path(node_dir, venv_root) + assert key_file.parent.exists() + + def test_writes_both_files(self, tmp_path: Path) -> None: + """Writes both cache_key and node_info.json.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {"key": "value"}}, manifest) + + key_file, data_file = get_cache_path(node_dir, venv_root) + assert key_file.exists() + assert data_file.exists() + + def test_data_is_valid_json(self, tmp_path: Path) -> None: + """Written data can be parsed as JSON.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"inputs": ["IMAGE"], "outputs": ["IMAGE"]}} + save_to_cache(node_dir, venv_root, test_data, manifest) + + _, data_file = get_cache_path(node_dir, venv_root) + loaded = json.loads(data_file.read_text()) + assert loaded == test_data + + def test_roundtrip_with_validation(self, tmp_path: Path) -> None: + """Saved cache is immediately valid.""" + from comfy.isolation.manifest_loader import ( + is_cache_valid, + load_from_cache, + save_to_cache, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"foo": "bar"}} + save_to_cache(node_dir, venv_root, test_data, manifest) + + assert is_cache_valid(node_dir, manifest, venv_root) is True + assert load_from_cache(node_dir, venv_root) == test_data + + def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None: + """Verify no files written to custom_nodes directory.""" + from comfy.isolation.manifest_loader import save_to_cache + + node_dir = tmp_path / "custom_nodes" / "MyNode" + node_dir.mkdir(parents=True) + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest) + + # Check nothing was created under node_dir + for item in node_dir.iterdir(): + assert item.name == "pyisolate.yaml", f"Unexpected file in node_dir: {item}" diff --git a/tests/isolation/test_manifest_loader_discovery.py b/tests/isolation/test_manifest_loader_discovery.py new file mode 100644 index 000000000..101b5d1e2 --- /dev/null +++ b/tests/isolation/test_manifest_loader_discovery.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import importlib +import sys +from pathlib import Path +from types import ModuleType + + +def _write_manifest(path: Path, *, standalone: bool = False) -> None: + lines = [ + "[project]", + 'name = "test-node"', + 'version = "0.1.0"', + "", + "[tool.comfy.isolation]", + "can_isolate = true", + "share_torch = false", + ] + if standalone: + lines.append("standalone = true") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _load_manifest_loader(custom_nodes_root: Path): + folder_paths = ModuleType("folder_paths") + folder_paths.base_path = str(custom_nodes_root) + folder_paths.get_folder_paths = lambda kind: [str(custom_nodes_root)] if kind == "custom_nodes" else [] + sys.modules["folder_paths"] = folder_paths + + if "comfy.isolation" not in sys.modules: + iso_mod = ModuleType("comfy.isolation") + iso_mod.__path__ = [ # type: ignore[attr-defined] + str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation") + ] + iso_mod.__package__ = "comfy.isolation" + sys.modules["comfy.isolation"] = iso_mod + + sys.modules.pop("comfy.isolation.manifest_loader", None) + + import comfy.isolation.manifest_loader as manifest_loader + + return importlib.reload(manifest_loader) + + +def test_finds_top_level_isolation_manifest(tmp_path: Path) -> None: + node_dir = tmp_path / "TopLevelNode" + node_dir.mkdir(parents=True) + _write_manifest(node_dir / "pyproject.toml") + + manifest_loader = _load_manifest_loader(tmp_path) + manifests = manifest_loader.find_manifest_directories() + + assert manifests == [(node_dir, node_dir / "pyproject.toml")] + + +def test_ignores_nested_manifest_without_standalone_flag(tmp_path: Path) -> None: + toolkit_dir = tmp_path / "ToolkitNode" + toolkit_dir.mkdir(parents=True) + _write_manifest(toolkit_dir / "pyproject.toml") + + nested_dir = toolkit_dir / "packages" / "nested_fixture" + nested_dir.mkdir(parents=True) + _write_manifest(nested_dir / "pyproject.toml", standalone=False) + + manifest_loader = _load_manifest_loader(tmp_path) + manifests = manifest_loader.find_manifest_directories() + + assert manifests == [(toolkit_dir, toolkit_dir / "pyproject.toml")] + + +def test_finds_nested_standalone_manifest(tmp_path: Path) -> None: + toolkit_dir = tmp_path / "ToolkitNode" + toolkit_dir.mkdir(parents=True) + _write_manifest(toolkit_dir / "pyproject.toml") + + nested_dir = toolkit_dir / "packages" / "uv_sealed_worker" + nested_dir.mkdir(parents=True) + _write_manifest(nested_dir / "pyproject.toml", standalone=True) + + manifest_loader = _load_manifest_loader(tmp_path) + manifests = manifest_loader.find_manifest_directories() + + assert manifests == [ + (toolkit_dir, toolkit_dir / "pyproject.toml"), + (nested_dir, nested_dir / "pyproject.toml"), + ] diff --git a/tests/isolation/test_model_management_proxy.py b/tests/isolation/test_model_management_proxy.py new file mode 100644 index 000000000..3a03bd54d --- /dev/null +++ b/tests/isolation/test_model_management_proxy.py @@ -0,0 +1,50 @@ +"""Unit tests for ModelManagementProxy.""" + +import pytest +import torch + +from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + + +class TestModelManagementProxy: + """Test ModelManagementProxy methods.""" + + @pytest.fixture + def proxy(self): + """Create a ModelManagementProxy instance for testing.""" + return ModelManagementProxy() + + def test_get_torch_device_returns_device(self, proxy): + """Verify get_torch_device returns a torch.device object.""" + result = proxy.get_torch_device() + assert isinstance(result, torch.device), f"Expected torch.device, got {type(result)}" + + def test_get_torch_device_is_valid(self, proxy): + """Verify get_torch_device returns a valid device (cpu or cuda).""" + result = proxy.get_torch_device() + assert result.type in ("cpu", "cuda"), f"Unexpected device type: {result.type}" + + def test_get_torch_device_name_returns_string(self, proxy): + """Verify get_torch_device_name returns a non-empty string.""" + device = proxy.get_torch_device() + result = proxy.get_torch_device_name(device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Device name is empty" + + def test_get_torch_device_name_with_cpu(self, proxy): + """Verify get_torch_device_name works with CPU device.""" + cpu_device = torch.device("cpu") + result = proxy.get_torch_device_name(cpu_device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "cpu" in result.lower(), f"Expected 'cpu' in device name, got: {result}" + + def test_get_torch_device_name_with_cuda_if_available(self, proxy): + """Verify get_torch_device_name works with CUDA device if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + cuda_device = torch.device("cuda:0") + result = proxy.get_torch_device_name(cuda_device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + # Should contain device identifier + assert len(result) > 0, "CUDA device name is empty" diff --git a/tests/isolation/test_path_helpers.py b/tests/isolation/test_path_helpers.py new file mode 100644 index 000000000..af96f1fe0 --- /dev/null +++ b/tests/isolation/test_path_helpers.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path + +import pytest + +from pyisolate.path_helpers import build_child_sys_path, serialize_host_snapshot + + +def test_serialize_host_snapshot_includes_expected_keys(tmp_path: Path, monkeypatch) -> None: + output = tmp_path / "snapshot.json" + monkeypatch.setenv("EXTRA_FLAG", "1") + snapshot = serialize_host_snapshot(output_path=output, extra_env_keys=["EXTRA_FLAG"]) + + assert "sys_path" in snapshot + assert "sys_executable" in snapshot + assert "sys_prefix" in snapshot + assert "environment" in snapshot + assert output.exists() + assert snapshot["environment"].get("EXTRA_FLAG") == "1" + + persisted = json.loads(output.read_text(encoding="utf-8")) + assert persisted["sys_path"] == snapshot["sys_path"] + + +def test_build_child_sys_path_preserves_host_order() -> None: + host_paths = ["/host/root", "/host/site-packages"] + extra_paths = ["/node/.venv/lib/python3.12/site-packages"] + result = build_child_sys_path(host_paths, extra_paths, preferred_root=None) + assert result == host_paths + extra_paths + + +def test_build_child_sys_path_inserts_comfy_root_when_missing() -> None: + host_paths = ["/host/site-packages"] + comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + extra_paths: list[str] = [] + result = build_child_sys_path(host_paths, extra_paths, preferred_root=comfy_root) + assert result[0] == comfy_root + assert result[1:] == host_paths + + +def test_build_child_sys_path_deduplicates_entries(tmp_path: Path) -> None: + path_a = str(tmp_path / "a") + path_b = str(tmp_path / "b") + host_paths = [path_a, path_b] + extra_paths = [path_a, path_b, str(tmp_path / "c")] + result = build_child_sys_path(host_paths, extra_paths) + assert result == [path_a, path_b, str(tmp_path / "c")] + + +def test_build_child_sys_path_skips_duplicate_comfy_root() -> None: + comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + host_paths = [comfy_root, "/host/other"] + result = build_child_sys_path(host_paths, extra_paths=[], preferred_root=comfy_root) + assert result == host_paths + + +def test_child_import_succeeds_after_path_unification(tmp_path: Path, monkeypatch) -> None: + host_root = tmp_path / "host" + utils_pkg = host_root / "utils" + app_pkg = host_root / "app" + utils_pkg.mkdir(parents=True) + app_pkg.mkdir(parents=True) + + (utils_pkg / "__init__.py").write_text("from . import install_util\n", encoding="utf-8") + (utils_pkg / "install_util.py").write_text("VALUE = 'hello'\n", encoding="utf-8") + (app_pkg / "__init__.py").write_text("", encoding="utf-8") + (app_pkg / "frontend_management.py").write_text( + "from utils import install_util\nVALUE = install_util.VALUE\n", + encoding="utf-8", + ) + + child_only = tmp_path / "child_only" + child_only.mkdir() + + target_module = "app.frontend_management" + for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]: + sys.modules.pop(name) + + monkeypatch.setattr(sys, "path", [str(child_only)]) + with pytest.raises(ModuleNotFoundError): + __import__(target_module) + + for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]: + sys.modules.pop(name) + + unified = build_child_sys_path([], [], preferred_root=str(host_root)) + monkeypatch.setattr(sys, "path", unified) + module = __import__(target_module, fromlist=["VALUE"]) + assert module.VALUE == "hello" diff --git a/tests/isolation/test_runtime_helpers_stub_contract.py b/tests/isolation/test_runtime_helpers_stub_contract.py new file mode 100644 index 000000000..16e47eb06 --- /dev/null +++ b/tests/isolation/test_runtime_helpers_stub_contract.py @@ -0,0 +1,125 @@ +"""Generic runtime-helper stub contract tests.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast + +from comfy.isolation import runtime_helpers +from comfy_api.latest import io as latest_io +from tests.isolation.stage_internal_probe_node import PROBE_NODE_NAME, staged_probe_node + + +class _DummyExtension: + def __init__(self, *, name: str, module_path: str): + self.name = name + self.module_path = module_path + + async def execute_node(self, _node_name: str, **inputs): + return { + "__node_output__": True, + "args": (inputs,), + "ui": {"status": "ok"}, + "expand": False, + "block_execution": False, + } + + +def _install_model_serialization_stub(monkeypatch): + async def deserialize_from_isolation(payload, _extension): + return payload + + monkeypatch.setitem( + sys.modules, + "pyisolate._internal.model_serialization", + SimpleNamespace( + serialize_for_isolation=lambda payload: payload, + deserialize_from_isolation=deserialize_from_isolation, + ), + ) + + +def test_stub_sets_relative_python_module(monkeypatch): + _install_model_serialization_stub(monkeypatch) + monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None) + monkeypatch.setattr(runtime_helpers, "_relieve_host_vram_pressure", lambda *args, **kwargs: None) + + extension = _DummyExtension(name="internal_probe", module_path=os.getcwd()) + stub = cast(Any, runtime_helpers.build_stub_class( + "ProbeNode", + { + "is_v3": True, + "schema_v1": {}, + "input_types": {}, + }, + extension, + {}, + logging.getLogger("test"), + )) + + info = getattr(stub, "GET_NODE_INFO_V1")() + assert info["python_module"] == "custom_nodes.internal_probe" + + +def test_stub_ui_dispatch_roundtrip(monkeypatch): + _install_model_serialization_stub(monkeypatch) + monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None) + monkeypatch.setattr(runtime_helpers, "_relieve_host_vram_pressure", lambda *args, **kwargs: None) + + extension = _DummyExtension(name="internal_probe", module_path=os.getcwd()) + stub = runtime_helpers.build_stub_class( + "ProbeNode", + { + "is_v3": True, + "schema_v1": {"python_module": "custom_nodes.internal_probe"}, + "input_types": {}, + }, + extension, + {}, + logging.getLogger("test"), + ) + + result = asyncio.run(getattr(stub, "_pyisolate_execute")(SimpleNamespace(), token="value")) + + assert isinstance(result, latest_io.NodeOutput) + assert result.ui == {"status": "ok"} + + +def test_stub_class_types_align_with_extension(): + extension = SimpleNamespace(name="internal_probe", module_path="/sandbox/probe") + running_extensions = {"internal_probe": extension} + + specs = [ + SimpleNamespace(module_path=Path("/sandbox/probe"), node_name="ProbeImage"), + SimpleNamespace(module_path=Path("/sandbox/probe"), node_name="ProbeAudio"), + SimpleNamespace(module_path=Path("/sandbox/other"), node_name="OtherNode"), + ] + + class_types = runtime_helpers.get_class_types_for_extension( + "internal_probe", running_extensions, specs + ) + + assert class_types == {"ProbeImage", "ProbeAudio"} + + +def test_probe_stage_requires_explicit_root(): + script = Path(__file__).resolve().parent / "stage_internal_probe_node.py" + result = subprocess.run([sys.executable, str(script)], capture_output=True, text=True, check=False) + + assert result.returncode != 0 + assert "--target-root" in result.stderr + + +def test_probe_stage_cleans_up_context(): + with staged_probe_node() as module_path: + staged_root = module_path.parents[1] + assert module_path.name == PROBE_NODE_NAME + assert staged_root.exists() + + assert not staged_root.exists() diff --git a/tests/isolation/test_savedimages_serialization.py b/tests/isolation/test_savedimages_serialization.py new file mode 100644 index 000000000..f2f3df1cc --- /dev/null +++ b/tests/isolation/test_savedimages_serialization.py @@ -0,0 +1,53 @@ +import logging +import socket +import sys +from pathlib import Path + +repo_root = Path(__file__).resolve().parents[2] +pyisolate_root = repo_root.parent / "pyisolate" +if pyisolate_root.exists(): + sys.path.insert(0, str(pyisolate_root)) + +from comfy.isolation.adapter import ComfyUIAdapter +from comfy_api.latest._io import FolderType +from comfy_api.latest._ui import SavedImages, SavedResult +from pyisolate._internal.rpc_transports import JSONSocketTransport +from pyisolate._internal.serialization_registry import SerializerRegistry + + +def test_savedimages_roundtrip(caplog): + registry = SerializerRegistry.get_instance() + registry.clear() + ComfyUIAdapter().register_serializers(registry) + + payload = SavedImages( + results=[SavedResult("issue82.png", "slice2", FolderType.output)], + is_animated=True, + ) + + a, b = socket.socketpair() + sender = JSONSocketTransport(a) + receiver = JSONSocketTransport(b) + try: + with caplog.at_level(logging.WARNING, logger="pyisolate._internal.rpc_transports"): + sender.send({"ui": payload}) + result = receiver.recv() + finally: + sender.close() + receiver.close() + registry.clear() + + ui = result["ui"] + assert isinstance(ui, SavedImages) + assert ui.is_animated is True + assert len(ui.results) == 1 + assert isinstance(ui.results[0], SavedResult) + assert ui.results[0].filename == "issue82.png" + assert ui.results[0].subfolder == "slice2" + assert ui.results[0].type == FolderType.output + assert ui.as_dict() == { + "images": [SavedResult("issue82.png", "slice2", FolderType.output)], + "animated": (True,), + } + assert not any("GENERIC SERIALIZER USED" in record.message for record in caplog.records) + assert not any("GENERIC DESERIALIZER USED" in record.message for record in caplog.records) diff --git a/tests/isolation/test_sealed_worker_contract_matrix.py b/tests/isolation/test_sealed_worker_contract_matrix.py new file mode 100644 index 000000000..7395c334c --- /dev/null +++ b/tests/isolation/test_sealed_worker_contract_matrix.py @@ -0,0 +1,368 @@ +"""Generic sealed-worker loader contract matrix tests.""" + +from __future__ import annotations + +import importlib +import json +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +COMFYUI_ROOT = Path(__file__).resolve().parents[2] +TEST_WORKFLOW_ROOT = COMFYUI_ROOT / "tests" / "isolation" / "workflows" +SEALED_WORKFLOW_CLASS_TYPES: dict[str, set[str]] = { + "quick_6_uv_sealed_worker.json": { + "EmptyLatentImage", + "ProxyTestSealedWorker", + "UVSealedBoltonsSlugify", + "UVSealedLatentEcho", + "UVSealedRuntimeProbe", + }, + "isolation_7_uv_sealed_worker.json": { + "EmptyLatentImage", + "ProxyTestSealedWorker", + "UVSealedBoltonsSlugify", + "UVSealedLatentEcho", + "UVSealedRuntimeProbe", + }, + "quick_8_conda_sealed_worker.json": { + "CondaSealedLatentEcho", + "CondaSealedOpenWeatherDataset", + "CondaSealedRuntimeProbe", + "EmptyLatentImage", + "ProxyTestCondaSealedWorker", + }, + "isolation_9_conda_sealed_worker.json": { + "CondaSealedLatentEcho", + "CondaSealedOpenWeatherDataset", + "CondaSealedRuntimeProbe", + "EmptyLatentImage", + "ProxyTestCondaSealedWorker", + }, +} + + +def _workflow_class_types(path: Path) -> set[str]: + payload = json.loads(path.read_text(encoding="utf-8")) + return { + node["class_type"] + for node in payload.values() + if isinstance(node, dict) and "class_type" in node + } + + +def _make_manifest( + *, + package_manager: str = "uv", + execution_model: str | None = None, + can_isolate: bool = True, + dependencies: list[str] | None = None, + share_torch: bool = False, + sealed_host_ro_paths: list[str] | None = None, +) -> dict: + isolation: dict[str, object] = { + "can_isolate": can_isolate, + } + if package_manager != "uv": + isolation["package_manager"] = package_manager + if execution_model is not None: + isolation["execution_model"] = execution_model + if share_torch: + isolation["share_torch"] = True + if sealed_host_ro_paths is not None: + isolation["sealed_host_ro_paths"] = sealed_host_ro_paths + + if package_manager == "conda": + isolation["conda_channels"] = ["conda-forge"] + isolation["conda_dependencies"] = ["numpy"] + + return { + "project": { + "name": "contract-extension", + "dependencies": dependencies or ["numpy"], + }, + "tool": {"comfy": {"isolation": isolation}}, + } + + +@pytest.fixture +def manifest_file(tmp_path: Path) -> Path: + path = tmp_path / "pyproject.toml" + path.write_bytes(b"") + return path + + +def _loader_module( + monkeypatch: pytest.MonkeyPatch, *, preload_extension_wrapper: bool +): + mock_wrapper = MagicMock() + mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {}) + + iso_mod = types.ModuleType("comfy.isolation") + iso_mod.__path__ = [ + str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation") + ] + iso_mod.__package__ = "comfy.isolation" + + manifest_loader = types.SimpleNamespace( + is_cache_valid=lambda *args, **kwargs: False, + load_from_cache=lambda *args, **kwargs: None, + save_to_cache=lambda *args, **kwargs: None, + ) + host_policy = types.SimpleNamespace( + load_host_policy=lambda base_path: { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": [], + } + ) + folder_paths = types.SimpleNamespace(base_path="/fake/comfyui") + + monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod) + monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock()) + monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader) + monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy) + monkeypatch.setitem(sys.modules, "folder_paths", folder_paths) + if preload_extension_wrapper: + monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper) + else: + sys.modules.pop("comfy.isolation.extension_wrapper", None) + sys.modules.pop("comfy.isolation.extension_loader", None) + + module = importlib.import_module("comfy.isolation.extension_loader") + try: + yield module, mock_wrapper + finally: + sys.modules.pop("comfy.isolation.extension_loader", None) + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"): + delattr(comfy_pkg, "isolation") + + +@pytest.fixture +def loader_module(monkeypatch: pytest.MonkeyPatch): + yield from _loader_module(monkeypatch, preload_extension_wrapper=True) + + +@pytest.fixture +def sealed_loader_module(monkeypatch: pytest.MonkeyPatch): + yield from _loader_module(monkeypatch, preload_extension_wrapper=False) + + +@pytest.fixture +def mocked_loader(loader_module): + module, mock_wrapper = loader_module + mock_ext = AsyncMock() + mock_ext.list_nodes = AsyncMock(return_value={}) + + mock_manager = MagicMock() + mock_manager.load_extension = MagicMock(return_value=mock_ext) + sealed_type = type("SealedNodeExtension", (), {}) + + with patch.object(module, "pyisolate") as mock_pi: + mock_pi.ExtensionManager = MagicMock(return_value=mock_manager) + mock_pi.SealedNodeExtension = sealed_type + yield module, mock_pi, mock_manager, sealed_type, mock_wrapper + + +@pytest.fixture +def sealed_mocked_loader(sealed_loader_module): + module, mock_wrapper = sealed_loader_module + mock_ext = AsyncMock() + mock_ext.list_nodes = AsyncMock(return_value={}) + + mock_manager = MagicMock() + mock_manager.load_extension = MagicMock(return_value=mock_ext) + sealed_type = type("SealedNodeExtension", (), {}) + + with patch.object(module, "pyisolate") as mock_pi: + mock_pi.ExtensionManager = MagicMock(return_value=mock_manager) + mock_pi.SealedNodeExtension = sealed_type + yield module, mock_pi, mock_manager, sealed_type, mock_wrapper + + +async def _load_node(module, manifest: dict, manifest_path: Path, tmp_path: Path) -> dict: + with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib: + mock_tomllib.load.return_value = manifest + await module.load_isolated_node( + node_dir=tmp_path, + manifest_path=manifest_path, + logger=MagicMock(), + build_stub_class=MagicMock(), + venv_root=tmp_path / "venvs", + extension_managers=[], + ) + manager = module.pyisolate.ExtensionManager.return_value + return manager.load_extension.call_args[0][0] + + +@pytest.mark.asyncio +async def test_uv_host_coupled_default(mocked_loader, manifest_file: Path, tmp_path: Path): + module, mock_pi, _mock_manager, sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="uv") + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type is not sealed_type + assert "execution_model" not in config + + +@pytest.mark.asyncio +async def test_uv_sealed_worker_opt_in( + sealed_mocked_loader, manifest_file: Path, tmp_path: Path +): + module, mock_pi, _mock_manager, sealed_type, _ = sealed_mocked_loader + manifest = _make_manifest(package_manager="uv", execution_model="sealed_worker") + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type is sealed_type + assert config["execution_model"] == "sealed_worker" + assert "apis" not in config + assert "comfy.isolation.extension_wrapper" not in sys.modules + + +@pytest.mark.asyncio +async def test_conda_defaults_to_sealed_worker( + sealed_mocked_loader, manifest_file: Path, tmp_path: Path +): + module, mock_pi, _mock_manager, sealed_type, _ = sealed_mocked_loader + manifest = _make_manifest(package_manager="conda") + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type is sealed_type + assert config["execution_model"] == "sealed_worker" + assert config["package_manager"] == "conda" + assert "comfy.isolation.extension_wrapper" not in sys.modules + + +@pytest.mark.asyncio +async def test_conda_never_uses_comfy_extension_type( + mocked_loader, manifest_file: Path, tmp_path: Path +): + module, mock_pi, _mock_manager, sealed_type, mock_wrapper = mocked_loader + manifest = _make_manifest(package_manager="conda") + + await _load_node(module, manifest, manifest_file, tmp_path) + + extension_type = mock_pi.ExtensionManager.call_args[0][0] + assert extension_type is sealed_type + assert extension_type is not mock_wrapper.ComfyNodeExtension + + +@pytest.mark.asyncio +async def test_conda_forces_share_torch_false(mocked_loader, manifest_file: Path, tmp_path: Path): + module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="conda", share_torch=True) + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + assert config["share_torch"] is False + + +@pytest.mark.asyncio +async def test_conda_forces_share_cuda_ipc_false( + mocked_loader, manifest_file: Path, tmp_path: Path +): + module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="conda", share_torch=True) + + config = await _load_node(module, manifest, manifest_file, tmp_path) + + assert config["share_cuda_ipc"] is False + + +@pytest.mark.asyncio +async def test_conda_sandbox_policy_applied(mocked_loader, manifest_file: Path, tmp_path: Path): + module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="conda") + + custom_policy = { + "sandbox_mode": "required", + "allow_network": True, + "writable_paths": ["/data/write"], + "readonly_paths": ["/data/read"], + } + + with patch("platform.system", return_value="Linux"): + with patch.object(module, "load_host_policy", return_value=custom_policy): + config = await _load_node(module, manifest, manifest_file, tmp_path) + + assert config["sandbox_mode"] == "required" + assert config["sandbox"] == { + "network": True, + "writable_paths": ["/data/write"], + "readonly_paths": ["/data/read"], + } + + +def test_sealed_worker_workflow_templates_present() -> None: + missing = [ + filename + for filename in SEALED_WORKFLOW_CLASS_TYPES + if not (TEST_WORKFLOW_ROOT / filename).is_file() + ] + assert not missing, f"missing sealed-worker workflow templates: {missing}" + + +@pytest.mark.parametrize( + "workflow_name,expected_class_types", + SEALED_WORKFLOW_CLASS_TYPES.items(), +) +def test_sealed_worker_workflow_class_type_contract( + workflow_name: str, expected_class_types: set[str] +) -> None: + workflow_path = TEST_WORKFLOW_ROOT / workflow_name + assert workflow_path.is_file(), f"workflow missing: {workflow_path}" + + assert _workflow_class_types(workflow_path) == expected_class_types + + +@pytest.mark.asyncio +async def test_sealed_worker_host_policy_ro_import_matrix( + mocked_loader, manifest_file: Path, tmp_path: Path +): + module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader + manifest = _make_manifest(package_manager="uv", execution_model="sealed_worker") + + with patch.object( + module, + "load_host_policy", + return_value={ + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": [], + }, + ): + default_config = await _load_node(module, manifest, manifest_file, tmp_path) + + with patch.object( + module, + "load_host_policy", + return_value={ + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + "sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"], + }, + ): + opt_in_config = await _load_node(module, manifest, manifest_file, tmp_path) + + assert default_config["execution_model"] == "sealed_worker" + assert "sealed_host_ro_paths" not in default_config + + assert opt_in_config["execution_model"] == "sealed_worker" + assert opt_in_config["sealed_host_ro_paths"] == ["/home/johnj/ComfyUI"] + assert "apis" not in opt_in_config diff --git a/tests/isolation/test_shared_model_proxy_contract.py b/tests/isolation/test_shared_model_proxy_contract.py new file mode 100644 index 000000000..9e91c74f3 --- /dev/null +++ b/tests/isolation/test_shared_model_proxy_contract.py @@ -0,0 +1,44 @@ +import asyncio +import sys +from pathlib import Path + +repo_root = Path(__file__).resolve().parents[2] +pyisolate_root = repo_root.parent / "pyisolate" +if pyisolate_root.exists(): + sys.path.insert(0, str(pyisolate_root)) + +from comfy.isolation.adapter import ComfyUIAdapter +from comfy.isolation.runtime_helpers import _wrap_remote_handles_as_host_proxies +from pyisolate._internal.model_serialization import deserialize_from_isolation +from pyisolate._internal.remote_handle import RemoteObjectHandle +from pyisolate._internal.serialization_registry import SerializerRegistry + + +def test_shared_model_ksampler_contract(): + registry = SerializerRegistry.get_instance() + registry.clear() + ComfyUIAdapter().register_serializers(registry) + + handle = RemoteObjectHandle("model_0", "ModelPatcher") + + class FakeExtension: + async def call_remote_object_method(self, object_id, method_name, *args, **kwargs): + assert object_id == "model_0" + assert method_name == "get_model_object" + assert args == ("latent_format",) + assert kwargs == {} + return "resolved:latent_format" + + wrapped = (handle,) + assert isinstance(wrapped, tuple) + assert isinstance(wrapped[0], RemoteObjectHandle) + + deserialized = asyncio.run(deserialize_from_isolation(wrapped)) + proxied = _wrap_remote_handles_as_host_proxies(deserialized, FakeExtension()) + model_for_host = proxied[0] + + assert not isinstance(model_for_host, RemoteObjectHandle) + assert hasattr(model_for_host, "get_model_object") + assert model_for_host.get_model_object("latent_format") == "resolved:latent_format" + + registry.clear() diff --git a/tests/isolation/test_singleton_proxy_boundary_matrix.py b/tests/isolation/test_singleton_proxy_boundary_matrix.py new file mode 100644 index 000000000..31cc86e04 --- /dev/null +++ b/tests/isolation/test_singleton_proxy_boundary_matrix.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import json + +from tests.isolation.singleton_boundary_helpers import ( + capture_minimal_sealed_worker_imports, + capture_sealed_singleton_imports, +) + + +def test_minimal_sealed_worker_forbidden_imports() -> None: + payload = capture_minimal_sealed_worker_imports() + + assert payload["mode"] == "minimal_sealed_worker" + assert payload["runtime_probe_function"] == "inspect" + assert payload["forbidden_matches"] == [] + + +def test_torch_share_subset_scope() -> None: + minimal = capture_minimal_sealed_worker_imports() + + allowed_torch_share_only = { + "torch", + "folder_paths", + "comfy.utils", + "comfy.model_management", + "main", + "comfy.isolation.extension_wrapper", + } + + assert minimal["forbidden_matches"] == [] + assert all( + module_name not in minimal["modules"] for module_name in sorted(allowed_torch_share_only) + ) + + +def test_capture_payload_is_json_serializable() -> None: + payload = capture_minimal_sealed_worker_imports() + + encoded = json.dumps(payload, sort_keys=True) + + assert "\"minimal_sealed_worker\"" in encoded + + +def test_folder_paths_child_safe() -> None: + payload = capture_sealed_singleton_imports() + + assert payload["mode"] == "sealed_singletons" + assert payload["folder_path"] == "/sandbox/input/demo.png" + assert payload["temp_dir"] == "/sandbox/temp" + assert payload["models_dir"] == "/sandbox/models" + assert payload["forbidden_matches"] == [] + + +def test_utils_child_safe() -> None: + payload = capture_sealed_singleton_imports() + + progress_calls = [ + call + for call in payload["rpc_calls"] + if call["object_id"] == "UtilsProxy" and call["method"] == "progress_bar_hook" + ] + + assert progress_calls + assert payload["forbidden_matches"] == [] + + +def test_progress_child_safe() -> None: + payload = capture_sealed_singleton_imports() + + progress_calls = [ + call + for call in payload["rpc_calls"] + if call["object_id"] == "ProgressProxy" and call["method"] == "rpc_set_progress" + ] + + assert progress_calls + assert payload["forbidden_matches"] == [] diff --git a/tests/isolation/test_web_directory_handler.py b/tests/isolation/test_web_directory_handler.py new file mode 100644 index 000000000..f50e01977 --- /dev/null +++ b/tests/isolation/test_web_directory_handler.py @@ -0,0 +1,129 @@ +"""Tests for WebDirectoryProxy host-side cache and aiohttp handler integration.""" + +from __future__ import annotations + +import base64 +import sys +from unittest.mock import MagicMock + +import pytest + +from comfy.isolation.proxies.web_directory_proxy import ( + ALLOWED_EXTENSIONS, + WebDirectoryCache, +) + + +@pytest.fixture() +def mock_proxy() -> MagicMock: + """Create a mock WebDirectoryProxy RPC proxy.""" + proxy = MagicMock() + proxy.list_web_files.return_value = [ + {"relative_path": "js/app.js", "content_type": "application/javascript"}, + {"relative_path": "js/utils.js", "content_type": "application/javascript"}, + {"relative_path": "index.html", "content_type": "text/html"}, + {"relative_path": "style.css", "content_type": "text/css"}, + ] + proxy.get_web_file.return_value = { + "content": base64.b64encode(b"console.log('hello');").decode("ascii"), + "content_type": "application/javascript", + } + return proxy + + +@pytest.fixture() +def cache_with_proxy(mock_proxy: MagicMock) -> WebDirectoryCache: + """Create a WebDirectoryCache with a registered mock proxy.""" + cache = WebDirectoryCache() + cache.register_proxy("test-extension", mock_proxy) + return cache + + +class TestExtensionsListing: + """AC-2: /extensions endpoint lists proxied JS files in URL format.""" + + def test_extensions_listing_produces_url_format_paths( + self, cache_with_proxy: WebDirectoryCache + ) -> None: + """Simulate what server.py does: build /extensions/{name}/{path} URLs.""" + import urllib.parse + + ext_name = "test-extension" + urls = [] + for entry in cache_with_proxy.list_files(ext_name): + if entry["relative_path"].endswith(".js"): + urls.append( + "/extensions/" + urllib.parse.quote(ext_name) + + "/" + entry["relative_path"] + ) + + # Emit the actual URL list so it appears in test log output. + sys.stdout.write(f"\n--- Proxied JS URLs ({len(urls)}) ---\n") + for url in urls: + sys.stdout.write(f" {url}\n") + sys.stdout.write("--- End URLs ---\n") + + # At least one proxied JS URL in /extensions/{name}/{path} format + assert len(urls) >= 1, f"Expected >= 1 proxied JS URL, got {len(urls)}" + assert "/extensions/test-extension/js/app.js" in urls, ( + f"Expected /extensions/test-extension/js/app.js in {urls}" + ) + + +class TestCacheHit: + """AC-3: Cache populated on first request, reused on second.""" + + def test_cache_hit_single_rpc_call( + self, cache_with_proxy: WebDirectoryCache, mock_proxy: MagicMock + ) -> None: + # First call — RPC + result1 = cache_with_proxy.get_file("test-extension", "js/app.js") + assert result1 is not None + assert result1["content"] == b"console.log('hello');" + + # Second call — cache hit + result2 = cache_with_proxy.get_file("test-extension", "js/app.js") + assert result2 is not None + assert result2["content"] == b"console.log('hello');" + + # Proxy was called exactly once + assert mock_proxy.get_web_file.call_count == 1 + + def test_cache_returns_none_for_unknown_extension( + self, cache_with_proxy: WebDirectoryCache + ) -> None: + result = cache_with_proxy.get_file("nonexistent", "js/app.js") + assert result is None + + +class TestForbiddenType: + """AC-4: Disallowed file types return HTTP 403 Forbidden.""" + + @pytest.mark.parametrize( + "disallowed_path,expected_status", + [ + ("backdoor.py", 403), + ("malware.exe", 403), + ("exploit.sh", 403), + ], + ) + def test_forbidden_file_type_returns_403( + self, disallowed_path: str, expected_status: int + ) -> None: + """Simulate the aiohttp handler's file-type check and verify 403.""" + import os + suffix = os.path.splitext(disallowed_path)[1].lower() + + # This mirrors the handler logic in server.py: + # if suffix not in ALLOWED_EXTENSIONS: return web.Response(status=403) + if suffix not in ALLOWED_EXTENSIONS: + status = 403 + else: + status = 200 + + sys.stdout.write( + f"\n--- HTTP status for {disallowed_path} (suffix={suffix}): {status} ---\n" + ) + assert status == expected_status, ( + f"Expected HTTP {expected_status} for {disallowed_path}, got {status}" + ) diff --git a/tests/isolation/test_web_directory_proxy.py b/tests/isolation/test_web_directory_proxy.py new file mode 100644 index 000000000..2922da92d --- /dev/null +++ b/tests/isolation/test_web_directory_proxy.py @@ -0,0 +1,130 @@ +"""Tests for WebDirectoryProxy — allow-list, traversal prevention, content serving.""" + +from __future__ import annotations + +import base64 +from pathlib import Path + +import pytest + +from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + + +@pytest.fixture() +def web_dir_with_mixed_files(tmp_path: Path) -> Path: + """Create a temp web directory with allowed and disallowed file types.""" + web = tmp_path / "web" + js_dir = web / "js" + js_dir.mkdir(parents=True) + + # Allowed types + (js_dir / "app.js").write_text("console.log('hello');") + (web / "index.html").write_text("") + (web / "style.css").write_text("body { margin: 0; }") + + # Disallowed types + (web / "backdoor.py").write_text("import os; os.system('rm -rf /')") + (web / "malware.exe").write_bytes(b"\x00" * 16) + (web / "exploit.sh").write_text("#!/bin/bash\nrm -rf /") + + return web + + +@pytest.fixture() +def proxy_with_web_dir(web_dir_with_mixed_files: Path) -> WebDirectoryProxy: + """Create a WebDirectoryProxy with a registered test web directory.""" + proxy = WebDirectoryProxy() + # Clear class-level state to avoid cross-test pollution + WebDirectoryProxy._web_dirs = {} + WebDirectoryProxy.register_web_dir("test-extension", str(web_dir_with_mixed_files)) + return proxy + + +class TestAllowList: + """AC-2: list_web_files returns only allowed file types.""" + + def test_allowlist_only_safe_types( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + files = proxy_with_web_dir.list_web_files("test-extension") + extensions = {Path(f["relative_path"]).suffix for f in files} + + # Only .js, .html, .css should appear + assert extensions == {".js", ".html", ".css"} + + def test_allowlist_excludes_dangerous_types( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + files = proxy_with_web_dir.list_web_files("test-extension") + paths = [f["relative_path"] for f in files] + + assert not any(p.endswith(".py") for p in paths) + assert not any(p.endswith(".exe") for p in paths) + assert not any(p.endswith(".sh") for p in paths) + + def test_allowlist_correct_count( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + files = proxy_with_web_dir.list_web_files("test-extension") + # 3 allowed files: app.js, index.html, style.css + assert len(files) == 3 + + def test_allowlist_unknown_extension_returns_empty( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + files = proxy_with_web_dir.list_web_files("nonexistent-extension") + assert files == [] + + +class TestTraversal: + """AC-3: get_web_file rejects directory traversal attempts.""" + + @pytest.mark.parametrize( + "malicious_path", + [ + "../../../etc/passwd", + "/etc/passwd", + "../../__init__.py", + ], + ) + def test_traversal_rejected( + self, proxy_with_web_dir: WebDirectoryProxy, malicious_path: str + ) -> None: + with pytest.raises(ValueError): + proxy_with_web_dir.get_web_file("test-extension", malicious_path) + + +class TestContent: + """AC-4: get_web_file returns base64 content with correct MIME types.""" + + def test_content_js_mime_type( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + result = proxy_with_web_dir.get_web_file("test-extension", "js/app.js") + assert result["content_type"] == "application/javascript" + + def test_content_html_mime_type( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + result = proxy_with_web_dir.get_web_file("test-extension", "index.html") + assert result["content_type"] == "text/html" + + def test_content_css_mime_type( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + result = proxy_with_web_dir.get_web_file("test-extension", "style.css") + assert result["content_type"] == "text/css" + + def test_content_base64_roundtrip( + self, proxy_with_web_dir: WebDirectoryProxy, web_dir_with_mixed_files: Path + ) -> None: + result = proxy_with_web_dir.get_web_file("test-extension", "js/app.js") + decoded = base64.b64decode(result["content"]) + source = (web_dir_with_mixed_files / "js" / "app.js").read_bytes() + assert decoded == source + + def test_content_disallowed_type_rejected( + self, proxy_with_web_dir: WebDirectoryProxy + ) -> None: + with pytest.raises(ValueError, match="Disallowed file type"): + proxy_with_web_dir.get_web_file("test-extension", "backdoor.py") diff --git a/tests/isolation/uv_sealed_worker/__init__.py b/tests/isolation/uv_sealed_worker/__init__.py new file mode 100644 index 000000000..453915a93 --- /dev/null +++ b/tests/isolation/uv_sealed_worker/__init__.py @@ -0,0 +1,230 @@ +# pylint: disable=import-outside-toplevel,import-error +from __future__ import annotations + +import os +import sys +import logging +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def _artifact_dir() -> Path | None: + raw = os.environ.get("PYISOLATE_ARTIFACT_DIR") + if not raw: + return None + path = Path(raw) + path.mkdir(parents=True, exist_ok=True) + return path + + +def _write_artifact(name: str, content: str) -> None: + artifact_dir = _artifact_dir() + if artifact_dir is None: + return + (artifact_dir / name).write_text(content, encoding="utf-8") + + +def _contains_tensor_marker(value: Any) -> bool: + if isinstance(value, dict): + if value.get("__type__") == "TensorValue": + return True + return any(_contains_tensor_marker(v) for v in value.values()) + if isinstance(value, (list, tuple)): + return any(_contains_tensor_marker(v) for v in value) + return False + + +class InspectRuntimeNode: + RETURN_TYPES = ( + "STRING", + "STRING", + "BOOLEAN", + "BOOLEAN", + "STRING", + "STRING", + "BOOLEAN", + ) + RETURN_NAMES = ( + "path_dump", + "boltons_origin", + "saw_comfy_root", + "imported_comfy_wrapper", + "comfy_module_dump", + "report", + "saw_user_site", + ) + FUNCTION = "inspect" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def inspect(self) -> tuple[str, str, bool, bool, str, str, bool]: + import boltons + + path_dump = "\n".join(sys.path) + comfy_root = "/home/johnj/ComfyUI" + saw_comfy_root = any( + entry == comfy_root + or entry.startswith(f"{comfy_root}/comfy") + or entry.startswith(f"{comfy_root}/.venv") + for entry in sys.path + ) + imported_comfy_wrapper = "comfy.isolation.extension_wrapper" in sys.modules + comfy_module_dump = "\n".join( + sorted(name for name in sys.modules if name.startswith("comfy")) + ) + saw_user_site = any("/.local/lib/" in entry for entry in sys.path) + boltons_origin = getattr(boltons, "__file__", "") + + report_lines = [ + "UV sealed worker runtime probe", + f"boltons_origin={boltons_origin}", + f"saw_comfy_root={saw_comfy_root}", + f"imported_comfy_wrapper={imported_comfy_wrapper}", + f"saw_user_site={saw_user_site}", + ] + report = "\n".join(report_lines) + + _write_artifact("child_bootstrap_paths.txt", path_dump) + _write_artifact("child_import_trace.txt", comfy_module_dump) + _write_artifact("child_dependency_dump.txt", boltons_origin) + logger.warning("][ UV sealed runtime probe executed") + logger.warning("][ boltons origin: %s", boltons_origin) + + return ( + path_dump, + boltons_origin, + saw_comfy_root, + imported_comfy_wrapper, + comfy_module_dump, + report, + saw_user_site, + ) + + +class BoltonsSlugifyNode: + RETURN_TYPES = ("STRING", "STRING") + RETURN_NAMES = ("slug", "boltons_origin") + FUNCTION = "slugify_text" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"text": ("STRING", {"default": "Sealed Worker Rocks"})}} + + def slugify_text(self, text: str) -> tuple[str, str]: + import boltons + from boltons.strutils import slugify + + slug = slugify(text) + origin = getattr(boltons, "__file__", "") + logger.warning("][ boltons slugify: %r -> %r", text, slug) + return slug, origin + + +class FilesystemBarrierNode: + RETURN_TYPES = ("STRING", "BOOLEAN", "BOOLEAN", "BOOLEAN") + RETURN_NAMES = ( + "report", + "outside_blocked", + "module_mutation_blocked", + "artifact_write_ok", + ) + FUNCTION = "probe" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def probe(self) -> tuple[str, bool, bool, bool]: + artifact_dir = _artifact_dir() + artifact_write_ok = False + if artifact_dir is not None: + probe_path = artifact_dir / "filesystem_barrier_probe.txt" + probe_path.write_text("artifact write ok\n", encoding="utf-8") + artifact_write_ok = probe_path.exists() + + module_target = Path(__file__).with_name( + "mutated_from_child_should_not_exist.txt" + ) + module_mutation_blocked = False + try: + module_target.write_text("mutation should fail\n", encoding="utf-8") + except Exception: + module_mutation_blocked = True + else: + module_target.unlink(missing_ok=True) + + outside_target = Path("/home/johnj/mysolate/.uv_sealed_worker_escape_probe") + outside_blocked = False + try: + outside_target.write_text("escape should fail\n", encoding="utf-8") + except Exception: + outside_blocked = True + else: + outside_target.unlink(missing_ok=True) + + report_lines = [ + "UV sealed worker filesystem barrier probe", + f"artifact_write_ok={artifact_write_ok}", + f"module_mutation_blocked={module_mutation_blocked}", + f"outside_blocked={outside_blocked}", + ] + report = "\n".join(report_lines) + _write_artifact("filesystem_barrier_report.txt", report) + logger.warning("][ filesystem barrier probe executed") + return report, outside_blocked, module_mutation_blocked, artifact_write_ok + + +class EchoTensorNode: + RETURN_TYPES = ("TENSOR", "BOOLEAN") + RETURN_NAMES = ("tensor", "saw_json_tensor") + FUNCTION = "echo" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"tensor": ("TENSOR",)}} + + def echo(self, tensor: Any) -> tuple[Any, bool]: + saw_json_tensor = _contains_tensor_marker(tensor) + logger.warning("][ tensor echo json_marker=%s", saw_json_tensor) + return tensor, saw_json_tensor + + +class EchoLatentNode: + RETURN_TYPES = ("LATENT", "BOOLEAN") + RETURN_NAMES = ("latent", "saw_json_tensor") + FUNCTION = "echo_latent" + CATEGORY = "PyIsolated/SealedWorker" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"latent": ("LATENT",)}} + + def echo_latent(self, latent: Any) -> tuple[Any, bool]: + saw_json_tensor = _contains_tensor_marker(latent) + logger.warning("][ latent echo json_marker=%s", saw_json_tensor) + return latent, saw_json_tensor + + +NODE_CLASS_MAPPINGS = { + "UVSealedRuntimeProbe": InspectRuntimeNode, + "UVSealedBoltonsSlugify": BoltonsSlugifyNode, + "UVSealedFilesystemBarrier": FilesystemBarrierNode, + "UVSealedTensorEcho": EchoTensorNode, + "UVSealedLatentEcho": EchoLatentNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "UVSealedRuntimeProbe": "UV Sealed Runtime Probe", + "UVSealedBoltonsSlugify": "UV Sealed Boltons Slugify", + "UVSealedFilesystemBarrier": "UV Sealed Filesystem Barrier", + "UVSealedTensorEcho": "UV Sealed Tensor Echo", + "UVSealedLatentEcho": "UV Sealed Latent Echo", +} diff --git a/tests/isolation/uv_sealed_worker/pyproject.toml b/tests/isolation/uv_sealed_worker/pyproject.toml new file mode 100644 index 000000000..f50d21eb3 --- /dev/null +++ b/tests/isolation/uv_sealed_worker/pyproject.toml @@ -0,0 +1,11 @@ +[project] +name = "comfyui-toolkit-uv-sealed-worker" +version = "0.1.0" +dependencies = ["boltons"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = false +package_manager = "uv" +execution_model = "sealed_worker" +standalone = true diff --git a/tests/isolation/workflows/internal_probe_preview_image_audio.json b/tests/isolation/workflows/internal_probe_preview_image_audio.json new file mode 100644 index 000000000..69f5f0d2f --- /dev/null +++ b/tests/isolation/workflows/internal_probe_preview_image_audio.json @@ -0,0 +1,10 @@ +{ + "1": { + "class_type": "InternalIsolationProbeImage", + "inputs": {} + }, + "2": { + "class_type": "InternalIsolationProbeAudio", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/internal_probe_ui3d.json b/tests/isolation/workflows/internal_probe_ui3d.json new file mode 100644 index 000000000..fea2dc3e7 --- /dev/null +++ b/tests/isolation/workflows/internal_probe_ui3d.json @@ -0,0 +1,6 @@ +{ + "1": { + "class_type": "InternalIsolationProbeUI3D", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/isolation_7_uv_sealed_worker.json b/tests/isolation/workflows/isolation_7_uv_sealed_worker.json new file mode 100644 index 000000000..3b83fa0db --- /dev/null +++ b/tests/isolation/workflows/isolation_7_uv_sealed_worker.json @@ -0,0 +1,22 @@ +{ + "1": { + "class_type": "EmptyLatentImage", + "inputs": {} + }, + "2": { + "class_type": "ProxyTestSealedWorker", + "inputs": {} + }, + "3": { + "class_type": "UVSealedBoltonsSlugify", + "inputs": {} + }, + "4": { + "class_type": "UVSealedLatentEcho", + "inputs": {} + }, + "5": { + "class_type": "UVSealedRuntimeProbe", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/isolation_9_conda_sealed_worker.json b/tests/isolation/workflows/isolation_9_conda_sealed_worker.json new file mode 100644 index 000000000..acfa2e59b --- /dev/null +++ b/tests/isolation/workflows/isolation_9_conda_sealed_worker.json @@ -0,0 +1,22 @@ +{ + "1": { + "class_type": "CondaSealedLatentEcho", + "inputs": {} + }, + "2": { + "class_type": "CondaSealedOpenWeatherDataset", + "inputs": {} + }, + "3": { + "class_type": "CondaSealedRuntimeProbe", + "inputs": {} + }, + "4": { + "class_type": "EmptyLatentImage", + "inputs": {} + }, + "5": { + "class_type": "ProxyTestCondaSealedWorker", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/quick_6_uv_sealed_worker.json b/tests/isolation/workflows/quick_6_uv_sealed_worker.json new file mode 100644 index 000000000..3b83fa0db --- /dev/null +++ b/tests/isolation/workflows/quick_6_uv_sealed_worker.json @@ -0,0 +1,22 @@ +{ + "1": { + "class_type": "EmptyLatentImage", + "inputs": {} + }, + "2": { + "class_type": "ProxyTestSealedWorker", + "inputs": {} + }, + "3": { + "class_type": "UVSealedBoltonsSlugify", + "inputs": {} + }, + "4": { + "class_type": "UVSealedLatentEcho", + "inputs": {} + }, + "5": { + "class_type": "UVSealedRuntimeProbe", + "inputs": {} + } +} diff --git a/tests/isolation/workflows/quick_8_conda_sealed_worker.json b/tests/isolation/workflows/quick_8_conda_sealed_worker.json new file mode 100644 index 000000000..acfa2e59b --- /dev/null +++ b/tests/isolation/workflows/quick_8_conda_sealed_worker.json @@ -0,0 +1,22 @@ +{ + "1": { + "class_type": "CondaSealedLatentEcho", + "inputs": {} + }, + "2": { + "class_type": "CondaSealedOpenWeatherDataset", + "inputs": {} + }, + "3": { + "class_type": "CondaSealedRuntimeProbe", + "inputs": {} + }, + "4": { + "class_type": "EmptyLatentImage", + "inputs": {} + }, + "5": { + "class_type": "ProxyTestCondaSealedWorker", + "inputs": {} + } +} diff --git a/tests/test_adapter.py b/tests/test_adapter.py new file mode 100644 index 000000000..298bc53f6 --- /dev/null +++ b/tests/test_adapter.py @@ -0,0 +1,124 @@ +import os +import subprocess +import sys +import textwrap +import types +from pathlib import Path + +repo_root = Path(__file__).resolve().parents[1] +pyisolate_root = repo_root.parent / "pyisolate" +if pyisolate_root.exists(): + sys.path.insert(0, str(pyisolate_root)) + +from comfy.isolation.adapter import ComfyUIAdapter +from pyisolate._internal.sandbox import build_bwrap_command +from pyisolate._internal.sandbox_detect import RestrictionModel +from pyisolate._internal.serialization_registry import SerializerRegistry + + +def test_identifier(): + adapter = ComfyUIAdapter() + assert adapter.identifier == "comfyui" + + +def test_get_path_config_valid(): + adapter = ComfyUIAdapter() + path = os.path.join("/opt", "ComfyUI", "custom_nodes", "demo") + cfg = adapter.get_path_config(path) + assert cfg is not None + assert cfg["preferred_root"].endswith("ComfyUI") + assert "custom_nodes" in cfg["additional_paths"][0] + + +def test_get_path_config_invalid(): + adapter = ComfyUIAdapter() + assert adapter.get_path_config("/random/path") is None + + +def test_provide_rpc_services(): + adapter = ComfyUIAdapter() + services = adapter.provide_rpc_services() + names = {s.__name__ for s in services} + assert "PromptServerService" in names + assert "FolderPathsProxy" in names + + +def test_register_serializers(): + adapter = ComfyUIAdapter() + registry = SerializerRegistry.get_instance() + registry.clear() + + adapter.register_serializers(registry) + assert registry.has_handler("ModelPatcher") + assert registry.has_handler("CLIP") + assert registry.has_handler("VAE") + + registry.clear() + + +def test_child_temp_directory_fence_uses_private_tmp(tmp_path): + adapter = ComfyUIAdapter() + child_script = textwrap.dedent( + """ + from pathlib import Path + + child_temp = Path("/tmp/comfyui_temp") + child_temp.mkdir(parents=True, exist_ok=True) + scratch = child_temp / "child_only.txt" + scratch.write_text("child-only", encoding="utf-8") + print(f"CHILD_TEMP={child_temp}") + print(f"CHILD_FILE={scratch}") + """ + ) + fake_folder_paths = types.SimpleNamespace( + temp_directory="/host/tmp/should_not_survive", + folder_names_and_paths={}, + extension_mimetypes_cache={}, + filename_list_cache={}, + ) + + class FolderPathsProxy: + def get_temp_directory(self): + return "/host/tmp/should_not_survive" + + original_folder_paths = sys.modules.get("folder_paths") + sys.modules["folder_paths"] = fake_folder_paths + try: + os.environ["PYISOLATE_CHILD"] = "1" + adapter.handle_api_registration(FolderPathsProxy, rpc=None) + finally: + os.environ.pop("PYISOLATE_CHILD", None) + if original_folder_paths is not None: + sys.modules["folder_paths"] = original_folder_paths + else: + sys.modules.pop("folder_paths", None) + + import tempfile as _tf + expected_temp = os.path.join(_tf.gettempdir(), "comfyui_temp") + assert fake_folder_paths.temp_directory == expected_temp + + host_child_file = Path(expected_temp) / "child_only.txt" + if host_child_file.exists(): + host_child_file.unlink() + + cmd = build_bwrap_command( + python_exe=sys.executable, + module_path=str(repo_root / "custom_nodes" / "ComfyUI-IsolationToolkit"), + venv_path=str(repo_root / ".venv"), + uds_address=str(tmp_path / "adapter.sock"), + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + sandbox_config={"writable_paths": ["/dev/shm"], "readonly_paths": [], "network": False}, + adapter=adapter, + ) + assert "--tmpfs" in cmd and "/tmp" in cmd + assert ["--bind", "/tmp", "/tmp"] not in [cmd[i : i + 3] for i in range(len(cmd) - 2)] + + command_tail = cmd[-3:] + assert command_tail[1:] == ["-m", "pyisolate._internal.uds_client"] + cmd = cmd[:-3] + [sys.executable, "-c", child_script] + + completed = subprocess.run(cmd, check=True, capture_output=True, text=True) + + assert "CHILD_TEMP=/tmp/comfyui_temp" in completed.stdout + assert not host_child_file.exists(), "Child temp file leaked into host /tmp"