ComfyUI/comfy/isolation/extension_loader.py
John Pollock 6aa0b838a0 feat(isolation): wheel support for isolated custom nodes
Extends pyisolate process isolation with wheel-based dependency
management, sandbox mode policy, and compatibility fixes validated
against DA3 as the first complex isolated custom node.

- Add sandbox_mode policy (required/disabled) with COMFY_HOST_POLICY_PATH
  env override for host security configuration
- Plumb cuda_wheels config and standardize child environment detection
- Add PLY, NPZ, File3D, VIDEO serializers for core save nodes
- Register isolated extension web directories on host side
  (sandbox_mode=disabled) for frontend JS widget serving
- Capture ACCEPT_ALL_INPUTS from child node class to prevent
  @classproperty trigger on proxy class
- Serialize NodeOutput with ui/expand/block_execution through JSON-RPC
- Cherry-pick V3 python_module metadata fix (e4592190)
- Remove improperly committed cache artifacts
2026-03-15 01:25:40 -05:00

389 lines
14 KiB
Python

# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
from __future__ import annotations
import logging
import os
import inspect
import sys
import types
import platform
from pathlib import Path
from typing import Callable, Dict, List, Tuple
import pyisolate
from pyisolate import ExtensionManager, ExtensionManagerConfig
from packaging.requirements import InvalidRequirement, Requirement
from packaging.utils import canonicalize_name
from .extension_wrapper import ComfyNodeExtension
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
from .host_policy import load_host_policy
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
logger = logging.getLogger(__name__)
def _register_web_directory(extension_name: str, node_dir: Path) -> None:
"""Register an isolated extension's web directory on the host side."""
import nodes
# Method 1: pyproject.toml [tool.comfy] web field
pyproject = node_dir / "pyproject.toml"
if pyproject.exists():
try:
with pyproject.open("rb") as f:
data = tomllib.load(f)
web_dir_name = data.get("tool", {}).get("comfy", {}).get("web")
if web_dir_name:
web_dir_path = str(node_dir / web_dir_name)
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
return
except Exception:
pass
# Method 2: __init__.py WEB_DIRECTORY constant (parse without importing)
init_file = node_dir / "__init__.py"
if init_file.exists():
try:
source = init_file.read_text()
for line in source.splitlines():
stripped = line.strip()
if stripped.startswith("WEB_DIRECTORY"):
# Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web"
_, _, value = stripped.partition("=")
value = value.strip().strip("\"'")
if value:
web_dir_path = str((node_dir / value).resolve())
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
return
except Exception:
pass
async def _stop_extension_safe(
extension: ComfyNodeExtension, extension_name: str
) -> None:
try:
stop_result = extension.stop()
if inspect.isawaitable(stop_result):
await stop_result
except Exception:
logger.debug("][ %s stop failed", extension_name, exc_info=True)
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
req, sep, marker = dep.partition(";")
req = req.strip()
marker_suffix = f";{marker}" if sep else ""
def _resolve_local_path(local_path: str) -> Path | None:
for base in base_paths:
candidate = (base / local_path).resolve()
if candidate.exists():
return candidate
return None
if req.startswith("./") or req.startswith("../"):
resolved = _resolve_local_path(req)
if resolved is not None:
return f"{resolved}{marker_suffix}"
if req.startswith("file://"):
raw = req[len("file://") :]
if raw.startswith("./") or raw.startswith("../"):
resolved = _resolve_local_path(raw)
if resolved is not None:
return f"file://{resolved}{marker_suffix}"
return dep
def _dependency_name_from_spec(dep: str) -> str | None:
stripped = dep.strip()
if not stripped or stripped == "-e" or stripped.startswith("-e "):
return None
if stripped.startswith(("/", "./", "../", "file://")):
return None
try:
return canonicalize_name(Requirement(stripped).name)
except InvalidRequirement:
return None
def _parse_cuda_wheels_config(
tool_config: dict[str, object], dependencies: list[str]
) -> dict[str, object] | None:
raw_config = tool_config.get("cuda_wheels")
if raw_config is None:
return None
if not isinstance(raw_config, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels] must be a table"
)
index_url = raw_config.get("index_url")
if not isinstance(index_url, str) or not index_url.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
)
packages = raw_config.get("packages")
if not isinstance(packages, list) or not all(
isinstance(package_name, str) and package_name.strip()
for package_name in packages
):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
)
declared_dependencies = {
dependency_name
for dep in dependencies
if (dependency_name := _dependency_name_from_spec(dep)) is not None
}
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
missing = [
package_name
for package_name in normalized_packages
if package_name not in declared_dependencies
]
if missing:
missing_joined = ", ".join(sorted(missing))
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
f"{missing_joined}"
)
package_map = raw_config.get("package_map", {})
if not isinstance(package_map, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
)
normalized_package_map: dict[str, str] = {}
for dependency_name, index_package_name in package_map.items():
if not isinstance(dependency_name, str) or not dependency_name.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
)
if not isinstance(index_package_name, str) or not index_package_name.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
)
canonical_dependency_name = canonicalize_name(dependency_name)
if canonical_dependency_name not in normalized_packages:
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
"[tool.comfy.isolation.cuda_wheels.packages]"
)
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
return {
"index_url": index_url.rstrip("/") + "/",
"packages": normalized_packages,
"package_map": normalized_package_map,
}
def get_enforcement_policy() -> Dict[str, bool]:
return {
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
}
class ExtensionLoadError(RuntimeError):
pass
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
normalized_name = extension_name.replace("-", "_").replace(".", "_")
if normalized_name not in sys.modules:
dummy_module = types.ModuleType(normalized_name)
dummy_module.__file__ = str(node_dir / "__init__.py")
dummy_module.__path__ = [str(node_dir)]
dummy_module.__package__ = normalized_name
sys.modules[normalized_name] = dummy_module
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
for details in cached_data.values():
if not isinstance(details, dict):
return True
if details.get("is_v3") and "schema_v1" not in details:
return True
return False
async def load_isolated_node(
node_dir: Path,
manifest_path: Path,
logger: logging.Logger,
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
venv_root: Path,
extension_managers: List[ExtensionManager],
) -> List[Tuple[str, str, type]]:
try:
with manifest_path.open("rb") as handle:
manifest_data = tomllib.load(handle)
except Exception as e:
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
return []
# Parse [tool.comfy.isolation]
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
can_isolate = tool_config.get("can_isolate", False)
share_torch = tool_config.get("share_torch", False)
# Parse [project] dependencies
project_config = manifest_data.get("project", {})
dependencies = project_config.get("dependencies", [])
if not isinstance(dependencies, list):
dependencies = []
# Get extension name (default to folder name if not in project.name)
extension_name = project_config.get("name", node_dir.name)
# LOGIC: Isolation Decision
policy = get_enforcement_policy()
isolated = can_isolate or policy["force_isolated"]
if not isolated:
return []
logger.info(f"][ Loading isolated node: {extension_name}")
import folder_paths
base_paths = [Path(folder_paths.base_path), node_dir]
dependencies = [
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
for dep in dependencies
]
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
manager: ExtensionManager = pyisolate.ExtensionManager(
ComfyNodeExtension, manager_config
)
extension_managers.append(manager)
host_policy = load_host_policy(Path(folder_paths.base_path))
sandbox_config = {}
is_linux = platform.system() == "Linux"
if is_linux and isolated:
sandbox_config = {
"network": host_policy["allow_network"],
"writable_paths": host_policy["writable_paths"],
"readonly_paths": host_policy["readonly_paths"],
}
share_cuda_ipc = share_torch and is_linux
extension_config = {
"name": extension_name,
"module_path": str(node_dir),
"isolated": True,
"dependencies": dependencies,
"share_torch": share_torch,
"share_cuda_ipc": share_cuda_ipc,
"sandbox_mode": host_policy["sandbox_mode"],
"sandbox": sandbox_config,
}
if cuda_wheels is not None:
extension_config["cuda_wheels"] = cuda_wheels
extension = manager.load_extension(extension_config)
register_dummy_module(extension_name, node_dir)
# Register web directory on the host — only when sandbox is disabled.
# In sandbox mode, serving untrusted JS to the browser is not safe.
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# Try cache first (lazy spawn)
if is_cache_valid(node_dir, manifest_path, venv_root):
cached_data = load_from_cache(node_dir, venv_root)
if cached_data:
if _is_stale_node_cache(cached_data):
logger.debug(
"][ %s cache is stale/incompatible; rebuilding metadata",
extension_name,
)
else:
logger.debug(f"][ {extension_name} loaded from cache")
specs: List[Tuple[str, str, type]] = []
for node_name, details in cached_data.items():
stub_cls = build_stub_class(node_name, details, extension)
specs.append(
(node_name, details.get("display_name", node_name), stub_cls)
)
return specs
# Cache miss - spawn process and get metadata
logger.debug(f"][ {extension_name} cache miss, spawning process for metadata")
try:
remote_nodes: Dict[str, str] = await extension.list_nodes()
except Exception as exc:
logger.warning(
"][ %s metadata discovery failed, skipping isolated load: %s",
extension_name,
exc,
)
await _stop_extension_safe(extension, extension_name)
return []
if not remote_nodes:
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
await _stop_extension_safe(extension, extension_name)
return []
specs: List[Tuple[str, str, type]] = []
cache_data: Dict[str, Dict] = {}
for node_name, display_name in remote_nodes.items():
try:
details = await extension.get_node_details(node_name)
except Exception as exc:
logger.warning(
"][ %s failed to load metadata for %s, skipping node: %s",
extension_name,
node_name,
exc,
)
continue
details["display_name"] = display_name
cache_data[node_name] = details
stub_cls = build_stub_class(node_name, details, extension)
specs.append((node_name, display_name, stub_cls))
if not specs:
logger.warning(
"][ %s produced no usable nodes after metadata scan; skipping",
extension_name,
)
await _stop_extension_safe(extension, extension_name)
return []
# Save metadata to cache for future runs
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
logger.debug(f"][ {extension_name} metadata cached")
# EJECT: Kill process after getting metadata (will respawn on first execution)
await _stop_extension_safe(extension, extension_name)
return specs
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]