mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-12 11:32:37 +08:00
feat: isolation core — adapter, loader, manifest, hooks, runtime helpers
This commit is contained in:
parent
c02372936d
commit
878684d8b2
@ -8,11 +8,21 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||||
import folder_paths
|
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||||
from .extension_loader import load_isolated_node
|
|
||||||
from .manifest_loader import find_manifest_directories
|
load_isolated_node = None
|
||||||
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
find_manifest_directories = None
|
||||||
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
build_stub_class = None
|
||||||
|
get_class_types_for_extension = None
|
||||||
|
scan_shm_forensics = None
|
||||||
|
start_shm_forensics = None
|
||||||
|
|
||||||
|
if _IMPORT_TORCH:
|
||||||
|
import folder_paths
|
||||||
|
from .extension_loader import load_isolated_node
|
||||||
|
from .manifest_loader import find_manifest_directories
|
||||||
|
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
||||||
|
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pyisolate import ExtensionManager
|
from pyisolate import ExtensionManager
|
||||||
@ -21,8 +31,9 @@ if TYPE_CHECKING:
|
|||||||
LOG_PREFIX = "]["
|
LOG_PREFIX = "]["
|
||||||
isolated_node_timings: List[tuple[float, Path, int]] = []
|
isolated_node_timings: List[tuple[float, Path, int]] = []
|
||||||
|
|
||||||
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
if _IMPORT_TORCH:
|
||||||
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
||||||
|
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||||
@ -42,7 +53,8 @@ def initialize_proxies() -> None:
|
|||||||
from .host_hooks import initialize_host_process
|
from .host_hooks import initialize_host_process
|
||||||
|
|
||||||
initialize_host_process()
|
initialize_host_process()
|
||||||
start_shm_forensics()
|
if start_shm_forensics is not None:
|
||||||
|
start_shm_forensics()
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -88,6 +100,8 @@ async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
_ISOLATION_SCAN_ATTEMPTED = True
|
_ISOLATION_SCAN_ATTEMPTED = True
|
||||||
|
if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None:
|
||||||
|
return []
|
||||||
manifest_entries = find_manifest_directories()
|
manifest_entries = find_manifest_directories()
|
||||||
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
||||||
|
|
||||||
@ -167,17 +181,27 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
|||||||
return class_types
|
return class_types
|
||||||
|
|
||||||
|
|
||||||
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None:
|
||||||
"""Evict running extensions not needed for current execution."""
|
"""Evict running extensions not needed for current execution.
|
||||||
|
|
||||||
|
When *caches* is provided, cache entries for evicted extensions' node
|
||||||
|
class_types are invalidated to prevent stale ``RemoteObjectHandle``
|
||||||
|
references from surviving in the output cache.
|
||||||
|
"""
|
||||||
await wait_for_model_patcher_quiescence(
|
await wait_for_model_patcher_quiescence(
|
||||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||||
fail_loud=True,
|
fail_loud=True,
|
||||||
marker="ISO:notify_graph_wait_idle",
|
marker="ISO:notify_graph_wait_idle",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
evicted_class_types: Set[str] = set()
|
||||||
|
|
||||||
async def _stop_extension(
|
async def _stop_extension(
|
||||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Collect class_types BEFORE stopping so we can invalidate cache entries.
|
||||||
|
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||||
|
evicted_class_types.update(ext_class_types)
|
||||||
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
||||||
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
||||||
stop_result = extension.stop()
|
stop_result = extension.stop()
|
||||||
@ -185,9 +209,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
|||||||
await stop_result
|
await stop_result
|
||||||
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
||||||
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
||||||
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
if scan_shm_forensics is not None:
|
||||||
|
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||||
|
|
||||||
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
if scan_shm_forensics is not None:
|
||||||
|
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||||
isolated_class_types_in_graph = needed_class_types.intersection(
|
isolated_class_types_in_graph = needed_class_types.intersection(
|
||||||
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
|
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
|
||||||
)
|
)
|
||||||
@ -247,6 +273,22 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
|||||||
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
# Invalidate cached outputs for evicted extensions so stale
|
||||||
|
# RemoteObjectHandle references are not served from cache.
|
||||||
|
if evicted_class_types and caches:
|
||||||
|
total_invalidated = 0
|
||||||
|
for cache in caches:
|
||||||
|
if hasattr(cache, "invalidate_by_class_types"):
|
||||||
|
total_invalidated += cache.invalidate_by_class_types(
|
||||||
|
evicted_class_types
|
||||||
|
)
|
||||||
|
if total_invalidated > 0:
|
||||||
|
logger.info(
|
||||||
|
"%s ISO:cache_invalidated count=%d class_types=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
total_invalidated,
|
||||||
|
evicted_class_types,
|
||||||
|
)
|
||||||
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
||||||
|
|||||||
@ -3,13 +3,38 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import inspect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, cast
|
||||||
|
|
||||||
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
||||||
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
||||||
|
|
||||||
try:
|
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||||
|
|
||||||
|
# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp.
|
||||||
|
# Safe to import in sealed workers without host framework modules.
|
||||||
|
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||||
|
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
|
||||||
|
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||||
|
|
||||||
|
# Singleton proxies that transitively import torch, PIL, or heavy host modules.
|
||||||
|
# Only available when torch/host framework is present.
|
||||||
|
CLIPProxy = None
|
||||||
|
CLIPRegistry = None
|
||||||
|
ModelPatcherProxy = None
|
||||||
|
ModelPatcherRegistry = None
|
||||||
|
ModelSamplingProxy = None
|
||||||
|
ModelSamplingRegistry = None
|
||||||
|
VAEProxy = None
|
||||||
|
VAERegistry = None
|
||||||
|
FirstStageModelRegistry = None
|
||||||
|
ModelManagementProxy = None
|
||||||
|
PromptServerService = None
|
||||||
|
ProgressProxy = None
|
||||||
|
UtilsProxy = None
|
||||||
|
_HAS_TORCH_PROXIES = False
|
||||||
|
if _IMPORT_TORCH:
|
||||||
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
|
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
|
||||||
from comfy.isolation.model_patcher_proxy import (
|
from comfy.isolation.model_patcher_proxy import (
|
||||||
ModelPatcherProxy,
|
ModelPatcherProxy,
|
||||||
@ -20,13 +45,11 @@ try:
|
|||||||
ModelSamplingRegistry,
|
ModelSamplingRegistry,
|
||||||
)
|
)
|
||||||
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
|
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
|
||||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
|
||||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
|
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
|
||||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
|
||||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||||
except ImportError as exc: # Fail loud if Comfy environment is incomplete
|
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||||
raise ImportError(f"ComfyUI environment incomplete: {exc}")
|
_HAS_TORCH_PROXIES = True
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -62,9 +85,20 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
os.path.join(comfy_root, "custom_nodes"),
|
os.path.join(comfy_root, "custom_nodes"),
|
||||||
os.path.join(comfy_root, "comfy"),
|
os.path.join(comfy_root, "comfy"),
|
||||||
],
|
],
|
||||||
|
"filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"],
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_sandbox_system_paths(self) -> Optional[List[str]]:
|
||||||
|
"""Returns required application paths to mount in the sandbox."""
|
||||||
|
# By inspecting where our adapter is loaded from, we can determine the comfy root
|
||||||
|
adapter_file = inspect.getfile(self.__class__)
|
||||||
|
# adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py
|
||||||
|
comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file)))
|
||||||
|
if os.path.exists(comfy_root):
|
||||||
|
return [comfy_root]
|
||||||
|
return None
|
||||||
|
|
||||||
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
|
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
|
||||||
comfy_root = snapshot.get("preferred_root")
|
comfy_root = snapshot.get("preferred_root")
|
||||||
if not comfy_root:
|
if not comfy_root:
|
||||||
@ -83,6 +117,59 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||||
|
|
||||||
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||||
|
if not _IMPORT_TORCH:
|
||||||
|
# Sealed worker without torch — register torch-free TensorValue handler
|
||||||
|
# so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts.
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_TORCH_DTYPE_TO_NUMPY = {
|
||||||
|
"torch.float32": np.float32,
|
||||||
|
"torch.float64": np.float64,
|
||||||
|
"torch.float16": np.float16,
|
||||||
|
"torch.bfloat16": np.float32, # numpy has no bfloat16; upcast
|
||||||
|
"torch.int32": np.int32,
|
||||||
|
"torch.int64": np.int64,
|
||||||
|
"torch.int16": np.int16,
|
||||||
|
"torch.int8": np.int8,
|
||||||
|
"torch.uint8": np.uint8,
|
||||||
|
"torch.bool": np.bool_,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _deserialize_tensor_value(data: Dict[str, Any]) -> Any:
|
||||||
|
dtype_str = data["dtype"]
|
||||||
|
np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32)
|
||||||
|
shape = tuple(data["tensor_size"])
|
||||||
|
arr = np.array(data["data"], dtype=np_dtype).reshape(shape)
|
||||||
|
return arr
|
||||||
|
|
||||||
|
_NUMPY_TO_TORCH_DTYPE = {
|
||||||
|
np.float32: "torch.float32",
|
||||||
|
np.float64: "torch.float64",
|
||||||
|
np.float16: "torch.float16",
|
||||||
|
np.int32: "torch.int32",
|
||||||
|
np.int64: "torch.int64",
|
||||||
|
np.int16: "torch.int16",
|
||||||
|
np.int8: "torch.int8",
|
||||||
|
np.uint8: "torch.uint8",
|
||||||
|
np.bool_: "torch.bool",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_tensor_value(obj: Any) -> Dict[str, Any]:
|
||||||
|
arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj)
|
||||||
|
dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32")
|
||||||
|
return {
|
||||||
|
"__type__": "TensorValue",
|
||||||
|
"dtype": dtype_str,
|
||||||
|
"tensor_size": list(arr.shape),
|
||||||
|
"requires_grad": False,
|
||||||
|
"data": arr.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
|
||||||
|
# ndarray output from sealed workers serializes as TensorValue for host torch reconstruction
|
||||||
|
registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
|
||||||
|
return
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def serialize_device(obj: Any) -> Dict[str, Any]:
|
def serialize_device(obj: Any) -> Dict[str, Any]:
|
||||||
@ -110,6 +197,54 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
|
|
||||||
registry.register("dtype", serialize_dtype, deserialize_dtype)
|
registry.register("dtype", serialize_dtype, deserialize_dtype)
|
||||||
|
|
||||||
|
from comfy_api.latest._io import FolderType
|
||||||
|
from comfy_api.latest._ui import SavedImages, SavedResult
|
||||||
|
|
||||||
|
def serialize_saved_result(obj: Any) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"__type__": "SavedResult",
|
||||||
|
"filename": obj.filename,
|
||||||
|
"subfolder": obj.subfolder,
|
||||||
|
"folder_type": obj.type.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize_saved_result(data: Dict[str, Any]) -> Any:
|
||||||
|
if isinstance(data, SavedResult):
|
||||||
|
return data
|
||||||
|
folder_type = data["folder_type"] if "folder_type" in data else data["type"]
|
||||||
|
return SavedResult(
|
||||||
|
filename=data["filename"],
|
||||||
|
subfolder=data["subfolder"],
|
||||||
|
type=FolderType(folder_type),
|
||||||
|
)
|
||||||
|
|
||||||
|
registry.register(
|
||||||
|
"SavedResult",
|
||||||
|
serialize_saved_result,
|
||||||
|
deserialize_saved_result,
|
||||||
|
data_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize_saved_images(obj: Any) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"__type__": "SavedImages",
|
||||||
|
"results": [serialize_saved_result(result) for result in obj.results],
|
||||||
|
"is_animated": obj.is_animated,
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize_saved_images(data: Dict[str, Any]) -> Any:
|
||||||
|
return SavedImages(
|
||||||
|
results=[deserialize_saved_result(result) for result in data["results"]],
|
||||||
|
is_animated=data.get("is_animated", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
registry.register(
|
||||||
|
"SavedImages",
|
||||||
|
serialize_saved_images,
|
||||||
|
deserialize_saved_images,
|
||||||
|
data_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||||
# Child-side: must already have _instance_id (proxy)
|
# Child-side: must already have _instance_id (proxy)
|
||||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||||
@ -212,28 +347,93 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
# copyreg removed - no pickle fallback allowed
|
# copyreg removed - no pickle fallback allowed
|
||||||
|
|
||||||
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
|
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
|
||||||
# Child-side: must already have _instance_id (proxy)
|
# Proxy with _instance_id — return ref (works from both host and child)
|
||||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
|
||||||
if hasattr(obj, "_instance_id"):
|
|
||||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
|
||||||
raise RuntimeError(
|
|
||||||
f"ModelSampling in child lacks _instance_id: "
|
|
||||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
|
||||||
)
|
|
||||||
# Host-side pass-through for proxies: do not re-register a proxy as a
|
|
||||||
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
|
|
||||||
if hasattr(obj, "_instance_id"):
|
if hasattr(obj, "_instance_id"):
|
||||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||||
|
# Child-side: object created locally in child (e.g. ModelSamplingAdvanced
|
||||||
|
# in nodes_z_image_turbo.py). Serialize as inline data so the host can
|
||||||
|
# reconstruct the real torch.nn.Module.
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||||
|
import base64
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
# Identify base classes from comfy.model_sampling
|
||||||
|
bases = []
|
||||||
|
for base in type(obj).__mro__:
|
||||||
|
if base.__module__ == "comfy.model_sampling" and base.__name__ != "object":
|
||||||
|
bases.append(base.__name__)
|
||||||
|
# Serialize state_dict as base64 safetensors-like
|
||||||
|
sd = obj.state_dict()
|
||||||
|
sd_serialized = {}
|
||||||
|
for k, v in sd.items():
|
||||||
|
buf = _io.BytesIO()
|
||||||
|
torch.save(v, buf)
|
||||||
|
sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||||
|
# Capture plain attrs (shift, multiplier, sigma_data, etc.)
|
||||||
|
plain_attrs = {}
|
||||||
|
for k, v in obj.__dict__.items():
|
||||||
|
if k.startswith("_"):
|
||||||
|
continue
|
||||||
|
if isinstance(v, (bool, int, float, str)):
|
||||||
|
plain_attrs[k] = v
|
||||||
|
return {
|
||||||
|
"__type__": "ModelSamplingInline",
|
||||||
|
"bases": bases,
|
||||||
|
"state_dict": sd_serialized,
|
||||||
|
"attrs": plain_attrs,
|
||||||
|
}
|
||||||
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||||
ms_id = ModelSamplingRegistry().register(obj)
|
ms_id = ModelSamplingRegistry().register(obj)
|
||||||
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||||
|
|
||||||
def deserialize_model_sampling(data: Any) -> Any:
|
def deserialize_model_sampling(data: Any) -> Any:
|
||||||
"""Deserialize ModelSampling refs; pass through already-materialized objects."""
|
"""Deserialize ModelSampling refs or inline data."""
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
if data.get("__type__") == "ModelSamplingInline":
|
||||||
|
return _reconstruct_model_sampling_inline(data)
|
||||||
return ModelSamplingProxy(data["ms_id"])
|
return ModelSamplingProxy(data["ms_id"])
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any:
|
||||||
|
"""Reconstruct a ModelSampling object on the host from inline child data."""
|
||||||
|
import comfy.model_sampling as _ms
|
||||||
|
import base64
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
# Resolve base classes
|
||||||
|
base_classes = []
|
||||||
|
for name in data["bases"]:
|
||||||
|
cls = getattr(_ms, name, None)
|
||||||
|
if cls is not None:
|
||||||
|
base_classes.append(cls)
|
||||||
|
if not base_classes:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}"
|
||||||
|
)
|
||||||
|
# Create dynamic class matching the child's class hierarchy
|
||||||
|
ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {})
|
||||||
|
obj = ReconstructedSampling.__new__(ReconstructedSampling)
|
||||||
|
torch.nn.Module.__init__(obj)
|
||||||
|
# Restore plain attributes first
|
||||||
|
for k, v in data.get("attrs", {}).items():
|
||||||
|
setattr(obj, k, v)
|
||||||
|
# Restore state_dict (buffers like sigmas)
|
||||||
|
for k, v_b64 in data.get("state_dict", {}).items():
|
||||||
|
buf = _io.BytesIO(base64.b64decode(v_b64))
|
||||||
|
tensor = torch.load(buf, weights_only=True)
|
||||||
|
# Register as buffer so it's part of state_dict
|
||||||
|
parts = k.split(".")
|
||||||
|
if len(parts) == 1:
|
||||||
|
cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member
|
||||||
|
else:
|
||||||
|
setattr(obj, parts[0], tensor)
|
||||||
|
# Register on host so future references use proxy pattern.
|
||||||
|
# Skip in child process — register() is async RPC and cannot be
|
||||||
|
# called synchronously during deserialization.
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||||
|
ModelSamplingRegistry().register(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
|
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
|
||||||
"""Context-aware ModelSamplingRef deserializer for both host and child."""
|
"""Context-aware ModelSamplingRef deserializer for both host and child."""
|
||||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
@ -262,6 +462,10 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
)
|
)
|
||||||
# Register ModelSamplingRef for deserialization (context-aware: host or child)
|
# Register ModelSamplingRef for deserialization (context-aware: host or child)
|
||||||
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
|
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
|
||||||
|
# Register ModelSamplingInline for deserialization (child→host inline transfer)
|
||||||
|
registry.register(
|
||||||
|
"ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data)
|
||||||
|
)
|
||||||
|
|
||||||
def serialize_cond(obj: Any) -> Dict[str, Any]:
|
def serialize_cond(obj: Any) -> Dict[str, Any]:
|
||||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||||
@ -393,52 +597,20 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
# Fallback for non-numeric arrays (strings, objects, mixes)
|
# Fallback for non-numeric arrays (strings, objects, mixes)
|
||||||
return obj.tolist()
|
return obj.tolist()
|
||||||
|
|
||||||
registry.register("ndarray", serialize_numpy, None)
|
def deserialize_numpy_b64(data: Any) -> Any:
|
||||||
|
"""Deserialize base64-encoded ndarray from sealed worker."""
|
||||||
def serialize_ply(obj: Any) -> Dict[str, Any]:
|
|
||||||
import base64
|
import base64
|
||||||
import torch
|
import numpy as np
|
||||||
if obj.raw_data is not None:
|
if isinstance(data, dict) and "data" in data and "dtype" in data:
|
||||||
return {
|
raw = base64.b64decode(data["data"])
|
||||||
"__type__": "PLY",
|
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
|
||||||
"raw_data": base64.b64encode(obj.raw_data).decode("ascii"),
|
return torch.from_numpy(arr.copy())
|
||||||
}
|
return data
|
||||||
result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)}
|
|
||||||
if obj.colors is not None:
|
|
||||||
result["colors"] = torch.from_numpy(obj.colors)
|
|
||||||
if obj.confidence is not None:
|
|
||||||
result["confidence"] = torch.from_numpy(obj.confidence)
|
|
||||||
if obj.view_id is not None:
|
|
||||||
result["view_id"] = torch.from_numpy(obj.view_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def deserialize_ply(data: Any) -> Any:
|
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
|
||||||
import base64
|
|
||||||
from comfy_api.latest._util.ply_types import PLY
|
|
||||||
if "raw_data" in data:
|
|
||||||
return PLY(raw_data=base64.b64decode(data["raw_data"]))
|
|
||||||
return PLY(
|
|
||||||
points=data["points"],
|
|
||||||
colors=data.get("colors"),
|
|
||||||
confidence=data.get("confidence"),
|
|
||||||
view_id=data.get("view_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
registry.register("PLY", serialize_ply, deserialize_ply, data_type=True)
|
# -- File3D (comfy_api.latest._util.geometry_types) ---------------------
|
||||||
|
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
|
||||||
def serialize_npz(obj: Any) -> Dict[str, Any]:
|
|
||||||
import base64
|
|
||||||
return {
|
|
||||||
"__type__": "NPZ",
|
|
||||||
"frames": [base64.b64encode(f).decode("ascii") for f in obj.frames],
|
|
||||||
}
|
|
||||||
|
|
||||||
def deserialize_npz(data: Any) -> Any:
|
|
||||||
import base64
|
|
||||||
from comfy_api.latest._util.npz_types import NPZ
|
|
||||||
return NPZ(frames=[base64.b64decode(f) for f in data["frames"]])
|
|
||||||
|
|
||||||
registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True)
|
|
||||||
|
|
||||||
def serialize_file3d(obj: Any) -> Dict[str, Any]:
|
def serialize_file3d(obj: Any) -> Dict[str, Any]:
|
||||||
import base64
|
import base64
|
||||||
@ -456,6 +628,9 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
|
|
||||||
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
|
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
|
||||||
|
|
||||||
|
# -- VIDEO (comfy_api.latest._input_impl.video_types) -------------------
|
||||||
|
# Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962
|
||||||
|
|
||||||
def serialize_video(obj: Any) -> Dict[str, Any]:
|
def serialize_video(obj: Any) -> Dict[str, Any]:
|
||||||
components = obj.get_components()
|
components = obj.get_components()
|
||||||
images = components.images.detach() if components.images.requires_grad else components.images
|
images = components.images.detach() if components.images.requires_grad else components.images
|
||||||
@ -494,19 +669,156 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
|
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
|
||||||
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
|
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
|
||||||
|
|
||||||
|
def setup_web_directory(self, module: Any) -> None:
|
||||||
|
"""Detect WEB_DIRECTORY on a module and populate/register it.
|
||||||
|
|
||||||
|
Called by the sealed worker after loading the node module.
|
||||||
|
Mirrors extension_wrapper.py:216-227 for host-coupled nodes.
|
||||||
|
Does NOT import extension_wrapper.py (it has `import torch` at module level).
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||||
|
if web_dir_attr is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||||
|
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||||
|
|
||||||
|
# Read extension name from pyproject.toml
|
||||||
|
ext_name = os.path.basename(module_dir)
|
||||||
|
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||||
|
if os.path.exists(pyproject):
|
||||||
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ImportError:
|
||||||
|
import tomli as tomllib # type: ignore[no-redef]
|
||||||
|
try:
|
||||||
|
with open(pyproject, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
name = data.get("project", {}).get("name")
|
||||||
|
if name:
|
||||||
|
ext_name = name
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Populate web dir if empty (mirrors _run_prestartup_web_copy)
|
||||||
|
if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))):
|
||||||
|
os.makedirs(web_dir_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Module-defined copy spec
|
||||||
|
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||||
|
if copy_spec is not None and callable(copy_spec):
|
||||||
|
try:
|
||||||
|
copy_spec(web_dir_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e)
|
||||||
|
|
||||||
|
# Fallback: comfy_3d_viewers
|
||||||
|
try:
|
||||||
|
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||||
|
for viewer in VIEWER_FILES:
|
||||||
|
try:
|
||||||
|
copy_viewer(viewer, web_dir_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: comfy_dynamic_widgets
|
||||||
|
try:
|
||||||
|
from comfy_dynamic_widgets import get_js_path
|
||||||
|
src = os.path.realpath(get_js_path())
|
||||||
|
if os.path.exists(src):
|
||||||
|
dst_dir = os.path.join(web_dir_path, "js")
|
||||||
|
os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js"))
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||||
|
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||||
|
logger.info(
|
||||||
|
"][ Adapter: registered web dir for %s (%d files)",
|
||||||
|
ext_name,
|
||||||
|
sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register_host_event_handlers(extension: Any) -> None:
|
||||||
|
"""Register host-side event handlers for an isolated extension.
|
||||||
|
|
||||||
|
Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK``
|
||||||
|
so the ComfyUI frontend receives progress bar updates.
|
||||||
|
"""
|
||||||
|
register_event_handler = inspect.getattr_static(
|
||||||
|
extension, "register_event_handler", None
|
||||||
|
)
|
||||||
|
if not callable(register_event_handler):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _host_progress_handler(payload: dict) -> None:
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
hook = comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
if hook is not None:
|
||||||
|
hook(
|
||||||
|
payload.get("value", 0),
|
||||||
|
payload.get("total", 0),
|
||||||
|
payload.get("preview"),
|
||||||
|
payload.get("node_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
extension.register_event_handler("progress", _host_progress_handler)
|
||||||
|
|
||||||
|
def setup_child_event_hooks(self, extension: Any) -> None:
|
||||||
|
"""Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension.
|
||||||
|
|
||||||
|
Host-coupled only — sealed workers do not have comfy.utils (torch).
|
||||||
|
"""
|
||||||
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child)
|
||||||
|
if not is_child:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not _IMPORT_TORCH:
|
||||||
|
logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)")
|
||||||
|
return
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
def _event_progress_hook(value, total, preview=None, node_id=None):
|
||||||
|
logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id)
|
||||||
|
extension.emit_event("progress", {
|
||||||
|
"value": value,
|
||||||
|
"total": total,
|
||||||
|
"node_id": node_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook
|
||||||
|
logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel")
|
||||||
|
|
||||||
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
|
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
|
||||||
return [
|
# Always available — no torch/PIL dependency
|
||||||
PromptServerService,
|
services: List[type[ProxiedSingleton]] = [
|
||||||
FolderPathsProxy,
|
FolderPathsProxy,
|
||||||
ModelManagementProxy,
|
HelperProxiesService,
|
||||||
UtilsProxy,
|
WebDirectoryProxy,
|
||||||
ProgressProxy,
|
|
||||||
VAERegistry,
|
|
||||||
CLIPRegistry,
|
|
||||||
ModelPatcherRegistry,
|
|
||||||
ModelSamplingRegistry,
|
|
||||||
FirstStageModelRegistry,
|
|
||||||
]
|
]
|
||||||
|
# Torch/PIL-dependent proxies
|
||||||
|
if _HAS_TORCH_PROXIES:
|
||||||
|
services.extend([
|
||||||
|
PromptServerService,
|
||||||
|
ModelManagementProxy,
|
||||||
|
UtilsProxy,
|
||||||
|
ProgressProxy,
|
||||||
|
VAERegistry,
|
||||||
|
CLIPRegistry,
|
||||||
|
ModelPatcherRegistry,
|
||||||
|
ModelSamplingRegistry,
|
||||||
|
FirstStageModelRegistry,
|
||||||
|
])
|
||||||
|
return services
|
||||||
|
|
||||||
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
|
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
|
||||||
# Resolve the real name whether it's an instance or the Singleton class itself
|
# Resolve the real name whether it's an instance or the Singleton class itself
|
||||||
@ -525,33 +837,45 @@ class ComfyUIAdapter(IsolationAdapter):
|
|||||||
|
|
||||||
# Fence: isolated children get writable temp inside sandbox
|
# Fence: isolated children get writable temp inside sandbox
|
||||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||||
_child_temp = os.path.join("/tmp", "comfyui_temp")
|
import tempfile
|
||||||
|
_child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp")
|
||||||
os.makedirs(_child_temp, exist_ok=True)
|
os.makedirs(_child_temp, exist_ok=True)
|
||||||
folder_paths.temp_directory = _child_temp
|
folder_paths.temp_directory = _child_temp
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if api_name == "ModelManagementProxy":
|
if api_name == "ModelManagementProxy":
|
||||||
import comfy.model_management
|
if _IMPORT_TORCH:
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
instance = api() if isinstance(api, type) else api
|
instance = api() if isinstance(api, type) else api
|
||||||
# Replace module-level functions with proxy methods
|
# Replace module-level functions with proxy methods
|
||||||
for name in dir(instance):
|
for name in dir(instance):
|
||||||
if not name.startswith("_"):
|
if not name.startswith("_"):
|
||||||
setattr(comfy.model_management, name, getattr(instance, name))
|
setattr(comfy.model_management, name, getattr(instance, name))
|
||||||
return
|
return
|
||||||
|
|
||||||
if api_name == "UtilsProxy":
|
if api_name == "UtilsProxy":
|
||||||
|
if not _IMPORT_TORCH:
|
||||||
|
logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)")
|
||||||
|
return
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
# Static Injection of RPC mechanism to ensure Child can access it
|
# Static Injection of RPC mechanism to ensure Child can access it
|
||||||
# independent of instance lifecycle.
|
# independent of instance lifecycle.
|
||||||
api.set_rpc(rpc)
|
api.set_rpc(rpc)
|
||||||
|
|
||||||
# Don't overwrite host hook (infinite recursion)
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child)
|
||||||
|
|
||||||
|
# Progress hook wiring moved to setup_child_event_hooks via event channel
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if api_name == "PromptServerProxy":
|
if api_name == "PromptServerProxy":
|
||||||
|
if not _IMPORT_TORCH:
|
||||||
|
return
|
||||||
# Defer heavy import to child context
|
# Defer heavy import to child context
|
||||||
import server
|
import server
|
||||||
|
|
||||||
|
|||||||
@ -11,130 +11,90 @@ def is_child_process() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def initialize_child_process() -> None:
|
def initialize_child_process() -> None:
|
||||||
|
_setup_child_loop_bridge()
|
||||||
|
|
||||||
# Manual RPC injection
|
# Manual RPC injection
|
||||||
try:
|
try:
|
||||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||||
|
|
||||||
rpc = get_child_rpc_instance()
|
rpc = get_child_rpc_instance()
|
||||||
if rpc:
|
if rpc:
|
||||||
_setup_prompt_server_stub(rpc)
|
_setup_proxy_callers(rpc)
|
||||||
_setup_utils_proxy(rpc)
|
|
||||||
else:
|
else:
|
||||||
logger.warning("Could not get child RPC instance for manual injection")
|
logger.warning("Could not get child RPC instance for manual injection")
|
||||||
_setup_prompt_server_stub()
|
_setup_proxy_callers()
|
||||||
_setup_utils_proxy()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Manual RPC Injection failed: {e}")
|
logger.error(f"Manual RPC Injection failed: {e}")
|
||||||
_setup_prompt_server_stub()
|
_setup_proxy_callers()
|
||||||
_setup_utils_proxy()
|
|
||||||
|
|
||||||
_setup_logging()
|
_setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_child_loop_bridge() -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
main_loop = None
|
||||||
|
try:
|
||||||
|
main_loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
try:
|
||||||
|
main_loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if main_loop is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .proxies.base import set_global_loop
|
||||||
|
|
||||||
|
set_global_loop(main_loop)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _setup_prompt_server_stub(rpc=None) -> None:
|
def _setup_prompt_server_stub(rpc=None) -> None:
|
||||||
try:
|
try:
|
||||||
from .proxies.prompt_server_impl import PromptServerStub
|
from .proxies.prompt_server_impl import PromptServerStub
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
|
|
||||||
# Mock server module
|
|
||||||
if "server" not in sys.modules:
|
|
||||||
mock_server = types.ModuleType("server")
|
|
||||||
sys.modules["server"] = mock_server
|
|
||||||
|
|
||||||
server = sys.modules["server"]
|
|
||||||
|
|
||||||
if not hasattr(server, "PromptServer"):
|
|
||||||
|
|
||||||
class MockPromptServer:
|
|
||||||
pass
|
|
||||||
|
|
||||||
server.PromptServer = MockPromptServer
|
|
||||||
|
|
||||||
stub = PromptServerStub()
|
|
||||||
|
|
||||||
if rpc:
|
if rpc:
|
||||||
PromptServerStub.set_rpc(rpc)
|
PromptServerStub.set_rpc(rpc)
|
||||||
if hasattr(stub, "set_rpc"):
|
elif hasattr(PromptServerStub, "clear_rpc"):
|
||||||
stub.set_rpc(rpc)
|
PromptServerStub.clear_rpc()
|
||||||
|
else:
|
||||||
server.PromptServer.instance = stub
|
PromptServerStub._rpc = None # type: ignore[attr-defined]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to setup PromptServerStub: {e}")
|
logger.error(f"Failed to setup PromptServerStub: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _setup_utils_proxy(rpc=None) -> None:
|
def _setup_proxy_callers(rpc=None) -> None:
|
||||||
try:
|
try:
|
||||||
import comfy.utils
|
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||||
import asyncio
|
from .proxies.helper_proxies import HelperProxiesService
|
||||||
|
from .proxies.model_management_proxy import ModelManagementProxy
|
||||||
|
from .proxies.progress_proxy import ProgressProxy
|
||||||
|
from .proxies.prompt_server_impl import PromptServerStub
|
||||||
|
from .proxies.utils_proxy import UtilsProxy
|
||||||
|
|
||||||
# Capture main loop during initialization (safe context)
|
if rpc is None:
|
||||||
main_loop = None
|
FolderPathsProxy.clear_rpc()
|
||||||
try:
|
HelperProxiesService.clear_rpc()
|
||||||
main_loop = asyncio.get_running_loop()
|
ModelManagementProxy.clear_rpc()
|
||||||
except RuntimeError:
|
ProgressProxy.clear_rpc()
|
||||||
try:
|
PromptServerStub.clear_rpc()
|
||||||
main_loop = asyncio.get_event_loop()
|
UtilsProxy.clear_rpc()
|
||||||
except RuntimeError:
|
return
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
FolderPathsProxy.set_rpc(rpc)
|
||||||
from .proxies.base import set_global_loop
|
HelperProxiesService.set_rpc(rpc)
|
||||||
|
ModelManagementProxy.set_rpc(rpc)
|
||||||
if main_loop:
|
ProgressProxy.set_rpc(rpc)
|
||||||
set_global_loop(main_loop)
|
PromptServerStub.set_rpc(rpc)
|
||||||
except ImportError:
|
UtilsProxy.set_rpc(rpc)
|
||||||
pass
|
|
||||||
|
|
||||||
# Sync hook wrapper for progress updates
|
|
||||||
def sync_hook_wrapper(
|
|
||||||
value: int, total: int, preview: None = None, node_id: None = None
|
|
||||||
) -> None:
|
|
||||||
if node_id is None:
|
|
||||||
try:
|
|
||||||
from comfy_execution.utils import get_executing_context
|
|
||||||
|
|
||||||
ctx = get_executing_context()
|
|
||||||
if ctx:
|
|
||||||
node_id = ctx.node_id
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Bypass blocked event loop by direct outbox injection
|
|
||||||
if rpc:
|
|
||||||
try:
|
|
||||||
# Use captured main loop if available (for threaded execution), or current loop
|
|
||||||
loop = main_loop
|
|
||||||
if loop is None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
rpc.outbox.put(
|
|
||||||
{
|
|
||||||
"kind": "call",
|
|
||||||
"object_id": "UtilsProxy",
|
|
||||||
"parent_call_id": None, # We are root here usually
|
|
||||||
"calling_loop": loop,
|
|
||||||
"future": loop.create_future(), # Dummy future
|
|
||||||
"method": "progress_bar_hook",
|
|
||||||
"args": (value, total, preview, node_id),
|
|
||||||
"kwargs": {},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.getLogger(__name__).error(f"Manual Inject Failed: {e}")
|
|
||||||
else:
|
|
||||||
logging.getLogger(__name__).warning(
|
|
||||||
"No RPC instance available for progress update"
|
|
||||||
)
|
|
||||||
|
|
||||||
comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to setup UtilsProxy hook: {e}")
|
logger.error(f"Failed to setup child singleton proxy callers: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _setup_logging() -> None:
|
def _setup_logging() -> None:
|
||||||
|
|||||||
16
comfy/isolation/custom_node_serializers.py
Normal file
16
comfy/isolation/custom_node_serializers.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
"""Compatibility shim for the indexed serializer path."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def register_custom_node_serializers(_registry: Any) -> None:
|
||||||
|
"""Legacy no-op shim.
|
||||||
|
|
||||||
|
Serializer registration now lives directly in the active isolation adapter.
|
||||||
|
This module remains importable because the isolation index still references it.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
__all__ = ["register_custom_node_serializers"]
|
||||||
@ -8,14 +8,13 @@ import sys
|
|||||||
import types
|
import types
|
||||||
import platform
|
import platform
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Tuple
|
from typing import Any, Callable, Dict, List, Tuple
|
||||||
|
|
||||||
import pyisolate
|
import pyisolate
|
||||||
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
||||||
from packaging.requirements import InvalidRequirement, Requirement
|
from packaging.requirements import InvalidRequirement, Requirement
|
||||||
from packaging.utils import canonicalize_name
|
from packaging.utils import canonicalize_name
|
||||||
|
|
||||||
from .extension_wrapper import ComfyNodeExtension
|
|
||||||
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
||||||
from .host_policy import load_host_policy
|
from .host_policy import load_host_policy
|
||||||
|
|
||||||
@ -42,7 +41,11 @@ def _register_web_directory(extension_name: str, node_dir: Path) -> None:
|
|||||||
web_dir_path = str(node_dir / web_dir_name)
|
web_dir_path = str(node_dir / web_dir_name)
|
||||||
if os.path.isdir(web_dir_path):
|
if os.path.isdir(web_dir_path):
|
||||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||||
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
|
logger.debug(
|
||||||
|
"][ Registered web dir for isolated %s: %s",
|
||||||
|
extension_name,
|
||||||
|
web_dir_path,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@ -62,15 +65,26 @@ def _register_web_directory(extension_name: str, node_dir: Path) -> None:
|
|||||||
web_dir_path = str((node_dir / value).resolve())
|
web_dir_path = str((node_dir / value).resolve())
|
||||||
if os.path.isdir(web_dir_path):
|
if os.path.isdir(web_dir_path):
|
||||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||||
logger.debug("][ Registered web dir for isolated %s: %s", extension_name, web_dir_path)
|
logger.debug(
|
||||||
|
"][ Registered web dir for isolated %s: %s",
|
||||||
|
extension_name,
|
||||||
|
web_dir_path,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def _stop_extension_safe(
|
def _get_extension_type(execution_model: str) -> type[Any]:
|
||||||
extension: ComfyNodeExtension, extension_name: str
|
if execution_model == "sealed_worker":
|
||||||
) -> None:
|
return pyisolate.SealedNodeExtension
|
||||||
|
|
||||||
|
from .extension_wrapper import ComfyNodeExtension
|
||||||
|
|
||||||
|
return ComfyNodeExtension
|
||||||
|
|
||||||
|
|
||||||
|
async def _stop_extension_safe(extension: Any, extension_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
stop_result = extension.stop()
|
stop_result = extension.stop()
|
||||||
if inspect.isawaitable(stop_result):
|
if inspect.isawaitable(stop_result):
|
||||||
@ -126,12 +140,18 @@ def _parse_cuda_wheels_config(
|
|||||||
if raw_config is None:
|
if raw_config is None:
|
||||||
return None
|
return None
|
||||||
if not isinstance(raw_config, dict):
|
if not isinstance(raw_config, dict):
|
||||||
raise ExtensionLoadError(
|
raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table")
|
||||||
"[tool.comfy.isolation.cuda_wheels] must be a table"
|
|
||||||
)
|
|
||||||
|
|
||||||
index_url = raw_config.get("index_url")
|
index_url = raw_config.get("index_url")
|
||||||
if not isinstance(index_url, str) or not index_url.strip():
|
index_urls = raw_config.get("index_urls")
|
||||||
|
if index_urls is not None:
|
||||||
|
if not isinstance(index_urls, list) or not all(
|
||||||
|
isinstance(u, str) and u.strip() for u in index_urls
|
||||||
|
):
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings"
|
||||||
|
)
|
||||||
|
elif not isinstance(index_url, str) or not index_url.strip():
|
||||||
raise ExtensionLoadError(
|
raise ExtensionLoadError(
|
||||||
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
|
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
|
||||||
)
|
)
|
||||||
@ -187,11 +207,15 @@ def _parse_cuda_wheels_config(
|
|||||||
)
|
)
|
||||||
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
|
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
|
||||||
|
|
||||||
return {
|
result: dict = {
|
||||||
"index_url": index_url.rstrip("/") + "/",
|
|
||||||
"packages": normalized_packages,
|
"packages": normalized_packages,
|
||||||
"package_map": normalized_package_map,
|
"package_map": normalized_package_map,
|
||||||
}
|
}
|
||||||
|
if index_urls is not None:
|
||||||
|
result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls]
|
||||||
|
else:
|
||||||
|
result["index_url"] = index_url.rstrip("/") + "/"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_enforcement_policy() -> Dict[str, bool]:
|
def get_enforcement_policy() -> Dict[str, bool]:
|
||||||
@ -228,7 +252,7 @@ async def load_isolated_node(
|
|||||||
node_dir: Path,
|
node_dir: Path,
|
||||||
manifest_path: Path,
|
manifest_path: Path,
|
||||||
logger: logging.Logger,
|
logger: logging.Logger,
|
||||||
build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type],
|
build_stub_class: Callable[[str, Dict[str, object], Any], type],
|
||||||
venv_root: Path,
|
venv_root: Path,
|
||||||
extension_managers: List[ExtensionManager],
|
extension_managers: List[ExtensionManager],
|
||||||
) -> List[Tuple[str, str, type]]:
|
) -> List[Tuple[str, str, type]]:
|
||||||
@ -243,6 +267,31 @@ async def load_isolated_node(
|
|||||||
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
||||||
can_isolate = tool_config.get("can_isolate", False)
|
can_isolate = tool_config.get("can_isolate", False)
|
||||||
share_torch = tool_config.get("share_torch", False)
|
share_torch = tool_config.get("share_torch", False)
|
||||||
|
package_manager = tool_config.get("package_manager", "uv")
|
||||||
|
is_conda = package_manager == "conda"
|
||||||
|
execution_model = tool_config.get("execution_model")
|
||||||
|
if execution_model is None:
|
||||||
|
execution_model = "sealed_worker" if is_conda else "host-coupled"
|
||||||
|
|
||||||
|
if "sealed_host_ro_paths" in tool_config:
|
||||||
|
raise ValueError(
|
||||||
|
"Manifest field 'sealed_host_ro_paths' is not allowed. "
|
||||||
|
"Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conda-specific manifest fields
|
||||||
|
conda_channels: list[str] = (
|
||||||
|
tool_config.get("conda_channels", []) if is_conda else []
|
||||||
|
)
|
||||||
|
conda_dependencies: list[str] = (
|
||||||
|
tool_config.get("conda_dependencies", []) if is_conda else []
|
||||||
|
)
|
||||||
|
conda_platforms: list[str] = (
|
||||||
|
tool_config.get("conda_platforms", []) if is_conda else []
|
||||||
|
)
|
||||||
|
conda_python: str = (
|
||||||
|
tool_config.get("conda_python", "*") if is_conda else "*"
|
||||||
|
)
|
||||||
|
|
||||||
# Parse [project] dependencies
|
# Parse [project] dependencies
|
||||||
project_config = manifest_data.get("project", {})
|
project_config = manifest_data.get("project", {})
|
||||||
@ -260,8 +309,6 @@ async def load_isolated_node(
|
|||||||
if not isolated:
|
if not isolated:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"][ Loading isolated node: {extension_name}")
|
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
base_paths = [Path(folder_paths.base_path), node_dir]
|
base_paths = [Path(folder_paths.base_path), node_dir]
|
||||||
@ -272,8 +319,9 @@ async def load_isolated_node(
|
|||||||
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
|
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
|
||||||
|
|
||||||
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
||||||
|
extension_type = _get_extension_type(execution_model)
|
||||||
manager: ExtensionManager = pyisolate.ExtensionManager(
|
manager: ExtensionManager = pyisolate.ExtensionManager(
|
||||||
ComfyNodeExtension, manager_config
|
extension_type, manager_config
|
||||||
)
|
)
|
||||||
extension_managers.append(manager)
|
extension_managers.append(manager)
|
||||||
|
|
||||||
@ -281,15 +329,21 @@ async def load_isolated_node(
|
|||||||
|
|
||||||
sandbox_config = {}
|
sandbox_config = {}
|
||||||
is_linux = platform.system() == "Linux"
|
is_linux = platform.system() == "Linux"
|
||||||
|
|
||||||
|
if is_conda:
|
||||||
|
share_torch = False
|
||||||
|
share_cuda_ipc = False
|
||||||
|
else:
|
||||||
|
share_cuda_ipc = share_torch and is_linux
|
||||||
|
|
||||||
if is_linux and isolated:
|
if is_linux and isolated:
|
||||||
sandbox_config = {
|
sandbox_config = {
|
||||||
"network": host_policy["allow_network"],
|
"network": host_policy["allow_network"],
|
||||||
"writable_paths": host_policy["writable_paths"],
|
"writable_paths": host_policy["writable_paths"],
|
||||||
"readonly_paths": host_policy["readonly_paths"],
|
"readonly_paths": host_policy["readonly_paths"],
|
||||||
}
|
}
|
||||||
share_cuda_ipc = share_torch and is_linux
|
|
||||||
|
|
||||||
extension_config = {
|
extension_config: dict = {
|
||||||
"name": extension_name,
|
"name": extension_name,
|
||||||
"module_path": str(node_dir),
|
"module_path": str(node_dir),
|
||||||
"isolated": True,
|
"isolated": True,
|
||||||
@ -299,17 +353,60 @@ async def load_isolated_node(
|
|||||||
"sandbox_mode": host_policy["sandbox_mode"],
|
"sandbox_mode": host_policy["sandbox_mode"],
|
||||||
"sandbox": sandbox_config,
|
"sandbox": sandbox_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_is_sealed = execution_model == "sealed_worker"
|
||||||
|
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
|
||||||
|
logger.info(
|
||||||
|
"][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])",
|
||||||
|
extension_name,
|
||||||
|
"x" if share_torch else " ",
|
||||||
|
"x" if _is_sealed else " ",
|
||||||
|
"x" if _is_sandboxed else " ",
|
||||||
|
)
|
||||||
|
|
||||||
if cuda_wheels is not None:
|
if cuda_wheels is not None:
|
||||||
extension_config["cuda_wheels"] = cuda_wheels
|
extension_config["cuda_wheels"] = cuda_wheels
|
||||||
|
|
||||||
|
# Conda-specific keys
|
||||||
|
if is_conda:
|
||||||
|
extension_config["package_manager"] = "conda"
|
||||||
|
extension_config["conda_channels"] = conda_channels
|
||||||
|
extension_config["conda_dependencies"] = conda_dependencies
|
||||||
|
extension_config["conda_python"] = conda_python
|
||||||
|
find_links = tool_config.get("find_links", [])
|
||||||
|
if find_links:
|
||||||
|
extension_config["find_links"] = find_links
|
||||||
|
if conda_platforms:
|
||||||
|
extension_config["conda_platforms"] = conda_platforms
|
||||||
|
|
||||||
|
if execution_model != "host-coupled":
|
||||||
|
extension_config["execution_model"] = execution_model
|
||||||
|
if execution_model == "sealed_worker":
|
||||||
|
policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", [])
|
||||||
|
if isinstance(policy_ro_paths, list) and policy_ro_paths:
|
||||||
|
extension_config["sealed_host_ro_paths"] = list(policy_ro_paths)
|
||||||
|
# Sealed workers keep the host RPC service inventory even when the
|
||||||
|
# child resolves no API classes locally.
|
||||||
|
|
||||||
extension = manager.load_extension(extension_config)
|
extension = manager.load_extension(extension_config)
|
||||||
register_dummy_module(extension_name, node_dir)
|
register_dummy_module(extension_name, node_dir)
|
||||||
|
|
||||||
|
# Register host-side event handlers via adapter
|
||||||
|
from .adapter import ComfyUIAdapter
|
||||||
|
ComfyUIAdapter.register_host_event_handlers(extension)
|
||||||
|
|
||||||
# Register web directory on the host — only when sandbox is disabled.
|
# Register web directory on the host — only when sandbox is disabled.
|
||||||
# In sandbox mode, serving untrusted JS to the browser is not safe.
|
# In sandbox mode, serving untrusted JS to the browser is not safe.
|
||||||
if host_policy["sandbox_mode"] == "disabled":
|
if host_policy["sandbox_mode"] == "disabled":
|
||||||
_register_web_directory(extension_name, node_dir)
|
_register_web_directory(extension_name, node_dir)
|
||||||
|
|
||||||
|
# Register for proxied web serving — the child's web dir may have
|
||||||
|
# content that doesn't exist on the host (e.g., pip-installed viewer
|
||||||
|
# bundles). The WebDirectoryCache will lazily fetch via RPC.
|
||||||
|
from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache
|
||||||
|
cache = get_web_directory_cache()
|
||||||
|
cache.register_proxy(extension_name, WebDirectoryProxy())
|
||||||
|
|
||||||
# Try cache first (lazy spawn)
|
# Try cache first (lazy spawn)
|
||||||
if is_cache_valid(node_dir, manifest_path, venv_root):
|
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||||
cached_data = load_from_cache(node_dir, venv_root)
|
cached_data = load_from_cache(node_dir, venv_root)
|
||||||
@ -379,6 +476,10 @@ async def load_isolated_node(
|
|||||||
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
||||||
logger.debug(f"][ {extension_name} metadata cached")
|
logger.debug(f"][ {extension_name} metadata cached")
|
||||||
|
|
||||||
|
# Re-check web directory AFTER child has populated it
|
||||||
|
if host_policy["sandbox_mode"] == "disabled":
|
||||||
|
_register_web_directory(extension_name, node_dir)
|
||||||
|
|
||||||
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||||
await _stop_extension_safe(extension, extension_name)
|
await _stop_extension_safe(extension, extension_name)
|
||||||
|
|
||||||
|
|||||||
@ -37,6 +37,88 @@ _PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None:
|
||||||
|
"""Run the web asset copy step that prestartup_script.py used to do.
|
||||||
|
|
||||||
|
If the module's web/ directory is empty and the module had a
|
||||||
|
prestartup_script.py that copied assets from pip packages, this
|
||||||
|
function replicates that work inside the child process.
|
||||||
|
|
||||||
|
Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if
|
||||||
|
defined, otherwise falls back to detecting common asset packages.
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Already populated — nothing to do
|
||||||
|
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||||
|
return
|
||||||
|
|
||||||
|
os.makedirs(web_dir_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Try module-defined copy spec first (generic hook for any node pack)
|
||||||
|
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||||
|
if copy_spec is not None and callable(copy_spec):
|
||||||
|
try:
|
||||||
|
copy_spec(web_dir_path)
|
||||||
|
logger.info(
|
||||||
|
"%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"%s _PRESTARTUP_WEB_COPY failed for %s: %s",
|
||||||
|
LOG_PREFIX, module_dir, e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: detect comfy_3d_viewers and run copy_viewer()
|
||||||
|
try:
|
||||||
|
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||||
|
viewers = list(VIEWER_FILES.keys())
|
||||||
|
for viewer in viewers:
|
||||||
|
try:
|
||||||
|
copy_viewer(viewer, web_dir_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if any(os.scandir(web_dir_path)):
|
||||||
|
logger.info(
|
||||||
|
"%s Copied %d viewer types from comfy_3d_viewers to %s",
|
||||||
|
LOG_PREFIX, len(viewers), web_dir_path,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: detect comfy_dynamic_widgets
|
||||||
|
try:
|
||||||
|
from comfy_dynamic_widgets import get_js_path
|
||||||
|
src = os.path.realpath(get_js_path())
|
||||||
|
if os.path.exists(src):
|
||||||
|
dst_dir = os.path.join(web_dir_path, "js")
|
||||||
|
os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
dst = os.path.join(dst_dir, "dynamic_widgets.js")
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _read_extension_name(module_dir: str) -> str:
|
||||||
|
"""Read extension name from pyproject.toml, falling back to directory name."""
|
||||||
|
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||||
|
if os.path.exists(pyproject):
|
||||||
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ImportError:
|
||||||
|
import tomli as tomllib # type: ignore[no-redef]
|
||||||
|
try:
|
||||||
|
with open(pyproject, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
name = data.get("project", {}).get("name")
|
||||||
|
if name:
|
||||||
|
return name
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return os.path.basename(module_dir)
|
||||||
|
|
||||||
|
|
||||||
def _flush_tensor_transport_state(marker: str) -> int:
|
def _flush_tensor_transport_state(marker: str) -> int:
|
||||||
try:
|
try:
|
||||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||||
@ -130,6 +212,20 @@ class ComfyNodeExtension(ExtensionBase):
|
|||||||
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||||
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||||
|
|
||||||
|
# Register web directory with WebDirectoryProxy (child-side)
|
||||||
|
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||||
|
if web_dir_attr is not None:
|
||||||
|
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||||
|
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||||
|
ext_name = _read_extension_name(module_dir)
|
||||||
|
|
||||||
|
# If web dir is empty, run the copy step that prestartup_script.py did
|
||||||
|
_run_prestartup_web_copy(module, module_dir, web_dir_path)
|
||||||
|
|
||||||
|
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||||
|
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||||
|
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from comfy_api.latest import ComfyExtension
|
from comfy_api.latest import ComfyExtension
|
||||||
|
|
||||||
@ -365,6 +461,12 @@ class ComfyNodeExtension(ExtensionBase):
|
|||||||
for key, value in hidden_found.items():
|
for key, value in hidden_found.items():
|
||||||
setattr(node_cls.hidden, key.value.lower(), value)
|
setattr(node_cls.hidden, key.value.lower(), value)
|
||||||
|
|
||||||
|
# INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this
|
||||||
|
# flag is set. The isolation RPC delivers unwrapped values, so we must
|
||||||
|
# wrap each input in a single-element list to match the contract.
|
||||||
|
if getattr(node_cls, "INPUT_IS_LIST", False):
|
||||||
|
resolved_inputs = {k: [v] for k, v in resolved_inputs.items()}
|
||||||
|
|
||||||
function_name = getattr(node_cls, "FUNCTION", "execute")
|
function_name = getattr(node_cls, "FUNCTION", "execute")
|
||||||
if not hasattr(instance, function_name):
|
if not hasattr(instance, function_name):
|
||||||
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
||||||
@ -402,7 +504,7 @@ class ComfyNodeExtension(ExtensionBase):
|
|||||||
"args": self._wrap_unpicklable_objects(result.args),
|
"args": self._wrap_unpicklable_objects(result.args),
|
||||||
}
|
}
|
||||||
if result.ui is not None:
|
if result.ui is not None:
|
||||||
node_output_dict["ui"] = result.ui
|
node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui)
|
||||||
if getattr(result, "expand", None) is not None:
|
if getattr(result, "expand", None) is not None:
|
||||||
node_output_dict["expand"] = result.expand
|
node_output_dict["expand"] = result.expand
|
||||||
if getattr(result, "block_execution", None) is not None:
|
if getattr(result, "block_execution", None) is not None:
|
||||||
@ -454,6 +556,85 @@ class ComfyNodeExtension(ExtensionBase):
|
|||||||
|
|
||||||
return self.remote_objects[object_id]
|
return self.remote_objects[object_id]
|
||||||
|
|
||||||
|
def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle:
|
||||||
|
object_id = str(uuid.uuid4())
|
||||||
|
self.remote_objects[object_id] = obj
|
||||||
|
return RemoteObjectHandle(object_id, type(obj).__name__)
|
||||||
|
|
||||||
|
async def call_remote_object_method(
|
||||||
|
self,
|
||||||
|
object_id: str,
|
||||||
|
method_name: str,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Invoke a method or attribute-backed accessor on a child-owned object."""
|
||||||
|
obj = await self.get_remote_object(object_id)
|
||||||
|
|
||||||
|
if method_name == "get_patcher_attr":
|
||||||
|
return getattr(obj, args[0])
|
||||||
|
if method_name == "get_model_options":
|
||||||
|
return getattr(obj, "model_options")
|
||||||
|
if method_name == "set_model_options":
|
||||||
|
setattr(obj, "model_options", args[0])
|
||||||
|
return None
|
||||||
|
if method_name == "get_object_patches":
|
||||||
|
return getattr(obj, "object_patches")
|
||||||
|
if method_name == "get_patches":
|
||||||
|
return getattr(obj, "patches")
|
||||||
|
if method_name == "get_wrappers":
|
||||||
|
return getattr(obj, "wrappers")
|
||||||
|
if method_name == "get_callbacks":
|
||||||
|
return getattr(obj, "callbacks")
|
||||||
|
if method_name == "get_load_device":
|
||||||
|
return getattr(obj, "load_device")
|
||||||
|
if method_name == "get_offload_device":
|
||||||
|
return getattr(obj, "offload_device")
|
||||||
|
if method_name == "get_hook_mode":
|
||||||
|
return getattr(obj, "hook_mode")
|
||||||
|
if method_name == "get_parent":
|
||||||
|
parent = getattr(obj, "parent", None)
|
||||||
|
if parent is None:
|
||||||
|
return None
|
||||||
|
return self._store_remote_object_handle(parent)
|
||||||
|
if method_name == "get_inner_model_attr":
|
||||||
|
attr_name = args[0]
|
||||||
|
if hasattr(obj.model, attr_name):
|
||||||
|
return getattr(obj.model, attr_name)
|
||||||
|
if hasattr(obj, attr_name):
|
||||||
|
return getattr(obj, attr_name)
|
||||||
|
return None
|
||||||
|
if method_name == "inner_model_apply_model":
|
||||||
|
return obj.model.apply_model(*args[0], **args[1])
|
||||||
|
if method_name == "inner_model_extra_conds_shapes":
|
||||||
|
return obj.model.extra_conds_shapes(*args[0], **args[1])
|
||||||
|
if method_name == "inner_model_extra_conds":
|
||||||
|
return obj.model.extra_conds(*args[0], **args[1])
|
||||||
|
if method_name == "inner_model_memory_required":
|
||||||
|
return obj.model.memory_required(*args[0], **args[1])
|
||||||
|
if method_name == "process_latent_in":
|
||||||
|
return obj.model.process_latent_in(*args[0], **args[1])
|
||||||
|
if method_name == "process_latent_out":
|
||||||
|
return obj.model.process_latent_out(*args[0], **args[1])
|
||||||
|
if method_name == "scale_latent_inpaint":
|
||||||
|
return obj.model.scale_latent_inpaint(*args[0], **args[1])
|
||||||
|
if method_name.startswith("get_"):
|
||||||
|
attr_name = method_name[4:]
|
||||||
|
if hasattr(obj, attr_name):
|
||||||
|
return getattr(obj, attr_name)
|
||||||
|
|
||||||
|
target = getattr(obj, method_name)
|
||||||
|
if callable(target):
|
||||||
|
result = target(*args, **kwargs)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
result = await result
|
||||||
|
if type(result).__name__ == "ModelPatcher":
|
||||||
|
return self._store_remote_object_handle(result)
|
||||||
|
return result
|
||||||
|
if args or kwargs:
|
||||||
|
raise TypeError(f"{method_name} is not callable on remote object {object_id}")
|
||||||
|
return target
|
||||||
|
|
||||||
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
||||||
if isinstance(data, (str, int, float, bool, type(None))):
|
if isinstance(data, (str, int, float, bool, type(None))):
|
||||||
return data
|
return data
|
||||||
@ -514,9 +695,7 @@ class ComfyNodeExtension(ExtensionBase):
|
|||||||
if serializer:
|
if serializer:
|
||||||
return serializer(data)
|
return serializer(data)
|
||||||
|
|
||||||
object_id = str(uuid.uuid4())
|
return self._store_remote_object_handle(data)
|
||||||
self.remote_objects[object_id] = data
|
|
||||||
return RemoteObjectHandle(object_id, type(data).__name__)
|
|
||||||
|
|
||||||
def _resolve_remote_objects(self, data: Any) -> Any:
|
def _resolve_remote_objects(self, data: Any) -> Any:
|
||||||
if isinstance(data, RemoteObjectHandle):
|
if isinstance(data, RemoteObjectHandle):
|
||||||
|
|||||||
@ -12,15 +12,19 @@ def initialize_host_process() -> None:
|
|||||||
root.addHandler(logging.NullHandler())
|
root.addHandler(logging.NullHandler())
|
||||||
|
|
||||||
from .proxies.folder_paths_proxy import FolderPathsProxy
|
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||||
|
from .proxies.helper_proxies import HelperProxiesService
|
||||||
from .proxies.model_management_proxy import ModelManagementProxy
|
from .proxies.model_management_proxy import ModelManagementProxy
|
||||||
from .proxies.progress_proxy import ProgressProxy
|
from .proxies.progress_proxy import ProgressProxy
|
||||||
from .proxies.prompt_server_impl import PromptServerService
|
from .proxies.prompt_server_impl import PromptServerService
|
||||||
from .proxies.utils_proxy import UtilsProxy
|
from .proxies.utils_proxy import UtilsProxy
|
||||||
|
from .proxies.web_directory_proxy import WebDirectoryProxy
|
||||||
from .vae_proxy import VAERegistry
|
from .vae_proxy import VAERegistry
|
||||||
|
|
||||||
FolderPathsProxy()
|
FolderPathsProxy()
|
||||||
|
HelperProxiesService()
|
||||||
ModelManagementProxy()
|
ModelManagementProxy()
|
||||||
ProgressProxy()
|
ProgressProxy()
|
||||||
PromptServerService()
|
PromptServerService()
|
||||||
UtilsProxy()
|
UtilsProxy()
|
||||||
|
WebDirectoryProxy()
|
||||||
VAERegistry()
|
VAERegistry()
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from pathlib import PurePosixPath
|
||||||
from typing import Dict, List, TypedDict
|
from typing import Dict, List, TypedDict
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
|
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
|
||||||
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
|
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
|
||||||
|
FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"})
|
||||||
|
|
||||||
|
|
||||||
class HostSecurityPolicy(TypedDict):
|
class HostSecurityPolicy(TypedDict):
|
||||||
@ -22,14 +24,16 @@ class HostSecurityPolicy(TypedDict):
|
|||||||
allow_network: bool
|
allow_network: bool
|
||||||
writable_paths: List[str]
|
writable_paths: List[str]
|
||||||
readonly_paths: List[str]
|
readonly_paths: List[str]
|
||||||
|
sealed_worker_ro_import_paths: List[str]
|
||||||
whitelist: Dict[str, str]
|
whitelist: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_POLICY: HostSecurityPolicy = {
|
DEFAULT_POLICY: HostSecurityPolicy = {
|
||||||
"sandbox_mode": "required",
|
"sandbox_mode": "required",
|
||||||
"allow_network": False,
|
"allow_network": False,
|
||||||
"writable_paths": ["/dev/shm", "/tmp"],
|
"writable_paths": ["/dev/shm"],
|
||||||
"readonly_paths": [],
|
"readonly_paths": [],
|
||||||
|
"sealed_worker_ro_import_paths": [],
|
||||||
"whitelist": {},
|
"whitelist": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,10 +44,68 @@ def _default_policy() -> HostSecurityPolicy:
|
|||||||
"allow_network": DEFAULT_POLICY["allow_network"],
|
"allow_network": DEFAULT_POLICY["allow_network"],
|
||||||
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
||||||
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
||||||
|
"sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]),
|
||||||
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_writable_paths(paths: list[object]) -> list[str]:
|
||||||
|
normalized_paths: list[str] = []
|
||||||
|
for raw_path in paths:
|
||||||
|
# Host-policy paths are contract-style POSIX paths; keep representation
|
||||||
|
# stable across Windows/Linux so tests and config behavior stay consistent.
|
||||||
|
normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/")))
|
||||||
|
if normalized_path in FORBIDDEN_WRITABLE_PATHS:
|
||||||
|
continue
|
||||||
|
normalized_paths.append(normalized_path)
|
||||||
|
return normalized_paths
|
||||||
|
|
||||||
|
|
||||||
|
def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]:
|
||||||
|
if not file_path.is_absolute():
|
||||||
|
file_path = config_path.parent / file_path
|
||||||
|
if not file_path.exists():
|
||||||
|
logger.warning("whitelist_file %s not found, skipping.", file_path)
|
||||||
|
return {}
|
||||||
|
entries: Dict[str, str] = {}
|
||||||
|
for line in file_path.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
entries[line] = "*"
|
||||||
|
logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]:
|
||||||
|
if not isinstance(raw_paths, list):
|
||||||
|
raise ValueError(
|
||||||
|
"tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths."
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_paths: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for raw_path in raw_paths:
|
||||||
|
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||||
|
raise ValueError(
|
||||||
|
"tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings."
|
||||||
|
)
|
||||||
|
normalized_path = str(PurePosixPath(raw_path.replace("\\", "/")))
|
||||||
|
# Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...)
|
||||||
|
is_absolute = normalized_path.startswith("/") or (
|
||||||
|
len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/"
|
||||||
|
)
|
||||||
|
if not is_absolute:
|
||||||
|
raise ValueError(
|
||||||
|
"tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths."
|
||||||
|
)
|
||||||
|
if normalized_path not in seen:
|
||||||
|
seen.add(normalized_path)
|
||||||
|
normalized_paths.append(normalized_path)
|
||||||
|
|
||||||
|
return normalized_paths
|
||||||
|
|
||||||
|
|
||||||
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||||
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
|
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
|
||||||
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
|
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
|
||||||
@ -86,14 +148,23 @@ def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
|||||||
policy["allow_network"] = bool(tool_config["allow_network"])
|
policy["allow_network"] = bool(tool_config["allow_network"])
|
||||||
|
|
||||||
if "writable_paths" in tool_config:
|
if "writable_paths" in tool_config:
|
||||||
policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]]
|
policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"])
|
||||||
|
|
||||||
if "readonly_paths" in tool_config:
|
if "readonly_paths" in tool_config:
|
||||||
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
||||||
|
|
||||||
|
if "sealed_worker_ro_import_paths" in tool_config:
|
||||||
|
policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths(
|
||||||
|
tool_config["sealed_worker_ro_import_paths"]
|
||||||
|
)
|
||||||
|
|
||||||
|
whitelist_file = tool_config.get("whitelist_file")
|
||||||
|
if isinstance(whitelist_file, str):
|
||||||
|
policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path))
|
||||||
|
|
||||||
whitelist_raw = tool_config.get("whitelist")
|
whitelist_raw = tool_config.get("whitelist")
|
||||||
if isinstance(whitelist_raw, dict):
|
if isinstance(whitelist_raw, dict):
|
||||||
policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()}
|
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
|
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
|
||||||
|
|||||||
@ -24,6 +24,49 @@ CACHE_SUBDIR = "cache"
|
|||||||
CACHE_KEY_FILE = "cache_key"
|
CACHE_KEY_FILE = "cache_key"
|
||||||
CACHE_DATA_FILE = "node_info.json"
|
CACHE_DATA_FILE = "node_info.json"
|
||||||
CACHE_KEY_LENGTH = 16
|
CACHE_KEY_LENGTH = 16
|
||||||
|
_NESTED_SCAN_ROOT = "packages"
|
||||||
|
_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"}
|
||||||
|
|
||||||
|
|
||||||
|
def _read_manifest(manifest_path: Path) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
with manifest_path.open("rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_isolation_manifest(data: dict[str, Any]) -> bool:
|
||||||
|
return (
|
||||||
|
"tool" in data
|
||||||
|
and "comfy" in data["tool"]
|
||||||
|
and "isolation" in data["tool"]["comfy"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]:
|
||||||
|
packages_root = entry / _NESTED_SCAN_ROOT
|
||||||
|
if not packages_root.exists() or not packages_root.is_dir():
|
||||||
|
return []
|
||||||
|
|
||||||
|
nested: List[Tuple[Path, Path]] = []
|
||||||
|
for manifest in sorted(packages_root.rglob("pyproject.toml")):
|
||||||
|
node_dir = manifest.parent
|
||||||
|
if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts):
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = _read_manifest(manifest)
|
||||||
|
if not data or not _is_isolation_manifest(data):
|
||||||
|
continue
|
||||||
|
|
||||||
|
isolation = data["tool"]["comfy"]["isolation"]
|
||||||
|
if isolation.get("standalone") is True:
|
||||||
|
nested.append((node_dir, manifest))
|
||||||
|
|
||||||
|
return nested
|
||||||
|
|
||||||
|
|
||||||
def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
||||||
@ -45,21 +88,13 @@ def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
|||||||
if not manifest.exists():
|
if not manifest.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Validate [tool.comfy.isolation] section existence
|
data = _read_manifest(manifest)
|
||||||
try:
|
if not data or not _is_isolation_manifest(data):
|
||||||
with manifest.open("rb") as f:
|
|
||||||
data = tomllib.load(f)
|
|
||||||
|
|
||||||
if (
|
|
||||||
"tool" in data
|
|
||||||
and "comfy" in data["tool"]
|
|
||||||
and "isolation" in data["tool"]["comfy"]
|
|
||||||
):
|
|
||||||
manifest_dirs.append((entry, manifest))
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
manifest_dirs.append((entry, manifest))
|
||||||
|
manifest_dirs.extend(_discover_nested_manifests(entry))
|
||||||
|
|
||||||
return manifest_dirs
|
return manifest_dirs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,10 +8,17 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
||||||
|
|
||||||
from .proxies.helper_proxies import restore_input_types
|
from .proxies.helper_proxies import restore_input_types
|
||||||
from comfy_api.internal import _ComfyNodeInternal
|
|
||||||
from comfy_api.latest import _io as latest_io
|
|
||||||
from .shm_forensics import scan_shm_forensics
|
from .shm_forensics import scan_shm_forensics
|
||||||
|
|
||||||
|
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||||
|
|
||||||
|
_ComfyNodeInternal = object
|
||||||
|
latest_io = None
|
||||||
|
|
||||||
|
if _IMPORT_TORCH:
|
||||||
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
|
from comfy_api.latest import _io as latest_io
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .extension_wrapper import ComfyNodeExtension
|
from .extension_wrapper import ComfyNodeExtension
|
||||||
|
|
||||||
@ -19,6 +26,68 @@ LOG_PREFIX = "]["
|
|||||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class _RemoteObjectRegistryCaller:
|
||||||
|
def __init__(self, extension: Any) -> None:
|
||||||
|
self._extension = extension
|
||||||
|
|
||||||
|
def __getattr__(self, method_name: str) -> Any:
|
||||||
|
async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
return await self._extension.call_remote_object_method(
|
||||||
|
instance_id,
|
||||||
|
method_name,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _call
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any:
|
||||||
|
from pyisolate._internal.remote_handle import RemoteObjectHandle
|
||||||
|
|
||||||
|
if isinstance(value, RemoteObjectHandle):
|
||||||
|
if value.type_name == "ModelPatcher":
|
||||||
|
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||||
|
|
||||||
|
proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False)
|
||||||
|
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||||
|
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||||
|
return proxy
|
||||||
|
if value.type_name == "VAE":
|
||||||
|
from comfy.isolation.vae_proxy import VAEProxy
|
||||||
|
|
||||||
|
proxy = VAEProxy(value.object_id, manage_lifecycle=False)
|
||||||
|
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||||
|
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||||
|
return proxy
|
||||||
|
if value.type_name == "CLIP":
|
||||||
|
from comfy.isolation.clip_proxy import CLIPProxy
|
||||||
|
|
||||||
|
proxy = CLIPProxy(value.object_id, manage_lifecycle=False)
|
||||||
|
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||||
|
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||||
|
return proxy
|
||||||
|
if value.type_name == "ModelSampling":
|
||||||
|
from comfy.isolation.model_sampling_proxy import ModelSamplingProxy
|
||||||
|
|
||||||
|
proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False)
|
||||||
|
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||||
|
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||||
|
return proxy
|
||||||
|
return value
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {
|
||||||
|
k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value]
|
||||||
|
return type(value)(wrapped)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _resource_snapshot() -> Dict[str, int]:
|
def _resource_snapshot() -> Dict[str, int]:
|
||||||
fd_count = -1
|
fd_count = -1
|
||||||
shm_sender_files = 0
|
shm_sender_files = 0
|
||||||
@ -146,6 +215,8 @@ def build_stub_class(
|
|||||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||||
logger: logging.Logger,
|
logger: logging.Logger,
|
||||||
) -> type:
|
) -> type:
|
||||||
|
if latest_io is None:
|
||||||
|
raise RuntimeError("comfy_api.latest._io is required to build isolation stubs")
|
||||||
is_v3 = bool(info.get("is_v3", False))
|
is_v3 = bool(info.get("is_v3", False))
|
||||||
function_name = "_pyisolate_execute"
|
function_name = "_pyisolate_execute"
|
||||||
restored_input_types = restore_input_types(info.get("input_types", {}))
|
restored_input_types = restore_input_types(info.get("input_types", {}))
|
||||||
@ -160,6 +231,13 @@ def build_stub_class(
|
|||||||
node_unique_id = _extract_hidden_unique_id(inputs)
|
node_unique_id = _extract_hidden_unique_id(inputs)
|
||||||
summary = _tensor_transport_summary(inputs)
|
summary = _tensor_transport_summary(inputs)
|
||||||
resources = _resource_snapshot()
|
resources = _resource_snapshot()
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:execute_start ext=%s node=%s uid=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
||||||
LOG_PREFIX,
|
LOG_PREFIX,
|
||||||
@ -192,7 +270,20 @@ def build_stub_class(
|
|||||||
node_name,
|
node_name,
|
||||||
node_unique_id or "-",
|
node_unique_id or "-",
|
||||||
)
|
)
|
||||||
serialized = serialize_for_isolation(inputs)
|
# Unwrap NodeOutput-like dicts before serialization.
|
||||||
|
# OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)}
|
||||||
|
# and the executor may pass this dict as input to downstream nodes.
|
||||||
|
unwrapped_inputs = {}
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v):
|
||||||
|
result = v.get("result")
|
||||||
|
if isinstance(result, (tuple, list)) and len(result) > 0:
|
||||||
|
unwrapped_inputs[k] = result[0]
|
||||||
|
else:
|
||||||
|
unwrapped_inputs[k] = result
|
||||||
|
else:
|
||||||
|
unwrapped_inputs[k] = v
|
||||||
|
serialized = serialize_for_isolation(unwrapped_inputs)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
||||||
LOG_PREFIX,
|
LOG_PREFIX,
|
||||||
@ -220,15 +311,32 @@ def build_stub_class(
|
|||||||
from comfy_api.latest import io as latest_io
|
from comfy_api.latest import io as latest_io
|
||||||
args_raw = result.get("args", ())
|
args_raw = result.get("args", ())
|
||||||
deserialized_args = await deserialize_from_isolation(args_raw, extension)
|
deserialized_args = await deserialize_from_isolation(args_raw, extension)
|
||||||
|
deserialized_args = _wrap_remote_handles_as_host_proxies(
|
||||||
|
deserialized_args, extension
|
||||||
|
)
|
||||||
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
|
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
|
||||||
|
ui_raw = result.get("ui")
|
||||||
|
deserialized_ui = None
|
||||||
|
if ui_raw is not None:
|
||||||
|
deserialized_ui = await deserialize_from_isolation(ui_raw, extension)
|
||||||
|
deserialized_ui = _wrap_remote_handles_as_host_proxies(
|
||||||
|
deserialized_ui, extension
|
||||||
|
)
|
||||||
|
deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui)
|
||||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||||
return latest_io.NodeOutput(
|
return latest_io.NodeOutput(
|
||||||
*deserialized_args,
|
*deserialized_args,
|
||||||
ui=result.get("ui"),
|
ui=deserialized_ui,
|
||||||
expand=result.get("expand"),
|
expand=result.get("expand"),
|
||||||
block_execution=result.get("block_execution"),
|
block_execution=result.get("block_execution"),
|
||||||
)
|
)
|
||||||
|
# OUTPUT_NODE: if sealed worker returned a tuple/list whose first
|
||||||
|
# element is a {"ui": ...} dict, unwrap it for the executor.
|
||||||
|
if (isinstance(result, (tuple, list)) and len(result) == 1
|
||||||
|
and isinstance(result[0], dict) and "ui" in result[0]):
|
||||||
|
return result[0]
|
||||||
deserialized = await deserialize_from_isolation(result, extension)
|
deserialized = await deserialize_from_isolation(result, extension)
|
||||||
|
deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension)
|
||||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||||
return _detach_shared_cpu_tensors(deserialized)
|
return _detach_shared_cpu_tensors(deserialized)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user