mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 04:22:31 +08:00
feat(isolation): core infrastructure and pyisolate integration
Adds the isolation system foundation: ComfyUIAdapter, extension loader, manifest discovery, child/host process hooks, RPC bridge, runtime helpers, SHM forensics, and the --use-process-isolation CLI flag. pyisolate added to requirements.txt. .pyisolate_venvs/ added to .gitignore.
This commit is contained in:
parent
b615af1c65
commit
7d512fa9c3
1
.gitignore
vendored
1
.gitignore
vendored
@ -24,3 +24,4 @@ web_custom_versions/
|
|||||||
openapi.yaml
|
openapi.yaml
|
||||||
filtered-openapi.yaml
|
filtered-openapi.yaml
|
||||||
uv.lock
|
uv.lock
|
||||||
|
.pyisolate_venvs/
|
||||||
|
|||||||
@ -184,6 +184,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo
|
|||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
|
parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.")
|
||||||
|
|
||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||||
|
|
||||||
|
|||||||
442
comfy/isolation/__init__.py
Normal file
442
comfy/isolation/__init__.py
Normal file
@ -0,0 +1,442 @@
|
|||||||
|
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||||
|
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||||
|
|
||||||
|
load_isolated_node = None
|
||||||
|
find_manifest_directories = None
|
||||||
|
build_stub_class = None
|
||||||
|
get_class_types_for_extension = None
|
||||||
|
scan_shm_forensics = None
|
||||||
|
start_shm_forensics = None
|
||||||
|
|
||||||
|
if _IMPORT_TORCH:
|
||||||
|
import folder_paths
|
||||||
|
from .extension_loader import load_isolated_node
|
||||||
|
from .manifest_loader import find_manifest_directories
|
||||||
|
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
||||||
|
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pyisolate import ExtensionManager
|
||||||
|
from .extension_wrapper import ComfyNodeExtension
|
||||||
|
|
||||||
|
LOG_PREFIX = "]["
|
||||||
|
isolated_node_timings: List[tuple[float, Path, int]] = []
|
||||||
|
|
||||||
|
if _IMPORT_TORCH:
|
||||||
|
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
||||||
|
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||||
|
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_proxies() -> None:
|
||||||
|
from .child_hooks import is_child_process
|
||||||
|
|
||||||
|
is_child = is_child_process()
|
||||||
|
logger.warning(
|
||||||
|
"%s DIAG:initialize_proxies | is_child=%s | PYISOLATE_CHILD=%s",
|
||||||
|
LOG_PREFIX, is_child, os.environ.get("PYISOLATE_CHILD"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_child:
|
||||||
|
from .child_hooks import initialize_child_process
|
||||||
|
|
||||||
|
initialize_child_process()
|
||||||
|
logger.warning("%s DIAG:initialize_proxies child_process initialized", LOG_PREFIX)
|
||||||
|
else:
|
||||||
|
from .host_hooks import initialize_host_process
|
||||||
|
|
||||||
|
initialize_host_process()
|
||||||
|
logger.warning("%s DIAG:initialize_proxies host_process initialized", LOG_PREFIX)
|
||||||
|
if start_shm_forensics is not None:
|
||||||
|
start_shm_forensics()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class IsolatedNodeSpec:
|
||||||
|
node_name: str
|
||||||
|
display_name: str
|
||||||
|
stub_class: type
|
||||||
|
module_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
|
||||||
|
_CLAIMED_PATHS: Set[Path] = set()
|
||||||
|
_ISOLATION_SCAN_ATTEMPTED = False
|
||||||
|
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
|
||||||
|
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
|
||||||
|
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
|
||||||
|
_EARLY_START_TIME: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
|
||||||
|
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||||
|
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||||
|
return
|
||||||
|
_EARLY_START_TIME = time.perf_counter()
|
||||||
|
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
|
||||||
|
|
||||||
|
|
||||||
|
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
|
||||||
|
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||||
|
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||||
|
specs = await _ISOLATION_BACKGROUND_TASK
|
||||||
|
return specs
|
||||||
|
return await initialize_isolation_nodes()
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
|
||||||
|
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
|
||||||
|
|
||||||
|
if _ISOLATED_NODE_SPECS:
|
||||||
|
return _ISOLATED_NODE_SPECS
|
||||||
|
|
||||||
|
if _ISOLATION_SCAN_ATTEMPTED:
|
||||||
|
return []
|
||||||
|
|
||||||
|
_ISOLATION_SCAN_ATTEMPTED = True
|
||||||
|
if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None:
|
||||||
|
return []
|
||||||
|
manifest_entries = find_manifest_directories()
|
||||||
|
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
||||||
|
|
||||||
|
if not manifest_entries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
|
||||||
|
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
|
||||||
|
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||||
|
|
||||||
|
async def load_with_semaphore(
|
||||||
|
node_dir: Path, manifest: Path
|
||||||
|
) -> List[IsolatedNodeSpec]:
|
||||||
|
async with semaphore:
|
||||||
|
load_start = time.perf_counter()
|
||||||
|
spec_list = await load_isolated_node(
|
||||||
|
node_dir,
|
||||||
|
manifest,
|
||||||
|
logger,
|
||||||
|
lambda name, info, extension: build_stub_class(
|
||||||
|
name,
|
||||||
|
info,
|
||||||
|
extension,
|
||||||
|
_RUNNING_EXTENSIONS,
|
||||||
|
logger,
|
||||||
|
),
|
||||||
|
PYISOLATE_VENV_ROOT,
|
||||||
|
_EXTENSION_MANAGERS,
|
||||||
|
)
|
||||||
|
spec_list = [
|
||||||
|
IsolatedNodeSpec(
|
||||||
|
node_name=node_name,
|
||||||
|
display_name=display_name,
|
||||||
|
stub_class=stub_cls,
|
||||||
|
module_path=node_dir,
|
||||||
|
)
|
||||||
|
for node_name, display_name, stub_cls in spec_list
|
||||||
|
]
|
||||||
|
isolated_node_timings.append(
|
||||||
|
(time.perf_counter() - load_start, node_dir, len(spec_list))
|
||||||
|
)
|
||||||
|
return spec_list
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
load_with_semaphore(node_dir, manifest)
|
||||||
|
for node_dir, manifest in manifest_entries
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
specs: List[IsolatedNodeSpec] = []
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(
|
||||||
|
"%s Isolated node failed during startup; continuing: %s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
specs.extend(result)
|
||||||
|
|
||||||
|
_ISOLATED_NODE_SPECS = specs
|
||||||
|
return list(_ISOLATED_NODE_SPECS)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||||
|
"""Get all node class types (node names) belonging to an extension."""
|
||||||
|
extension = _RUNNING_EXTENSIONS.get(extension_name)
|
||||||
|
if not extension:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
ext_path = Path(extension.module_path)
|
||||||
|
class_types = set()
|
||||||
|
for spec in _ISOLATED_NODE_SPECS:
|
||||||
|
if spec.module_path.resolve() == ext_path.resolve():
|
||||||
|
class_types.add(spec.node_name)
|
||||||
|
|
||||||
|
return class_types
|
||||||
|
|
||||||
|
|
||||||
|
async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None:
|
||||||
|
"""Evict running extensions not needed for current execution.
|
||||||
|
|
||||||
|
When *caches* is provided, cache entries for evicted extensions' node
|
||||||
|
class_types are invalidated to prevent stale ``RemoteObjectHandle``
|
||||||
|
references from surviving in the output cache.
|
||||||
|
"""
|
||||||
|
await wait_for_model_patcher_quiescence(
|
||||||
|
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||||
|
fail_loud=True,
|
||||||
|
marker="ISO:notify_graph_wait_idle",
|
||||||
|
)
|
||||||
|
|
||||||
|
evicted_class_types: Set[str] = set()
|
||||||
|
|
||||||
|
async def _stop_extension(
|
||||||
|
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||||
|
) -> None:
|
||||||
|
# Collect class_types BEFORE stopping so we can invalidate cache entries.
|
||||||
|
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||||
|
evicted_class_types.update(ext_class_types)
|
||||||
|
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
||||||
|
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
||||||
|
stop_result = extension.stop()
|
||||||
|
if inspect.isawaitable(stop_result):
|
||||||
|
await stop_result
|
||||||
|
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
||||||
|
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
||||||
|
if scan_shm_forensics is not None:
|
||||||
|
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||||
|
|
||||||
|
if scan_shm_forensics is not None:
|
||||||
|
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||||
|
isolated_class_types_in_graph = needed_class_types.intersection(
|
||||||
|
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
|
||||||
|
)
|
||||||
|
graph_uses_isolation = bool(isolated_class_types_in_graph)
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:notify_graph_start running=%d needed=%d",
|
||||||
|
LOG_PREFIX,
|
||||||
|
len(_RUNNING_EXTENSIONS),
|
||||||
|
len(needed_class_types),
|
||||||
|
)
|
||||||
|
if graph_uses_isolation:
|
||||||
|
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||||
|
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||||
|
|
||||||
|
# If NONE of this extension's nodes are in the execution graph -> evict.
|
||||||
|
if not ext_class_types.intersection(needed_class_types):
|
||||||
|
await _stop_extension(
|
||||||
|
ext_name,
|
||||||
|
extension,
|
||||||
|
"isolated custom_node not in execution graph, evicting",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
|
||||||
|
LOG_PREFIX,
|
||||||
|
len(_RUNNING_EXTENSIONS),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
||||||
|
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
||||||
|
try:
|
||||||
|
import comfy.model_management as model_management
|
||||||
|
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
if getattr(device, "type", None) == "cuda":
|
||||||
|
required = max(
|
||||||
|
model_management.minimum_inference_memory(),
|
||||||
|
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
||||||
|
)
|
||||||
|
free_before = model_management.get_free_memory(device)
|
||||||
|
if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
|
||||||
|
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||||
|
await _stop_extension(
|
||||||
|
ext_name,
|
||||||
|
extension,
|
||||||
|
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
|
||||||
|
)
|
||||||
|
if model_management.get_free_memory(device) < required:
|
||||||
|
model_management.unload_all_models()
|
||||||
|
model_management.cleanup_models_gc()
|
||||||
|
model_management.cleanup_models()
|
||||||
|
if model_management.get_free_memory(device) < required:
|
||||||
|
model_management.free_memory(required, device, for_dynamic=False)
|
||||||
|
model_management.soft_empty_cache()
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Invalidate cached outputs for evicted extensions so stale
|
||||||
|
# RemoteObjectHandle references are not served from cache.
|
||||||
|
if evicted_class_types and caches:
|
||||||
|
total_invalidated = 0
|
||||||
|
for cache in caches:
|
||||||
|
if hasattr(cache, "invalidate_by_class_types"):
|
||||||
|
total_invalidated += cache.invalidate_by_class_types(
|
||||||
|
evicted_class_types
|
||||||
|
)
|
||||||
|
if total_invalidated > 0:
|
||||||
|
logger.info(
|
||||||
|
"%s ISO:cache_invalidated count=%d class_types=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
total_invalidated,
|
||||||
|
evicted_class_types,
|
||||||
|
)
|
||||||
|
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def flush_running_extensions_transport_state() -> int:
|
||||||
|
await wait_for_model_patcher_quiescence(
|
||||||
|
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||||
|
fail_loud=True,
|
||||||
|
marker="ISO:flush_transport_wait_idle",
|
||||||
|
)
|
||||||
|
total_flushed = 0
|
||||||
|
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||||
|
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||||
|
if not callable(flush_fn):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
flushed = await flush_fn()
|
||||||
|
if isinstance(flushed, int):
|
||||||
|
total_flushed += flushed
|
||||||
|
if flushed > 0:
|
||||||
|
logger.debug(
|
||||||
|
"%s %s workflow-end flush released=%d",
|
||||||
|
LOG_PREFIX,
|
||||||
|
ext_name,
|
||||||
|
flushed,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
|
||||||
|
)
|
||||||
|
scan_shm_forensics(
|
||||||
|
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
|
||||||
|
)
|
||||||
|
return total_flushed
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_model_patcher_quiescence(
|
||||||
|
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||||
|
*,
|
||||||
|
fail_loud: bool = False,
|
||||||
|
marker: str = "ISO:wait_model_patcher_idle",
|
||||||
|
) -> bool:
|
||||||
|
try:
|
||||||
|
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||||
|
|
||||||
|
registry = ModelPatcherRegistry()
|
||||||
|
start = time.perf_counter()
|
||||||
|
idle = await registry.wait_all_idle(timeout_ms)
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||||
|
if idle:
|
||||||
|
logger.debug(
|
||||||
|
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
|
||||||
|
LOG_PREFIX,
|
||||||
|
marker,
|
||||||
|
timeout_ms,
|
||||||
|
elapsed_ms,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
states = await registry.get_all_operation_states()
|
||||||
|
logger.error(
|
||||||
|
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
marker,
|
||||||
|
timeout_ms,
|
||||||
|
elapsed_ms,
|
||||||
|
states,
|
||||||
|
)
|
||||||
|
if fail_loud:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
if fail_loud:
|
||||||
|
raise
|
||||||
|
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_claimed_paths() -> Set[Path]:
|
||||||
|
return _CLAIMED_PATHS
|
||||||
|
|
||||||
|
|
||||||
|
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
|
||||||
|
"""Update all active RPC instances with the current event loop.
|
||||||
|
|
||||||
|
This MUST be called at the start of each workflow execution to ensure
|
||||||
|
RPC calls are scheduled on the correct event loop. This handles the case
|
||||||
|
where asyncio.run() creates a new event loop for each workflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loop: The event loop to use. If None, uses asyncio.get_running_loop().
|
||||||
|
"""
|
||||||
|
if loop is None:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
update_count = 0
|
||||||
|
|
||||||
|
# Update RPCs from ExtensionManagers
|
||||||
|
for manager in _EXTENSION_MANAGERS:
|
||||||
|
if not hasattr(manager, "extensions"):
|
||||||
|
continue
|
||||||
|
for name, extension in manager.extensions.items():
|
||||||
|
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||||
|
if hasattr(extension.rpc, "update_event_loop"):
|
||||||
|
extension.rpc.update_event_loop(loop)
|
||||||
|
update_count += 1
|
||||||
|
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
|
||||||
|
|
||||||
|
# Also update RPCs from running extensions (they may have direct RPC refs)
|
||||||
|
for name, extension in _RUNNING_EXTENSIONS.items():
|
||||||
|
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||||
|
if hasattr(extension.rpc, "update_event_loop"):
|
||||||
|
extension.rpc.update_event_loop(loop)
|
||||||
|
update_count += 1
|
||||||
|
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
|
||||||
|
|
||||||
|
if update_count > 0:
|
||||||
|
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LOG_PREFIX",
|
||||||
|
"initialize_proxies",
|
||||||
|
"initialize_isolation_nodes",
|
||||||
|
"start_isolation_loading_early",
|
||||||
|
"await_isolation_loading",
|
||||||
|
"notify_execution_graph",
|
||||||
|
"flush_running_extensions_transport_state",
|
||||||
|
"wait_for_model_patcher_quiescence",
|
||||||
|
"get_claimed_paths",
|
||||||
|
"update_rpc_event_loops",
|
||||||
|
"IsolatedNodeSpec",
|
||||||
|
"get_class_types_for_extension",
|
||||||
|
]
|
||||||
965
comfy/isolation/adapter.py
Normal file
965
comfy/isolation/adapter.py
Normal file
@ -0,0 +1,965 @@
|
|||||||
|
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, cast
|
||||||
|
|
||||||
|
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
||||||
|
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||||
|
|
||||||
|
# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp.
|
||||||
|
# Safe to import in sealed workers without host framework modules.
|
||||||
|
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||||
|
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
|
||||||
|
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||||
|
|
||||||
|
# Singleton proxies that transitively import torch, PIL, or heavy host modules.
|
||||||
|
# Only available when torch/host framework is present.
|
||||||
|
CLIPProxy = None
|
||||||
|
CLIPRegistry = None
|
||||||
|
ModelPatcherProxy = None
|
||||||
|
ModelPatcherRegistry = None
|
||||||
|
ModelSamplingProxy = None
|
||||||
|
ModelSamplingRegistry = None
|
||||||
|
VAEProxy = None
|
||||||
|
VAERegistry = None
|
||||||
|
FirstStageModelRegistry = None
|
||||||
|
ModelManagementProxy = None
|
||||||
|
PromptServerService = None
|
||||||
|
ProgressProxy = None
|
||||||
|
UtilsProxy = None
|
||||||
|
_HAS_TORCH_PROXIES = False
|
||||||
|
if _IMPORT_TORCH:
|
||||||
|
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
|
||||||
|
from comfy.isolation.model_patcher_proxy import (
|
||||||
|
ModelPatcherProxy,
|
||||||
|
ModelPatcherRegistry,
|
||||||
|
)
|
||||||
|
from comfy.isolation.model_sampling_proxy import (
|
||||||
|
ModelSamplingProxy,
|
||||||
|
ModelSamplingRegistry,
|
||||||
|
)
|
||||||
|
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
|
||||||
|
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||||
|
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
|
||||||
|
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||||
|
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||||
|
_HAS_TORCH_PROXIES = True
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Force /dev/shm for shared memory (bwrap makes /tmp private)
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
if os.path.exists("/dev/shm"):
|
||||||
|
# Only override if not already set or if default is not /dev/shm
|
||||||
|
current_tmp = tempfile.gettempdir()
|
||||||
|
if not current_tmp.startswith("/dev/shm"):
|
||||||
|
logger.debug(
|
||||||
|
f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm"
|
||||||
|
)
|
||||||
|
os.environ["TMPDIR"] = "/dev/shm"
|
||||||
|
tempfile.tempdir = None # Clear cache to force re-evaluation
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyUIAdapter(IsolationAdapter):
|
||||||
|
# ComfyUI-specific IsolationAdapter implementation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifier(self) -> str:
|
||||||
|
return "comfyui"
|
||||||
|
|
||||||
|
def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]:
|
||||||
|
if "ComfyUI" in module_path and "custom_nodes" in module_path:
|
||||||
|
parts = module_path.split("ComfyUI")
|
||||||
|
if len(parts) > 1:
|
||||||
|
comfy_root = parts[0] + "ComfyUI"
|
||||||
|
return {
|
||||||
|
"preferred_root": comfy_root,
|
||||||
|
"additional_paths": [
|
||||||
|
os.path.join(comfy_root, "custom_nodes"),
|
||||||
|
os.path.join(comfy_root, "comfy"),
|
||||||
|
],
|
||||||
|
"filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"],
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_sandbox_system_paths(self) -> Optional[List[str]]:
|
||||||
|
"""Returns required application paths to mount in the sandbox."""
|
||||||
|
# By inspecting where our adapter is loaded from, we can determine the comfy root
|
||||||
|
adapter_file = inspect.getfile(self.__class__)
|
||||||
|
# adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py
|
||||||
|
comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file)))
|
||||||
|
if os.path.exists(comfy_root):
|
||||||
|
return [comfy_root]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
|
||||||
|
comfy_root = snapshot.get("preferred_root")
|
||||||
|
if not comfy_root:
|
||||||
|
return
|
||||||
|
|
||||||
|
requirements_path = Path(comfy_root) / "requirements.txt"
|
||||||
|
if requirements_path.exists():
|
||||||
|
import re
|
||||||
|
|
||||||
|
for line in requirements_path.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
pkg_name = re.split(r"[<>=!~\[]", line)[0].strip()
|
||||||
|
if pkg_name:
|
||||||
|
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||||
|
if not _IMPORT_TORCH:
|
||||||
|
# Sealed worker without torch — register torch-free TensorValue handler
|
||||||
|
# so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts.
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_TORCH_DTYPE_TO_NUMPY = {
|
||||||
|
"torch.float32": np.float32,
|
||||||
|
"torch.float64": np.float64,
|
||||||
|
"torch.float16": np.float16,
|
||||||
|
"torch.bfloat16": np.float32, # numpy has no bfloat16; upcast
|
||||||
|
"torch.int32": np.int32,
|
||||||
|
"torch.int64": np.int64,
|
||||||
|
"torch.int16": np.int16,
|
||||||
|
"torch.int8": np.int8,
|
||||||
|
"torch.uint8": np.uint8,
|
||||||
|
"torch.bool": np.bool_,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _deserialize_tensor_value(data: Dict[str, Any]) -> Any:
|
||||||
|
dtype_str = data["dtype"]
|
||||||
|
np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32)
|
||||||
|
shape = tuple(data["tensor_size"])
|
||||||
|
arr = np.array(data["data"], dtype=np_dtype).reshape(shape)
|
||||||
|
return arr
|
||||||
|
|
||||||
|
_NUMPY_TO_TORCH_DTYPE = {
|
||||||
|
np.float32: "torch.float32",
|
||||||
|
np.float64: "torch.float64",
|
||||||
|
np.float16: "torch.float16",
|
||||||
|
np.int32: "torch.int32",
|
||||||
|
np.int64: "torch.int64",
|
||||||
|
np.int16: "torch.int16",
|
||||||
|
np.int8: "torch.int8",
|
||||||
|
np.uint8: "torch.uint8",
|
||||||
|
np.bool_: "torch.bool",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_tensor_value(obj: Any) -> Dict[str, Any]:
|
||||||
|
arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj)
|
||||||
|
dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32")
|
||||||
|
return {
|
||||||
|
"__type__": "TensorValue",
|
||||||
|
"dtype": dtype_str,
|
||||||
|
"tensor_size": list(arr.shape),
|
||||||
|
"requires_grad": False,
|
||||||
|
"data": arr.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
|
||||||
|
# ndarray output from sealed workers serializes as TensorValue for host torch reconstruction
|
||||||
|
registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def serialize_device(obj: Any) -> Dict[str, Any]:
|
||||||
|
return {"__type__": "device", "device_str": str(obj)}
|
||||||
|
|
||||||
|
def deserialize_device(data: Dict[str, Any]) -> Any:
|
||||||
|
return torch.device(data["device_str"])
|
||||||
|
|
||||||
|
registry.register("device", serialize_device, deserialize_device)
|
||||||
|
|
||||||
|
_VALID_DTYPES = {
|
||||||
|
"float16", "float32", "float64", "bfloat16",
|
||||||
|
"int8", "int16", "int32", "int64",
|
||||||
|
"uint8", "bool",
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_dtype(obj: Any) -> Dict[str, Any]:
|
||||||
|
return {"__type__": "dtype", "dtype_str": str(obj)}
|
||||||
|
|
||||||
|
def deserialize_dtype(data: Dict[str, Any]) -> Any:
|
||||||
|
dtype_name = data["dtype_str"].replace("torch.", "")
|
||||||
|
if dtype_name not in _VALID_DTYPES:
|
||||||
|
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
|
||||||
|
return getattr(torch, dtype_name)
|
||||||
|
|
||||||
|
registry.register("dtype", serialize_dtype, deserialize_dtype)
|
||||||
|
|
||||||
|
from comfy_api.latest._io import FolderType
|
||||||
|
from comfy_api.latest._ui import SavedImages, SavedResult
|
||||||
|
|
||||||
|
def serialize_saved_result(obj: Any) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"__type__": "SavedResult",
|
||||||
|
"filename": obj.filename,
|
||||||
|
"subfolder": obj.subfolder,
|
||||||
|
"folder_type": obj.type.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize_saved_result(data: Dict[str, Any]) -> Any:
|
||||||
|
if isinstance(data, SavedResult):
|
||||||
|
return data
|
||||||
|
folder_type = data["folder_type"] if "folder_type" in data else data["type"]
|
||||||
|
return SavedResult(
|
||||||
|
filename=data["filename"],
|
||||||
|
subfolder=data["subfolder"],
|
||||||
|
type=FolderType(folder_type),
|
||||||
|
)
|
||||||
|
|
||||||
|
registry.register(
|
||||||
|
"SavedResult",
|
||||||
|
serialize_saved_result,
|
||||||
|
deserialize_saved_result,
|
||||||
|
data_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize_saved_images(obj: Any) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"__type__": "SavedImages",
|
||||||
|
"results": [serialize_saved_result(result) for result in obj.results],
|
||||||
|
"is_animated": obj.is_animated,
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize_saved_images(data: Dict[str, Any]) -> Any:
|
||||||
|
return SavedImages(
|
||||||
|
results=[deserialize_saved_result(result) for result in data["results"]],
|
||||||
|
is_animated=data.get("is_animated", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
registry.register(
|
||||||
|
"SavedImages",
|
||||||
|
serialize_saved_images,
|
||||||
|
deserialize_saved_images,
|
||||||
|
data_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||||
|
# Child-side: must already have _instance_id (proxy)
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||||
|
if hasattr(obj, "_instance_id"):
|
||||||
|
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ModelPatcher in child lacks _instance_id: "
|
||||||
|
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||||
|
)
|
||||||
|
# Host-side: register with registry
|
||||||
|
if hasattr(obj, "_instance_id"):
|
||||||
|
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||||
|
model_id = ModelPatcherRegistry().register(obj)
|
||||||
|
return {"__type__": "ModelPatcherRef", "model_id": model_id}
|
||||||
|
|
||||||
|
def deserialize_model_patcher(data: Any) -> Any:
|
||||||
|
"""Deserialize ModelPatcher refs; pass through already-materialized objects."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return ModelPatcherProxy(
|
||||||
|
data["model_id"], registry=None, manage_lifecycle=False
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any:
|
||||||
|
"""Context-aware ModelPatcherRef deserializer for both host and child."""
|
||||||
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
if is_child:
|
||||||
|
return ModelPatcherProxy(
|
||||||
|
data["model_id"], registry=None, manage_lifecycle=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ModelPatcherRegistry()._get_instance(data["model_id"])
|
||||||
|
|
||||||
|
# Register ModelPatcher type for serialization
|
||||||
|
registry.register(
|
||||||
|
"ModelPatcher", serialize_model_patcher, deserialize_model_patcher
|
||||||
|
)
|
||||||
|
# Register ModelPatcherProxy type (already a proxy, just return ref)
|
||||||
|
registry.register(
|
||||||
|
"ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher
|
||||||
|
)
|
||||||
|
# Register ModelPatcherRef for deserialization (context-aware: host or child)
|
||||||
|
registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref)
|
||||||
|
|
||||||
|
def serialize_clip(obj: Any) -> Dict[str, Any]:
|
||||||
|
if hasattr(obj, "_instance_id"):
|
||||||
|
return {"__type__": "CLIPRef", "clip_id": obj._instance_id}
|
||||||
|
clip_id = CLIPRegistry().register(obj)
|
||||||
|
return {"__type__": "CLIPRef", "clip_id": clip_id}
|
||||||
|
|
||||||
|
def deserialize_clip(data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def deserialize_clip_ref(data: Dict[str, Any]) -> Any:
|
||||||
|
"""Context-aware CLIPRef deserializer for both host and child."""
|
||||||
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
if is_child:
|
||||||
|
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||||
|
else:
|
||||||
|
return CLIPRegistry()._get_instance(data["clip_id"])
|
||||||
|
|
||||||
|
# Register CLIP type for serialization
|
||||||
|
registry.register("CLIP", serialize_clip, deserialize_clip)
|
||||||
|
# Register CLIPProxy type (already a proxy, just return ref)
|
||||||
|
registry.register("CLIPProxy", serialize_clip, deserialize_clip)
|
||||||
|
# Register CLIPRef for deserialization (context-aware: host or child)
|
||||||
|
registry.register("CLIPRef", None, deserialize_clip_ref)
|
||||||
|
|
||||||
|
def serialize_vae(obj: Any) -> Dict[str, Any]:
|
||||||
|
if hasattr(obj, "_instance_id"):
|
||||||
|
return {"__type__": "VAERef", "vae_id": obj._instance_id}
|
||||||
|
vae_id = VAERegistry().register(obj)
|
||||||
|
return {"__type__": "VAERef", "vae_id": vae_id}
|
||||||
|
|
||||||
|
def deserialize_vae(data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return VAEProxy(data["vae_id"])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def deserialize_vae_ref(data: Dict[str, Any]) -> Any:
|
||||||
|
"""Context-aware VAERef deserializer for both host and child."""
|
||||||
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
if is_child:
|
||||||
|
# Child: create a proxy
|
||||||
|
return VAEProxy(data["vae_id"])
|
||||||
|
else:
|
||||||
|
# Host: lookup real VAE from registry
|
||||||
|
return VAERegistry()._get_instance(data["vae_id"])
|
||||||
|
|
||||||
|
# Register VAE type for serialization
|
||||||
|
registry.register("VAE", serialize_vae, deserialize_vae)
|
||||||
|
# Register VAEProxy type (already a proxy, just return ref)
|
||||||
|
registry.register("VAEProxy", serialize_vae, deserialize_vae)
|
||||||
|
# Register VAERef for deserialization (context-aware: host or child)
|
||||||
|
registry.register("VAERef", None, deserialize_vae_ref)
|
||||||
|
|
||||||
|
# ModelSampling serialization - handles ModelSampling* types
|
||||||
|
# copyreg removed - no pickle fallback allowed
|
||||||
|
|
||||||
|
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
|
||||||
|
# Proxy with _instance_id — return ref (works from both host and child)
|
||||||
|
if hasattr(obj, "_instance_id"):
|
||||||
|
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||||
|
# Child-side: object created locally in child (e.g. ModelSamplingAdvanced
|
||||||
|
# in nodes_z_image_turbo.py). Serialize as inline data so the host can
|
||||||
|
# reconstruct the real torch.nn.Module.
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||||
|
import base64
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
# Identify base classes from comfy.model_sampling
|
||||||
|
bases = []
|
||||||
|
for base in type(obj).__mro__:
|
||||||
|
if base.__module__ == "comfy.model_sampling" and base.__name__ != "object":
|
||||||
|
bases.append(base.__name__)
|
||||||
|
# Serialize state_dict as base64 safetensors-like
|
||||||
|
sd = obj.state_dict()
|
||||||
|
sd_serialized = {}
|
||||||
|
for k, v in sd.items():
|
||||||
|
buf = _io.BytesIO()
|
||||||
|
torch.save(v, buf)
|
||||||
|
sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||||
|
# Capture plain attrs (shift, multiplier, sigma_data, etc.)
|
||||||
|
plain_attrs = {}
|
||||||
|
for k, v in obj.__dict__.items():
|
||||||
|
if k.startswith("_"):
|
||||||
|
continue
|
||||||
|
if isinstance(v, (bool, int, float, str)):
|
||||||
|
plain_attrs[k] = v
|
||||||
|
return {
|
||||||
|
"__type__": "ModelSamplingInline",
|
||||||
|
"bases": bases,
|
||||||
|
"state_dict": sd_serialized,
|
||||||
|
"attrs": plain_attrs,
|
||||||
|
}
|
||||||
|
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||||
|
ms_id = ModelSamplingRegistry().register(obj)
|
||||||
|
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||||
|
|
||||||
|
def deserialize_model_sampling(data: Any) -> Any:
|
||||||
|
"""Deserialize ModelSampling refs or inline data."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if data.get("__type__") == "ModelSamplingInline":
|
||||||
|
return _reconstruct_model_sampling_inline(data)
|
||||||
|
return ModelSamplingProxy(data["ms_id"])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any:
|
||||||
|
"""Reconstruct a ModelSampling object on the host from inline child data."""
|
||||||
|
import comfy.model_sampling as _ms
|
||||||
|
import base64
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
# Resolve base classes
|
||||||
|
base_classes = []
|
||||||
|
for name in data["bases"]:
|
||||||
|
cls = getattr(_ms, name, None)
|
||||||
|
if cls is not None:
|
||||||
|
base_classes.append(cls)
|
||||||
|
if not base_classes:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}"
|
||||||
|
)
|
||||||
|
# Create dynamic class matching the child's class hierarchy
|
||||||
|
ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {})
|
||||||
|
obj = ReconstructedSampling.__new__(ReconstructedSampling)
|
||||||
|
torch.nn.Module.__init__(obj)
|
||||||
|
# Restore plain attributes first
|
||||||
|
for k, v in data.get("attrs", {}).items():
|
||||||
|
setattr(obj, k, v)
|
||||||
|
# Restore state_dict (buffers like sigmas)
|
||||||
|
for k, v_b64 in data.get("state_dict", {}).items():
|
||||||
|
buf = _io.BytesIO(base64.b64decode(v_b64))
|
||||||
|
tensor = torch.load(buf, weights_only=True)
|
||||||
|
# Register as buffer so it's part of state_dict
|
||||||
|
parts = k.split(".")
|
||||||
|
if len(parts) == 1:
|
||||||
|
cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member
|
||||||
|
else:
|
||||||
|
setattr(obj, parts[0], tensor)
|
||||||
|
# Register on host so future references use proxy pattern.
|
||||||
|
# Skip in child process — register() is async RPC and cannot be
|
||||||
|
# called synchronously during deserialization.
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||||
|
ModelSamplingRegistry().register(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
|
||||||
|
"""Context-aware ModelSamplingRef deserializer for both host and child."""
|
||||||
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
if is_child:
|
||||||
|
return ModelSamplingProxy(data["ms_id"])
|
||||||
|
else:
|
||||||
|
return ModelSamplingRegistry()._get_instance(data["ms_id"])
|
||||||
|
|
||||||
|
# Register all ModelSampling* and StableCascadeSampling classes dynamically
|
||||||
|
import comfy.model_sampling
|
||||||
|
|
||||||
|
for ms_cls in vars(comfy.model_sampling).values():
|
||||||
|
if not isinstance(ms_cls, type):
|
||||||
|
continue
|
||||||
|
if not issubclass(ms_cls, torch.nn.Module):
|
||||||
|
continue
|
||||||
|
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
|
||||||
|
continue
|
||||||
|
registry.register(
|
||||||
|
ms_cls.__name__,
|
||||||
|
serialize_model_sampling,
|
||||||
|
deserialize_model_sampling,
|
||||||
|
)
|
||||||
|
registry.register(
|
||||||
|
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
|
||||||
|
)
|
||||||
|
# Register ModelSamplingRef for deserialization (context-aware: host or child)
|
||||||
|
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
|
||||||
|
# Register ModelSamplingInline for deserialization (child→host inline transfer)
|
||||||
|
registry.register(
|
||||||
|
"ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data)
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize_cond(obj: Any) -> Dict[str, Any]:
|
||||||
|
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||||
|
return {
|
||||||
|
"__type__": type_key,
|
||||||
|
"cond": obj.cond,
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize_cond(data: Dict[str, Any]) -> Any:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
type_key = data["__type__"]
|
||||||
|
module_name, class_name = type_key.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
return cls(data["cond"])
|
||||||
|
|
||||||
|
def _serialize_public_state(obj: Any) -> Dict[str, Any]:
|
||||||
|
state: Dict[str, Any] = {}
|
||||||
|
for key, value in obj.__dict__.items():
|
||||||
|
if key.startswith("_"):
|
||||||
|
continue
|
||||||
|
if callable(value):
|
||||||
|
continue
|
||||||
|
state[key] = value
|
||||||
|
return state
|
||||||
|
|
||||||
|
def serialize_latent_format(obj: Any) -> Dict[str, Any]:
|
||||||
|
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||||
|
return {
|
||||||
|
"__type__": type_key,
|
||||||
|
"state": _serialize_public_state(obj),
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize_latent_format(data: Dict[str, Any]) -> Any:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
type_key = data["__type__"]
|
||||||
|
module_name, class_name = type_key.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
obj = cls()
|
||||||
|
for key, value in data.get("state", {}).items():
|
||||||
|
prop = getattr(type(obj), key, None)
|
||||||
|
if isinstance(prop, property) and prop.fset is None:
|
||||||
|
continue
|
||||||
|
setattr(obj, key, value)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
import comfy.conds
|
||||||
|
|
||||||
|
for cond_cls in vars(comfy.conds).values():
|
||||||
|
if not isinstance(cond_cls, type):
|
||||||
|
continue
|
||||||
|
if not issubclass(cond_cls, comfy.conds.CONDRegular):
|
||||||
|
continue
|
||||||
|
type_key = f"{cond_cls.__module__}.{cond_cls.__name__}"
|
||||||
|
registry.register(type_key, serialize_cond, deserialize_cond)
|
||||||
|
registry.register(cond_cls.__name__, serialize_cond, deserialize_cond)
|
||||||
|
|
||||||
|
import comfy.latent_formats
|
||||||
|
|
||||||
|
for latent_cls in vars(comfy.latent_formats).values():
|
||||||
|
if not isinstance(latent_cls, type):
|
||||||
|
continue
|
||||||
|
if not issubclass(latent_cls, comfy.latent_formats.LatentFormat):
|
||||||
|
continue
|
||||||
|
type_key = f"{latent_cls.__module__}.{latent_cls.__name__}"
|
||||||
|
registry.register(
|
||||||
|
type_key, serialize_latent_format, deserialize_latent_format
|
||||||
|
)
|
||||||
|
registry.register(
|
||||||
|
latent_cls.__name__, serialize_latent_format, deserialize_latent_format
|
||||||
|
)
|
||||||
|
|
||||||
|
# V3 API: unwrap NodeOutput.args
|
||||||
|
def deserialize_node_output(data: Any) -> Any:
|
||||||
|
return getattr(data, "args", data)
|
||||||
|
|
||||||
|
registry.register("NodeOutput", None, deserialize_node_output)
|
||||||
|
|
||||||
|
# KSAMPLER serializer: stores sampler name instead of function object
|
||||||
|
# sampler_function is a callable which gets filtered out by JSONSocketTransport
|
||||||
|
def serialize_ksampler(obj: Any) -> Dict[str, Any]:
|
||||||
|
func_name = obj.sampler_function.__name__
|
||||||
|
# Map function name back to sampler name
|
||||||
|
if func_name == "sample_unipc":
|
||||||
|
sampler_name = "uni_pc"
|
||||||
|
elif func_name == "sample_unipc_bh2":
|
||||||
|
sampler_name = "uni_pc_bh2"
|
||||||
|
elif func_name == "dpm_fast_function":
|
||||||
|
sampler_name = "dpm_fast"
|
||||||
|
elif func_name == "dpm_adaptive_function":
|
||||||
|
sampler_name = "dpm_adaptive"
|
||||||
|
elif func_name.startswith("sample_"):
|
||||||
|
sampler_name = func_name[7:] # Remove "sample_" prefix
|
||||||
|
else:
|
||||||
|
sampler_name = func_name
|
||||||
|
return {
|
||||||
|
"__type__": "KSAMPLER",
|
||||||
|
"sampler_name": sampler_name,
|
||||||
|
"extra_options": obj.extra_options,
|
||||||
|
"inpaint_options": obj.inpaint_options,
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize_ksampler(data: Dict[str, Any]) -> Any:
|
||||||
|
import comfy.samplers
|
||||||
|
|
||||||
|
return comfy.samplers.ksampler(
|
||||||
|
data["sampler_name"],
|
||||||
|
data.get("extra_options", {}),
|
||||||
|
data.get("inpaint_options", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler)
|
||||||
|
|
||||||
|
from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers
|
||||||
|
|
||||||
|
register_hooks_serializers(registry)
|
||||||
|
|
||||||
|
# Generic Numpy Serializer
|
||||||
|
def serialize_numpy(obj: Any) -> Any:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Attempt zero-copy conversion to Tensor
|
||||||
|
return torch.from_numpy(obj)
|
||||||
|
except Exception:
|
||||||
|
# Fallback for non-numeric arrays (strings, objects, mixes)
|
||||||
|
return obj.tolist()
|
||||||
|
|
||||||
|
def deserialize_numpy_b64(data: Any) -> Any:
|
||||||
|
"""Deserialize base64-encoded ndarray from sealed worker."""
|
||||||
|
import base64
|
||||||
|
import numpy as np
|
||||||
|
if isinstance(data, dict) and "data" in data and "dtype" in data:
|
||||||
|
raw = base64.b64decode(data["data"])
|
||||||
|
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
|
||||||
|
return torch.from_numpy(arr.copy())
|
||||||
|
return data
|
||||||
|
|
||||||
|
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
|
||||||
|
|
||||||
|
# -- File3D (comfy_api.latest._util.geometry_types) ---------------------
|
||||||
|
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
|
||||||
|
|
||||||
|
def serialize_file3d(obj: Any) -> Dict[str, Any]:
|
||||||
|
import base64
|
||||||
|
return {
|
||||||
|
"__type__": "File3D",
|
||||||
|
"format": obj.format,
|
||||||
|
"data": base64.b64encode(obj.get_bytes()).decode("ascii"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize_file3d(data: Any) -> Any:
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from comfy_api.latest._util.geometry_types import File3D
|
||||||
|
return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"])
|
||||||
|
|
||||||
|
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
|
||||||
|
|
||||||
|
# -- VIDEO (comfy_api.latest._input_impl.video_types) -------------------
|
||||||
|
# Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962
|
||||||
|
|
||||||
|
def serialize_video(obj: Any) -> Dict[str, Any]:
|
||||||
|
components = obj.get_components()
|
||||||
|
images = components.images.detach() if components.images.requires_grad else components.images
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"__type__": "VIDEO",
|
||||||
|
"images": images,
|
||||||
|
"frame_rate_num": components.frame_rate.numerator,
|
||||||
|
"frame_rate_den": components.frame_rate.denominator,
|
||||||
|
}
|
||||||
|
if components.audio is not None:
|
||||||
|
waveform = components.audio["waveform"]
|
||||||
|
if waveform.requires_grad:
|
||||||
|
waveform = waveform.detach()
|
||||||
|
result["audio_waveform"] = waveform
|
||||||
|
result["audio_sample_rate"] = components.audio["sample_rate"]
|
||||||
|
if components.metadata is not None:
|
||||||
|
result["metadata"] = components.metadata
|
||||||
|
return result
|
||||||
|
|
||||||
|
def deserialize_video(data: Any) -> Any:
|
||||||
|
from fractions import Fraction
|
||||||
|
from comfy_api.latest._input_impl.video_types import VideoFromComponents
|
||||||
|
from comfy_api.latest._util.video_types import VideoComponents
|
||||||
|
audio = None
|
||||||
|
if "audio_waveform" in data:
|
||||||
|
audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]}
|
||||||
|
components = VideoComponents(
|
||||||
|
images=data["images"],
|
||||||
|
frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]),
|
||||||
|
audio=audio,
|
||||||
|
metadata=data.get("metadata"),
|
||||||
|
)
|
||||||
|
return VideoFromComponents(components)
|
||||||
|
|
||||||
|
registry.register("VIDEO", serialize_video, deserialize_video, data_type=True)
|
||||||
|
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
|
||||||
|
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
|
||||||
|
|
||||||
|
def setup_web_directory(self, module: Any) -> None:
|
||||||
|
"""Detect WEB_DIRECTORY on a module and populate/register it.
|
||||||
|
|
||||||
|
Called by the sealed worker after loading the node module.
|
||||||
|
Mirrors extension_wrapper.py:216-227 for host-coupled nodes.
|
||||||
|
Does NOT import extension_wrapper.py (it has `import torch` at module level).
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||||
|
if web_dir_attr is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||||
|
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||||
|
|
||||||
|
# Read extension name from pyproject.toml
|
||||||
|
ext_name = os.path.basename(module_dir)
|
||||||
|
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||||
|
if os.path.exists(pyproject):
|
||||||
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ImportError:
|
||||||
|
import tomli as tomllib # type: ignore[no-redef]
|
||||||
|
try:
|
||||||
|
with open(pyproject, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
name = data.get("project", {}).get("name")
|
||||||
|
if name:
|
||||||
|
ext_name = name
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Populate web dir if empty (mirrors _run_prestartup_web_copy)
|
||||||
|
if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))):
|
||||||
|
os.makedirs(web_dir_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Module-defined copy spec
|
||||||
|
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||||
|
if copy_spec is not None and callable(copy_spec):
|
||||||
|
try:
|
||||||
|
copy_spec(web_dir_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e)
|
||||||
|
|
||||||
|
# Fallback: comfy_3d_viewers
|
||||||
|
try:
|
||||||
|
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||||
|
for viewer in VIEWER_FILES:
|
||||||
|
try:
|
||||||
|
copy_viewer(viewer, web_dir_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: comfy_dynamic_widgets
|
||||||
|
try:
|
||||||
|
from comfy_dynamic_widgets import get_js_path
|
||||||
|
src = os.path.realpath(get_js_path())
|
||||||
|
if os.path.exists(src):
|
||||||
|
dst_dir = os.path.join(web_dir_path, "js")
|
||||||
|
os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js"))
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||||
|
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||||
|
logger.info(
|
||||||
|
"][ Adapter: registered web dir for %s (%d files)",
|
||||||
|
ext_name,
|
||||||
|
sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register_host_event_handlers(extension: Any) -> None:
|
||||||
|
"""Register host-side event handlers for an isolated extension.
|
||||||
|
|
||||||
|
Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK``
|
||||||
|
so the ComfyUI frontend receives progress bar updates.
|
||||||
|
"""
|
||||||
|
register_event_handler = inspect.getattr_static(
|
||||||
|
extension, "register_event_handler", None
|
||||||
|
)
|
||||||
|
if not callable(register_event_handler):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _host_progress_handler(payload: dict) -> None:
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
hook = comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
if hook is not None:
|
||||||
|
hook(
|
||||||
|
payload.get("value", 0),
|
||||||
|
payload.get("total", 0),
|
||||||
|
payload.get("preview"),
|
||||||
|
payload.get("node_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
extension.register_event_handler("progress", _host_progress_handler)
|
||||||
|
|
||||||
|
def setup_child_event_hooks(self, extension: Any) -> None:
|
||||||
|
"""Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension.
|
||||||
|
|
||||||
|
Host-coupled only — sealed workers do not have comfy.utils (torch).
|
||||||
|
"""
|
||||||
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child)
|
||||||
|
if not is_child:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not _IMPORT_TORCH:
|
||||||
|
logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)")
|
||||||
|
return
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
def _event_progress_hook(value, total, preview=None, node_id=None):
|
||||||
|
logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id)
|
||||||
|
extension.emit_event("progress", {
|
||||||
|
"value": value,
|
||||||
|
"total": total,
|
||||||
|
"node_id": node_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook
|
||||||
|
logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel")
|
||||||
|
|
||||||
|
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
|
||||||
|
# Always available — no torch/PIL dependency
|
||||||
|
services: List[type[ProxiedSingleton]] = [
|
||||||
|
FolderPathsProxy,
|
||||||
|
HelperProxiesService,
|
||||||
|
WebDirectoryProxy,
|
||||||
|
]
|
||||||
|
# Torch/PIL-dependent proxies
|
||||||
|
if _HAS_TORCH_PROXIES:
|
||||||
|
services.extend([
|
||||||
|
PromptServerService,
|
||||||
|
ModelManagementProxy,
|
||||||
|
UtilsProxy,
|
||||||
|
ProgressProxy,
|
||||||
|
VAERegistry,
|
||||||
|
CLIPRegistry,
|
||||||
|
ModelPatcherRegistry,
|
||||||
|
ModelSamplingRegistry,
|
||||||
|
FirstStageModelRegistry,
|
||||||
|
])
|
||||||
|
return services
|
||||||
|
|
||||||
|
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
|
||||||
|
# Resolve the real name whether it's an instance or the Singleton class itself
|
||||||
|
api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__
|
||||||
|
|
||||||
|
if api_name == "FolderPathsProxy":
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
# Replace module-level functions with proxy methods
|
||||||
|
# This is aggressive but necessary for transparent proxying
|
||||||
|
# Handle both instance and class cases
|
||||||
|
instance = api() if isinstance(api, type) else api
|
||||||
|
for name in dir(instance):
|
||||||
|
if not name.startswith("_"):
|
||||||
|
setattr(folder_paths, name, getattr(instance, name))
|
||||||
|
|
||||||
|
# Fence: isolated children get writable temp inside sandbox
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||||
|
import tempfile
|
||||||
|
_child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp")
|
||||||
|
os.makedirs(_child_temp, exist_ok=True)
|
||||||
|
folder_paths.temp_directory = _child_temp
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if api_name == "ModelManagementProxy":
|
||||||
|
if _IMPORT_TORCH:
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
instance = api() if isinstance(api, type) else api
|
||||||
|
# Replace module-level functions with proxy methods
|
||||||
|
for name in dir(instance):
|
||||||
|
if not name.startswith("_"):
|
||||||
|
setattr(comfy.model_management, name, getattr(instance, name))
|
||||||
|
return
|
||||||
|
|
||||||
|
if api_name == "UtilsProxy":
|
||||||
|
if not _IMPORT_TORCH:
|
||||||
|
logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)")
|
||||||
|
return
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
# Static Injection of RPC mechanism to ensure Child can access it
|
||||||
|
# independent of instance lifecycle.
|
||||||
|
api.set_rpc(rpc)
|
||||||
|
|
||||||
|
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child)
|
||||||
|
|
||||||
|
# Progress hook wiring moved to setup_child_event_hooks via event channel
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if api_name == "PromptServerProxy":
|
||||||
|
if not _IMPORT_TORCH:
|
||||||
|
return
|
||||||
|
# Defer heavy import to child context
|
||||||
|
import server
|
||||||
|
|
||||||
|
instance = api() if isinstance(api, type) else api
|
||||||
|
proxy = (
|
||||||
|
instance.instance
|
||||||
|
) # PromptServerProxy instance has .instance property returning self
|
||||||
|
|
||||||
|
original_register_route = proxy.register_route
|
||||||
|
|
||||||
|
def register_route_wrapper(
|
||||||
|
method: str, path: str, handler: Callable[..., Any]
|
||||||
|
) -> None:
|
||||||
|
callback_id = rpc.register_callback(handler)
|
||||||
|
loop = getattr(rpc, "loop", None)
|
||||||
|
if loop and loop.is_running():
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
original_register_route(
|
||||||
|
method, path, handler=callback_id, is_callback=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
original_register_route(
|
||||||
|
method, path, handler=callback_id, is_callback=True
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
proxy.register_route = register_route_wrapper
|
||||||
|
|
||||||
|
class RouteTableDefProxy:
|
||||||
|
def __init__(self, proxy_instance: Any):
|
||||||
|
self.proxy = proxy_instance
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, path: str, **kwargs: Any
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
self.proxy.register_route("GET", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def post(
|
||||||
|
self, path: str, **kwargs: Any
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
self.proxy.register_route("POST", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def patch(
|
||||||
|
self, path: str, **kwargs: Any
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
self.proxy.register_route("PATCH", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def put(
|
||||||
|
self, path: str, **kwargs: Any
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
self.proxy.register_route("PUT", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self, path: str, **kwargs: Any
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
self.proxy.register_route("DELETE", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
proxy.routes = RouteTableDefProxy(proxy)
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(server, "PromptServer")
|
||||||
|
and getattr(server.PromptServer, "instance", None) != proxy
|
||||||
|
):
|
||||||
|
server.PromptServer.instance = proxy
|
||||||
126
comfy/isolation/child_hooks.py
Normal file
126
comfy/isolation/child_hooks.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation
|
||||||
|
# Child process initialization for PyIsolate
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_child_process() -> bool:
|
||||||
|
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_extra_model_paths() -> None:
|
||||||
|
"""Load extra_model_paths.yaml so the child's folder_paths has the same search paths as the host.
|
||||||
|
|
||||||
|
The host loads this in main.py:143-145. The child is spawned by
|
||||||
|
pyisolate's uds_client.py and never runs main.py, so folder_paths
|
||||||
|
only has the base model directories. Any isolated node calling
|
||||||
|
folder_paths.get_filename_list() in define_schema() would get empty
|
||||||
|
results for folders whose files live in extra_model_paths locations.
|
||||||
|
"""
|
||||||
|
import folder_paths # noqa: F401 — side-effect import; load_extra_path_config writes to folder_paths internals
|
||||||
|
from utils.extra_config import load_extra_path_config
|
||||||
|
|
||||||
|
extra_config_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||||
|
"extra_model_paths.yaml",
|
||||||
|
)
|
||||||
|
if os.path.isfile(extra_config_path):
|
||||||
|
load_extra_path_config(extra_config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_child_process() -> None:
|
||||||
|
logger.warning("][ DIAG:child_hooks initialize_child_process START")
|
||||||
|
if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0":
|
||||||
|
_load_extra_model_paths()
|
||||||
|
_setup_child_loop_bridge()
|
||||||
|
|
||||||
|
# Manual RPC injection
|
||||||
|
try:
|
||||||
|
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||||
|
|
||||||
|
rpc = get_child_rpc_instance()
|
||||||
|
logger.warning("][ DIAG:child_hooks RPC instance: %s", rpc is not None)
|
||||||
|
if rpc:
|
||||||
|
_setup_proxy_callers(rpc)
|
||||||
|
logger.warning("][ DIAG:child_hooks proxy callers configured with RPC")
|
||||||
|
else:
|
||||||
|
logger.warning("][ DIAG:child_hooks NO RPC — proxy callers cleared")
|
||||||
|
_setup_proxy_callers()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"][ DIAG:child_hooks Manual RPC Injection failed: {e}")
|
||||||
|
_setup_proxy_callers()
|
||||||
|
|
||||||
|
_setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_child_loop_bridge() -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
main_loop = None
|
||||||
|
try:
|
||||||
|
main_loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
try:
|
||||||
|
main_loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if main_loop is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .proxies.base import set_global_loop
|
||||||
|
|
||||||
|
set_global_loop(main_loop)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_prompt_server_stub(rpc=None) -> None:
|
||||||
|
try:
|
||||||
|
from .proxies.prompt_server_impl import PromptServerStub
|
||||||
|
|
||||||
|
if rpc:
|
||||||
|
PromptServerStub.set_rpc(rpc)
|
||||||
|
elif hasattr(PromptServerStub, "clear_rpc"):
|
||||||
|
PromptServerStub.clear_rpc()
|
||||||
|
else:
|
||||||
|
PromptServerStub._rpc = None # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to setup PromptServerStub: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_proxy_callers(rpc=None) -> None:
|
||||||
|
try:
|
||||||
|
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||||
|
from .proxies.helper_proxies import HelperProxiesService
|
||||||
|
from .proxies.model_management_proxy import ModelManagementProxy
|
||||||
|
from .proxies.progress_proxy import ProgressProxy
|
||||||
|
from .proxies.prompt_server_impl import PromptServerStub
|
||||||
|
from .proxies.utils_proxy import UtilsProxy
|
||||||
|
|
||||||
|
if rpc is None:
|
||||||
|
FolderPathsProxy.clear_rpc()
|
||||||
|
HelperProxiesService.clear_rpc()
|
||||||
|
ModelManagementProxy.clear_rpc()
|
||||||
|
ProgressProxy.clear_rpc()
|
||||||
|
PromptServerStub.clear_rpc()
|
||||||
|
UtilsProxy.clear_rpc()
|
||||||
|
return
|
||||||
|
|
||||||
|
FolderPathsProxy.set_rpc(rpc)
|
||||||
|
HelperProxiesService.set_rpc(rpc)
|
||||||
|
ModelManagementProxy.set_rpc(rpc)
|
||||||
|
ProgressProxy.set_rpc(rpc)
|
||||||
|
PromptServerStub.set_rpc(rpc)
|
||||||
|
UtilsProxy.set_rpc(rpc)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to setup child singleton proxy callers: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_logging() -> None:
|
||||||
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
521
comfy/isolation/extension_loader.py
Normal file
521
comfy/isolation/extension_loader.py
Normal file
@ -0,0 +1,521 @@
|
|||||||
|
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import platform
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Tuple
|
||||||
|
|
||||||
|
import pyisolate
|
||||||
|
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
||||||
|
from packaging.requirements import InvalidRequirement, Requirement
|
||||||
|
from packaging.utils import canonicalize_name
|
||||||
|
|
||||||
|
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
||||||
|
from .host_policy import load_host_policy
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ImportError:
|
||||||
|
import tomli as tomllib # type: ignore[no-redef]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_web_directory(extension_name: str, node_dir: Path) -> None:
|
||||||
|
"""Register an isolated extension's web directory on the host side."""
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
# Method 1: pyproject.toml [tool.comfy] web field
|
||||||
|
pyproject = node_dir / "pyproject.toml"
|
||||||
|
if pyproject.exists():
|
||||||
|
try:
|
||||||
|
with pyproject.open("rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
web_dir_name = data.get("tool", {}).get("comfy", {}).get("web")
|
||||||
|
if web_dir_name:
|
||||||
|
web_dir_path = str(node_dir / web_dir_name)
|
||||||
|
if os.path.isdir(web_dir_path):
|
||||||
|
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||||
|
logger.debug(
|
||||||
|
"][ Registered web dir for isolated %s: %s",
|
||||||
|
extension_name,
|
||||||
|
web_dir_path,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Method 2: __init__.py WEB_DIRECTORY constant (parse without importing)
|
||||||
|
init_file = node_dir / "__init__.py"
|
||||||
|
if init_file.exists():
|
||||||
|
try:
|
||||||
|
source = init_file.read_text()
|
||||||
|
for line in source.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped.startswith("WEB_DIRECTORY"):
|
||||||
|
# Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web"
|
||||||
|
_, _, value = stripped.partition("=")
|
||||||
|
value = value.strip().strip("\"'")
|
||||||
|
if value:
|
||||||
|
web_dir_path = str((node_dir / value).resolve())
|
||||||
|
if os.path.isdir(web_dir_path):
|
||||||
|
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||||
|
logger.debug(
|
||||||
|
"][ Registered web dir for isolated %s: %s",
|
||||||
|
extension_name,
|
||||||
|
web_dir_path,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_extension_type(execution_model: str) -> type[Any]:
|
||||||
|
if execution_model == "sealed_worker":
|
||||||
|
return pyisolate.SealedNodeExtension
|
||||||
|
|
||||||
|
from .extension_wrapper import ComfyNodeExtension
|
||||||
|
|
||||||
|
return ComfyNodeExtension
|
||||||
|
|
||||||
|
|
||||||
|
async def _stop_extension_safe(extension: Any, extension_name: str) -> None:
|
||||||
|
try:
|
||||||
|
stop_result = extension.stop()
|
||||||
|
if inspect.isawaitable(stop_result):
|
||||||
|
await stop_result
|
||||||
|
except Exception:
|
||||||
|
logger.debug("][ %s stop failed", extension_name, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
|
||||||
|
req, sep, marker = dep.partition(";")
|
||||||
|
req = req.strip()
|
||||||
|
marker_suffix = f";{marker}" if sep else ""
|
||||||
|
|
||||||
|
def _resolve_local_path(local_path: str) -> Path | None:
|
||||||
|
for base in base_paths:
|
||||||
|
candidate = (base / local_path).resolve()
|
||||||
|
if candidate.exists():
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
||||||
|
if req.startswith("./") or req.startswith("../"):
|
||||||
|
resolved = _resolve_local_path(req)
|
||||||
|
if resolved is not None:
|
||||||
|
return f"{resolved}{marker_suffix}"
|
||||||
|
|
||||||
|
if req.startswith("file://"):
|
||||||
|
raw = req[len("file://") :]
|
||||||
|
if raw.startswith("./") or raw.startswith("../"):
|
||||||
|
resolved = _resolve_local_path(raw)
|
||||||
|
if resolved is not None:
|
||||||
|
return f"file://{resolved}{marker_suffix}"
|
||||||
|
|
||||||
|
return dep
|
||||||
|
|
||||||
|
|
||||||
|
def _dependency_name_from_spec(dep: str) -> str | None:
|
||||||
|
stripped = dep.strip()
|
||||||
|
if not stripped or stripped == "-e" or stripped.startswith("-e "):
|
||||||
|
return None
|
||||||
|
if stripped.startswith(("/", "./", "../", "file://")):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return canonicalize_name(Requirement(stripped).name)
|
||||||
|
except InvalidRequirement:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_cuda_wheels_config(
|
||||||
|
tool_config: dict[str, object], dependencies: list[str]
|
||||||
|
) -> dict[str, object] | None:
|
||||||
|
raw_config = tool_config.get("cuda_wheels")
|
||||||
|
if raw_config is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(raw_config, dict):
|
||||||
|
raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table")
|
||||||
|
|
||||||
|
index_url = raw_config.get("index_url")
|
||||||
|
index_urls = raw_config.get("index_urls")
|
||||||
|
if index_urls is not None:
|
||||||
|
if not isinstance(index_urls, list) or not all(
|
||||||
|
isinstance(u, str) and u.strip() for u in index_urls
|
||||||
|
):
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings"
|
||||||
|
)
|
||||||
|
elif not isinstance(index_url, str) or not index_url.strip():
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
|
||||||
|
)
|
||||||
|
|
||||||
|
packages = raw_config.get("packages")
|
||||||
|
if not isinstance(packages, list) or not all(
|
||||||
|
isinstance(package_name, str) and package_name.strip()
|
||||||
|
for package_name in packages
|
||||||
|
):
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
declared_dependencies = {
|
||||||
|
dependency_name
|
||||||
|
for dep in dependencies
|
||||||
|
if (dependency_name := _dependency_name_from_spec(dep)) is not None
|
||||||
|
}
|
||||||
|
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
|
||||||
|
missing = [
|
||||||
|
package_name
|
||||||
|
for package_name in normalized_packages
|
||||||
|
if package_name not in declared_dependencies
|
||||||
|
]
|
||||||
|
if missing:
|
||||||
|
missing_joined = ", ".join(sorted(missing))
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
|
||||||
|
f"{missing_joined}"
|
||||||
|
)
|
||||||
|
|
||||||
|
package_map = raw_config.get("package_map", {})
|
||||||
|
if not isinstance(package_map, dict):
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_package_map: dict[str, str] = {}
|
||||||
|
for dependency_name, index_package_name in package_map.items():
|
||||||
|
if not isinstance(dependency_name, str) or not dependency_name.strip():
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
|
||||||
|
)
|
||||||
|
if not isinstance(index_package_name, str) or not index_package_name.strip():
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
|
||||||
|
)
|
||||||
|
canonical_dependency_name = canonicalize_name(dependency_name)
|
||||||
|
if canonical_dependency_name not in normalized_packages:
|
||||||
|
raise ExtensionLoadError(
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
|
||||||
|
"[tool.comfy.isolation.cuda_wheels.packages]"
|
||||||
|
)
|
||||||
|
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
|
||||||
|
|
||||||
|
result: dict = {
|
||||||
|
"packages": normalized_packages,
|
||||||
|
"package_map": normalized_package_map,
|
||||||
|
}
|
||||||
|
if index_urls is not None:
|
||||||
|
result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls]
|
||||||
|
else:
|
||||||
|
result["index_url"] = index_url.rstrip("/") + "/"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_enforcement_policy() -> Dict[str, bool]:
|
||||||
|
return {
|
||||||
|
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
|
||||||
|
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionLoadError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
|
||||||
|
normalized_name = extension_name.replace("-", "_").replace(".", "_")
|
||||||
|
if normalized_name not in sys.modules:
|
||||||
|
dummy_module = types.ModuleType(normalized_name)
|
||||||
|
dummy_module.__file__ = str(node_dir / "__init__.py")
|
||||||
|
dummy_module.__path__ = [str(node_dir)]
|
||||||
|
dummy_module.__package__ = normalized_name
|
||||||
|
sys.modules[normalized_name] = dummy_module
|
||||||
|
|
||||||
|
|
||||||
|
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
|
||||||
|
for details in cached_data.values():
|
||||||
|
if not isinstance(details, dict):
|
||||||
|
return True
|
||||||
|
if details.get("is_v3") and "schema_v1" not in details:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def load_isolated_node(
|
||||||
|
node_dir: Path,
|
||||||
|
manifest_path: Path,
|
||||||
|
logger: logging.Logger,
|
||||||
|
build_stub_class: Callable[[str, Dict[str, object], Any], type],
|
||||||
|
venv_root: Path,
|
||||||
|
extension_managers: List[ExtensionManager],
|
||||||
|
) -> List[Tuple[str, str, type]]:
|
||||||
|
try:
|
||||||
|
with manifest_path.open("rb") as handle:
|
||||||
|
manifest_data = tomllib.load(handle)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Parse [tool.comfy.isolation]
|
||||||
|
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
||||||
|
can_isolate = tool_config.get("can_isolate", False)
|
||||||
|
share_torch = tool_config.get("share_torch", False)
|
||||||
|
package_manager = tool_config.get("package_manager", "uv")
|
||||||
|
is_conda = package_manager == "conda"
|
||||||
|
execution_model = tool_config.get("execution_model")
|
||||||
|
if execution_model is None:
|
||||||
|
execution_model = "sealed_worker" if is_conda else "host-coupled"
|
||||||
|
|
||||||
|
if "sealed_host_ro_paths" in tool_config:
|
||||||
|
raise ValueError(
|
||||||
|
"Manifest field 'sealed_host_ro_paths' is not allowed. "
|
||||||
|
"Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conda-specific manifest fields
|
||||||
|
conda_channels: list[str] = (
|
||||||
|
tool_config.get("conda_channels", []) if is_conda else []
|
||||||
|
)
|
||||||
|
conda_dependencies: list[str] = (
|
||||||
|
tool_config.get("conda_dependencies", []) if is_conda else []
|
||||||
|
)
|
||||||
|
conda_platforms: list[str] = (
|
||||||
|
tool_config.get("conda_platforms", []) if is_conda else []
|
||||||
|
)
|
||||||
|
conda_python: str = (
|
||||||
|
tool_config.get("conda_python", "*") if is_conda else "*"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse [project] dependencies
|
||||||
|
project_config = manifest_data.get("project", {})
|
||||||
|
dependencies = project_config.get("dependencies", [])
|
||||||
|
if not isinstance(dependencies, list):
|
||||||
|
dependencies = []
|
||||||
|
|
||||||
|
# Get extension name (default to folder name if not in project.name)
|
||||||
|
extension_name = project_config.get("name", node_dir.name)
|
||||||
|
|
||||||
|
# LOGIC: Isolation Decision
|
||||||
|
policy = get_enforcement_policy()
|
||||||
|
isolated = can_isolate or policy["force_isolated"]
|
||||||
|
|
||||||
|
if not isolated:
|
||||||
|
return []
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
base_paths = [Path(folder_paths.base_path), node_dir]
|
||||||
|
dependencies = [
|
||||||
|
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
|
||||||
|
for dep in dependencies
|
||||||
|
]
|
||||||
|
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
|
||||||
|
|
||||||
|
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
||||||
|
extension_type = _get_extension_type(execution_model)
|
||||||
|
manager: ExtensionManager = pyisolate.ExtensionManager(
|
||||||
|
extension_type, manager_config
|
||||||
|
)
|
||||||
|
extension_managers.append(manager)
|
||||||
|
|
||||||
|
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||||
|
|
||||||
|
sandbox_config = {}
|
||||||
|
is_linux = platform.system() == "Linux"
|
||||||
|
|
||||||
|
if is_conda:
|
||||||
|
share_torch = False
|
||||||
|
share_cuda_ipc = False
|
||||||
|
else:
|
||||||
|
share_cuda_ipc = share_torch and is_linux
|
||||||
|
|
||||||
|
if is_linux and isolated:
|
||||||
|
sandbox_config = {
|
||||||
|
"network": host_policy["allow_network"],
|
||||||
|
"writable_paths": host_policy["writable_paths"],
|
||||||
|
"readonly_paths": host_policy["readonly_paths"],
|
||||||
|
}
|
||||||
|
|
||||||
|
extension_config: dict = {
|
||||||
|
"name": extension_name,
|
||||||
|
"module_path": str(node_dir),
|
||||||
|
"isolated": True,
|
||||||
|
"dependencies": dependencies,
|
||||||
|
"share_torch": share_torch,
|
||||||
|
"share_cuda_ipc": share_cuda_ipc,
|
||||||
|
"sandbox_mode": host_policy["sandbox_mode"],
|
||||||
|
"sandbox": sandbox_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
_is_sealed = execution_model == "sealed_worker"
|
||||||
|
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
|
||||||
|
logger.info(
|
||||||
|
"][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])",
|
||||||
|
extension_name,
|
||||||
|
"x" if share_torch else " ",
|
||||||
|
"x" if _is_sealed else " ",
|
||||||
|
"x" if _is_sandboxed else " ",
|
||||||
|
)
|
||||||
|
|
||||||
|
if cuda_wheels is not None:
|
||||||
|
extension_config["cuda_wheels"] = cuda_wheels
|
||||||
|
|
||||||
|
# Conda-specific keys
|
||||||
|
if is_conda:
|
||||||
|
extension_config["package_manager"] = "conda"
|
||||||
|
extension_config["conda_channels"] = conda_channels
|
||||||
|
extension_config["conda_dependencies"] = conda_dependencies
|
||||||
|
extension_config["conda_python"] = conda_python
|
||||||
|
find_links = tool_config.get("find_links", [])
|
||||||
|
if find_links:
|
||||||
|
extension_config["find_links"] = find_links
|
||||||
|
if conda_platforms:
|
||||||
|
extension_config["conda_platforms"] = conda_platforms
|
||||||
|
|
||||||
|
if execution_model != "host-coupled":
|
||||||
|
extension_config["execution_model"] = execution_model
|
||||||
|
if execution_model == "sealed_worker":
|
||||||
|
policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", [])
|
||||||
|
if isinstance(policy_ro_paths, list) and policy_ro_paths:
|
||||||
|
extension_config["sealed_host_ro_paths"] = list(policy_ro_paths)
|
||||||
|
# Sealed workers keep the host RPC service inventory even when the
|
||||||
|
# child resolves no API classes locally.
|
||||||
|
|
||||||
|
extension = manager.load_extension(extension_config)
|
||||||
|
register_dummy_module(extension_name, node_dir)
|
||||||
|
|
||||||
|
# Register host-side event handlers via adapter
|
||||||
|
from .adapter import ComfyUIAdapter
|
||||||
|
ComfyUIAdapter.register_host_event_handlers(extension)
|
||||||
|
|
||||||
|
# Register web directory on the host — only when sandbox is disabled.
|
||||||
|
# In sandbox mode, serving untrusted JS to the browser is not safe.
|
||||||
|
if host_policy["sandbox_mode"] == "disabled":
|
||||||
|
_register_web_directory(extension_name, node_dir)
|
||||||
|
|
||||||
|
# Register for proxied web serving — the child's web dir may have
|
||||||
|
# content that doesn't exist on the host (e.g., pip-installed viewer
|
||||||
|
# bundles). The WebDirectoryCache will lazily fetch via RPC.
|
||||||
|
from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache
|
||||||
|
cache = get_web_directory_cache()
|
||||||
|
cache.register_proxy(extension_name, WebDirectoryProxy())
|
||||||
|
|
||||||
|
# Try cache first (lazy spawn)
|
||||||
|
logger.warning("][ DIAG:ext_loader cache_valid_check for %s", extension_name)
|
||||||
|
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||||
|
cached_data = load_from_cache(node_dir, venv_root)
|
||||||
|
if cached_data:
|
||||||
|
if _is_stale_node_cache(cached_data):
|
||||||
|
logger.warning(
|
||||||
|
"][ DIAG:ext_loader %s cache is stale/incompatible; rebuilding metadata",
|
||||||
|
extension_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("][ DIAG:ext_loader %s USING CACHE — dumping combo options:", extension_name)
|
||||||
|
for node_name, details in cached_data.items():
|
||||||
|
schema_v1 = details.get("schema_v1", {})
|
||||||
|
inp = schema_v1.get("input", {}) if schema_v1 else {}
|
||||||
|
for section_name, section in inp.items():
|
||||||
|
if isinstance(section, dict):
|
||||||
|
for field_name, field_def in section.items():
|
||||||
|
if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]:
|
||||||
|
opts = field_def[1]["options"]
|
||||||
|
logger.warning(
|
||||||
|
"][ DIAG:ext_loader CACHE %s.%s.%s options=%d first=%s",
|
||||||
|
node_name, section_name, field_name,
|
||||||
|
len(opts),
|
||||||
|
opts[:3] if opts else "EMPTY",
|
||||||
|
)
|
||||||
|
specs: List[Tuple[str, str, type]] = []
|
||||||
|
for node_name, details in cached_data.items():
|
||||||
|
stub_cls = build_stub_class(node_name, details, extension)
|
||||||
|
specs.append(
|
||||||
|
(node_name, details.get("display_name", node_name), stub_cls)
|
||||||
|
)
|
||||||
|
return specs
|
||||||
|
else:
|
||||||
|
logger.warning("][ DIAG:ext_loader %s cache INVALID or MISSING", extension_name)
|
||||||
|
|
||||||
|
# Cache miss - spawn process and get metadata
|
||||||
|
logger.warning("][ DIAG:ext_loader %s cache miss, spawning process for metadata", extension_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
remote_nodes: Dict[str, str] = await extension.list_nodes()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"][ %s metadata discovery failed, skipping isolated load: %s",
|
||||||
|
extension_name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await _stop_extension_safe(extension, extension_name)
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not remote_nodes:
|
||||||
|
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
|
||||||
|
await _stop_extension_safe(extension, extension_name)
|
||||||
|
return []
|
||||||
|
|
||||||
|
specs: List[Tuple[str, str, type]] = []
|
||||||
|
cache_data: Dict[str, Dict] = {}
|
||||||
|
|
||||||
|
for node_name, display_name in remote_nodes.items():
|
||||||
|
logger.warning("][ DIAG:ext_loader calling get_node_details for %s.%s", extension_name, node_name)
|
||||||
|
try:
|
||||||
|
details = await extension.get_node_details(node_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"][ %s failed to load metadata for %s, skipping node: %s",
|
||||||
|
extension_name,
|
||||||
|
node_name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
# DIAG: dump combo options from freshly-fetched details
|
||||||
|
schema_v1 = details.get("schema_v1", {})
|
||||||
|
inp = schema_v1.get("input", {}) if schema_v1 else {}
|
||||||
|
for section_name, section in inp.items():
|
||||||
|
if isinstance(section, dict):
|
||||||
|
for field_name, field_def in section.items():
|
||||||
|
if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]:
|
||||||
|
opts = field_def[1]["options"]
|
||||||
|
logger.warning(
|
||||||
|
"][ DIAG:ext_loader FRESH %s.%s.%s options=%d first=%s",
|
||||||
|
node_name, section_name, field_name,
|
||||||
|
len(opts),
|
||||||
|
opts[:3] if opts else "EMPTY",
|
||||||
|
)
|
||||||
|
details["display_name"] = display_name
|
||||||
|
cache_data[node_name] = details
|
||||||
|
stub_cls = build_stub_class(node_name, details, extension)
|
||||||
|
specs.append((node_name, display_name, stub_cls))
|
||||||
|
|
||||||
|
if not specs:
|
||||||
|
logger.warning(
|
||||||
|
"][ %s produced no usable nodes after metadata scan; skipping",
|
||||||
|
extension_name,
|
||||||
|
)
|
||||||
|
await _stop_extension_safe(extension, extension_name)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Save metadata to cache for future runs
|
||||||
|
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
||||||
|
logger.debug(f"][ {extension_name} metadata cached")
|
||||||
|
|
||||||
|
# Re-check web directory AFTER child has populated it
|
||||||
|
if host_policy["sandbox_mode"] == "disabled":
|
||||||
|
_register_web_directory(extension_name, node_dir)
|
||||||
|
|
||||||
|
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||||
|
await _stop_extension_safe(extension, extension_name)
|
||||||
|
|
||||||
|
return specs
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]
|
||||||
896
comfy/isolation/extension_wrapper.py
Normal file
896
comfy/isolation/extension_wrapper.py
Normal file
@ -0,0 +1,896 @@
|
|||||||
|
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __getattr__(self, item):
|
||||||
|
try:
|
||||||
|
return self[item]
|
||||||
|
except KeyError as e:
|
||||||
|
raise AttributeError(item) from e
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return AttrDict(super().copy())
|
||||||
|
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from pyisolate import ExtensionBase
|
||||||
|
|
||||||
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
|
|
||||||
|
LOG_PREFIX = "]["
|
||||||
|
V3_DISCOVERY_TIMEOUT = 30
|
||||||
|
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None:
|
||||||
|
"""Run the web asset copy step that prestartup_script.py used to do.
|
||||||
|
|
||||||
|
If the module's web/ directory is empty and the module had a
|
||||||
|
prestartup_script.py that copied assets from pip packages, this
|
||||||
|
function replicates that work inside the child process.
|
||||||
|
|
||||||
|
Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if
|
||||||
|
defined, otherwise falls back to detecting common asset packages.
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Already populated — nothing to do
|
||||||
|
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||||
|
return
|
||||||
|
|
||||||
|
os.makedirs(web_dir_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Try module-defined copy spec first (generic hook for any node pack)
|
||||||
|
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||||
|
if copy_spec is not None and callable(copy_spec):
|
||||||
|
try:
|
||||||
|
copy_spec(web_dir_path)
|
||||||
|
logger.info(
|
||||||
|
"%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"%s _PRESTARTUP_WEB_COPY failed for %s: %s",
|
||||||
|
LOG_PREFIX, module_dir, e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: detect comfy_3d_viewers and run copy_viewer()
|
||||||
|
try:
|
||||||
|
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||||
|
viewers = list(VIEWER_FILES.keys())
|
||||||
|
for viewer in viewers:
|
||||||
|
try:
|
||||||
|
copy_viewer(viewer, web_dir_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if any(os.scandir(web_dir_path)):
|
||||||
|
logger.info(
|
||||||
|
"%s Copied %d viewer types from comfy_3d_viewers to %s",
|
||||||
|
LOG_PREFIX, len(viewers), web_dir_path,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: detect comfy_dynamic_widgets
|
||||||
|
try:
|
||||||
|
from comfy_dynamic_widgets import get_js_path
|
||||||
|
src = os.path.realpath(get_js_path())
|
||||||
|
if os.path.exists(src):
|
||||||
|
dst_dir = os.path.join(web_dir_path, "js")
|
||||||
|
os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
dst = os.path.join(dst_dir, "dynamic_widgets.js")
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _read_extension_name(module_dir: str) -> str:
|
||||||
|
"""Read extension name from pyproject.toml, falling back to directory name."""
|
||||||
|
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||||
|
if os.path.exists(pyproject):
|
||||||
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ImportError:
|
||||||
|
import tomli as tomllib # type: ignore[no-redef]
|
||||||
|
try:
|
||||||
|
with open(pyproject, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
name = data.get("project", {}).get("name")
|
||||||
|
if name:
|
||||||
|
return name
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return os.path.basename(module_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_tensor_transport_state(marker: str) -> int:
|
||||||
|
try:
|
||||||
|
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
if not callable(flush_tensor_keeper):
|
||||||
|
return 0
|
||||||
|
flushed = flush_tensor_keeper()
|
||||||
|
if flushed > 0:
|
||||||
|
logger.debug(
|
||||||
|
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||||
|
)
|
||||||
|
return flushed
|
||||||
|
|
||||||
|
|
||||||
|
def _relieve_child_vram_pressure(marker: str) -> None:
|
||||||
|
import comfy.model_management as model_management
|
||||||
|
|
||||||
|
model_management.cleanup_models_gc()
|
||||||
|
model_management.cleanup_models()
|
||||||
|
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
if not hasattr(device, "type") or device.type == "cpu":
|
||||||
|
return
|
||||||
|
|
||||||
|
required = max(
|
||||||
|
model_management.minimum_inference_memory(),
|
||||||
|
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||||
|
)
|
||||||
|
if model_management.get_free_memory(device) < required:
|
||||||
|
model_management.free_memory(required, device, for_dynamic=True)
|
||||||
|
if model_management.get_free_memory(device) < required:
|
||||||
|
model_management.free_memory(required, device, for_dynamic=False)
|
||||||
|
model_management.cleanup_models()
|
||||||
|
model_management.soft_empty_cache()
|
||||||
|
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_for_transport(value):
|
||||||
|
primitives = (str, int, float, bool, type(None))
|
||||||
|
if isinstance(value, primitives):
|
||||||
|
return value
|
||||||
|
|
||||||
|
cls_name = value.__class__.__name__
|
||||||
|
if cls_name == "FlexibleOptionalInputType":
|
||||||
|
return {
|
||||||
|
"__pyisolate_flexible_optional__": True,
|
||||||
|
"type": _sanitize_for_transport(getattr(value, "type", "*")),
|
||||||
|
}
|
||||||
|
if cls_name == "AnyType":
|
||||||
|
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||||
|
if cls_name == "ByPassTypeTuple":
|
||||||
|
return {
|
||||||
|
"__pyisolate_bypass_tuple__": [
|
||||||
|
_sanitize_for_transport(v) for v in tuple(value)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: _sanitize_for_transport(v) for k, v in value.items()}
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_sanitize_for_transport(v) for v in value]
|
||||||
|
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
|
||||||
|
# The canonical definition is now in pyisolate._internal.remote_handle
|
||||||
|
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyNodeExtension(ExtensionBase):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.node_classes: Dict[str, type] = {}
|
||||||
|
self.display_names: Dict[str, str] = {}
|
||||||
|
self.node_instances: Dict[str, Any] = {}
|
||||||
|
self.remote_objects: Dict[str, Any] = {}
|
||||||
|
self._route_handlers: Dict[str, Any] = {}
|
||||||
|
self._module: Any = None
|
||||||
|
|
||||||
|
async def on_module_loaded(self, module: Any) -> None:
|
||||||
|
self._module = module
|
||||||
|
|
||||||
|
# Registries are initialized in host_hooks.py initialize_host_process()
|
||||||
|
# They auto-register via ProxiedSingleton when instantiated
|
||||||
|
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
|
||||||
|
|
||||||
|
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||||
|
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||||
|
|
||||||
|
# Register web directory with WebDirectoryProxy (child-side)
|
||||||
|
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||||
|
if web_dir_attr is not None:
|
||||||
|
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||||
|
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||||
|
ext_name = _read_extension_name(module_dir)
|
||||||
|
|
||||||
|
# If web dir is empty, run the copy step that prestartup_script.py did
|
||||||
|
_run_prestartup_web_copy(module, module_dir, web_dir_path)
|
||||||
|
|
||||||
|
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||||
|
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||||
|
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from comfy_api.latest import ComfyExtension
|
||||||
|
|
||||||
|
for name, obj in inspect.getmembers(module):
|
||||||
|
if not (
|
||||||
|
inspect.isclass(obj)
|
||||||
|
and issubclass(obj, ComfyExtension)
|
||||||
|
and obj is not ComfyExtension
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if not obj.__module__.startswith(module.__name__):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
ext_instance = obj()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"%s V3 Extension %s timed out in on_load()",
|
||||||
|
LOG_PREFIX,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
v3_nodes = await asyncio.wait_for(
|
||||||
|
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"%s V3 Extension %s timed out in get_node_list()",
|
||||||
|
LOG_PREFIX,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
for node_cls in v3_nodes:
|
||||||
|
if hasattr(node_cls, "GET_SCHEMA"):
|
||||||
|
schema = node_cls.GET_SCHEMA()
|
||||||
|
self.node_classes[schema.node_id] = node_cls
|
||||||
|
if schema.display_name:
|
||||||
|
self.display_names[schema.node_id] = schema.display_name
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
module_name = getattr(module, "__name__", "isolated_nodes")
|
||||||
|
for node_cls in self.node_classes.values():
|
||||||
|
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
|
||||||
|
node_cls.__module__ = module_name
|
||||||
|
|
||||||
|
self.node_instances = {}
|
||||||
|
|
||||||
|
async def list_nodes(self) -> Dict[str, str]:
|
||||||
|
return {name: self.display_names.get(name, name) for name in self.node_classes}
|
||||||
|
|
||||||
|
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
|
||||||
|
return await self.get_node_details(node_name)
|
||||||
|
|
||||||
|
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
|
||||||
|
node_cls = self._get_node_class(node_name)
|
||||||
|
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
|
||||||
|
logger.warning(
|
||||||
|
"%s DIAG:get_node_details START | node=%s | is_v3=%s | cls=%s",
|
||||||
|
LOG_PREFIX, node_name, is_v3, node_cls,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_types_raw = (
|
||||||
|
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
|
||||||
|
)
|
||||||
|
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
|
||||||
|
if output_is_list is not None:
|
||||||
|
output_is_list = tuple(bool(x) for x in output_is_list)
|
||||||
|
|
||||||
|
details: Dict[str, Any] = {
|
||||||
|
"input_types": _sanitize_for_transport(input_types_raw),
|
||||||
|
"return_types": tuple(
|
||||||
|
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
|
||||||
|
),
|
||||||
|
"return_names": getattr(node_cls, "RETURN_NAMES", None),
|
||||||
|
"function": str(getattr(node_cls, "FUNCTION", "execute")),
|
||||||
|
"category": str(getattr(node_cls, "CATEGORY", "")),
|
||||||
|
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
|
||||||
|
"output_is_list": output_is_list,
|
||||||
|
"is_v3": is_v3,
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_v3:
|
||||||
|
try:
|
||||||
|
logger.warning(
|
||||||
|
"%s DIAG:get_node_details calling GET_SCHEMA for %s",
|
||||||
|
LOG_PREFIX, node_name,
|
||||||
|
)
|
||||||
|
schema = node_cls.GET_SCHEMA()
|
||||||
|
logger.warning(
|
||||||
|
"%s DIAG:get_node_details GET_SCHEMA returned for %s | schema_inputs=%s",
|
||||||
|
LOG_PREFIX, node_name,
|
||||||
|
[getattr(i, 'id', '?') for i in (schema.inputs or [])],
|
||||||
|
)
|
||||||
|
schema_v1 = asdict(schema.get_v1_info(node_cls))
|
||||||
|
try:
|
||||||
|
schema_v3 = asdict(schema.get_v3_info(node_cls))
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
schema_v3 = self._build_schema_v3_fallback(schema)
|
||||||
|
details.update(
|
||||||
|
{
|
||||||
|
"schema_v1": schema_v1,
|
||||||
|
"schema_v3": schema_v3,
|
||||||
|
"hidden": [h.value for h in (schema.hidden or [])],
|
||||||
|
"description": getattr(schema, "description", ""),
|
||||||
|
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
|
||||||
|
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
|
||||||
|
"api_node": bool(getattr(node_cls, "API_NODE", False)),
|
||||||
|
"input_is_list": bool(
|
||||||
|
getattr(node_cls, "INPUT_IS_LIST", False)
|
||||||
|
),
|
||||||
|
"not_idempotent": bool(
|
||||||
|
getattr(node_cls, "NOT_IDEMPOTENT", False)
|
||||||
|
),
|
||||||
|
"accept_all_inputs": bool(
|
||||||
|
getattr(node_cls, "ACCEPT_ALL_INPUTS", False)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"%s V3 schema serialization failed for %s: %s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
node_name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return details
|
||||||
|
|
||||||
|
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
|
||||||
|
input_dict: Dict[str, Any] = {}
|
||||||
|
output_dict: Dict[str, Any] = {}
|
||||||
|
hidden_list: List[str] = []
|
||||||
|
|
||||||
|
if getattr(schema, "inputs", None):
|
||||||
|
for inp in schema.inputs:
|
||||||
|
self._add_schema_io_v3(inp, input_dict)
|
||||||
|
if getattr(schema, "outputs", None):
|
||||||
|
for out in schema.outputs:
|
||||||
|
self._add_schema_io_v3(out, output_dict)
|
||||||
|
if getattr(schema, "hidden", None):
|
||||||
|
for h in schema.hidden:
|
||||||
|
hidden_list.append(getattr(h, "value", str(h)))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input": input_dict,
|
||||||
|
"output": output_dict,
|
||||||
|
"hidden": hidden_list,
|
||||||
|
"name": getattr(schema, "node_id", None),
|
||||||
|
"display_name": getattr(schema, "display_name", None),
|
||||||
|
"description": getattr(schema, "description", None),
|
||||||
|
"category": getattr(schema, "category", None),
|
||||||
|
"output_node": getattr(schema, "is_output_node", False),
|
||||||
|
"deprecated": getattr(schema, "is_deprecated", False),
|
||||||
|
"experimental": getattr(schema, "is_experimental", False),
|
||||||
|
"api_node": getattr(schema, "is_api_node", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
|
||||||
|
io_id = getattr(io_obj, "id", None)
|
||||||
|
if io_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
io_type_fn = getattr(io_obj, "get_io_type", None)
|
||||||
|
io_type = (
|
||||||
|
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
as_dict_fn = getattr(io_obj, "as_dict", None)
|
||||||
|
payload = as_dict_fn() if callable(as_dict_fn) else {}
|
||||||
|
|
||||||
|
target[str(io_id)] = (io_type, payload)
|
||||||
|
|
||||||
|
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
|
||||||
|
node_cls = self._get_node_class(node_name)
|
||||||
|
if hasattr(node_cls, "INPUT_TYPES"):
|
||||||
|
return node_cls.INPUT_TYPES()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
|
||||||
|
LOG_PREFIX,
|
||||||
|
getattr(self, "name", "?"),
|
||||||
|
node_name,
|
||||||
|
len(inputs),
|
||||||
|
)
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||||
|
_relieve_child_vram_pressure("EXT:pre_execute")
|
||||||
|
|
||||||
|
resolved_inputs = self._resolve_remote_objects(inputs)
|
||||||
|
|
||||||
|
instance = self._get_node_instance(node_name)
|
||||||
|
node_cls = self._get_node_class(node_name)
|
||||||
|
|
||||||
|
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
|
||||||
|
# Hidden params come through RPC as string keys like "Hidden.prompt"
|
||||||
|
from comfy_api.latest._io import Hidden, HiddenHolder
|
||||||
|
|
||||||
|
# Map string representations back to Hidden enum keys
|
||||||
|
hidden_string_map = {
|
||||||
|
"Hidden.unique_id": Hidden.unique_id,
|
||||||
|
"Hidden.prompt": Hidden.prompt,
|
||||||
|
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
|
||||||
|
"Hidden.dynprompt": Hidden.dynprompt,
|
||||||
|
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
|
||||||
|
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
|
||||||
|
# Uppercase enum VALUE forms — V3 execution engine passes these
|
||||||
|
"UNIQUE_ID": Hidden.unique_id,
|
||||||
|
"PROMPT": Hidden.prompt,
|
||||||
|
"EXTRA_PNGINFO": Hidden.extra_pnginfo,
|
||||||
|
"DYNPROMPT": Hidden.dynprompt,
|
||||||
|
"AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org,
|
||||||
|
"API_KEY_COMFY_ORG": Hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Find and extract hidden parameters (both enum and string form)
|
||||||
|
hidden_found = {}
|
||||||
|
keys_to_remove = []
|
||||||
|
|
||||||
|
for key in list(resolved_inputs.keys()):
|
||||||
|
# Check string form first (from RPC serialization)
|
||||||
|
if key in hidden_string_map:
|
||||||
|
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
|
||||||
|
keys_to_remove.append(key)
|
||||||
|
# Also check enum form (direct calls)
|
||||||
|
elif isinstance(key, Hidden):
|
||||||
|
hidden_found[key] = resolved_inputs[key]
|
||||||
|
keys_to_remove.append(key)
|
||||||
|
|
||||||
|
# Remove hidden params from kwargs
|
||||||
|
for key in keys_to_remove:
|
||||||
|
resolved_inputs.pop(key)
|
||||||
|
|
||||||
|
# Set hidden on node class if any hidden params found
|
||||||
|
if hidden_found:
|
||||||
|
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
|
||||||
|
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
|
||||||
|
else:
|
||||||
|
# Update existing hidden holder
|
||||||
|
for key, value in hidden_found.items():
|
||||||
|
setattr(node_cls.hidden, key.value.lower(), value)
|
||||||
|
|
||||||
|
# INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this
|
||||||
|
# flag is set. The isolation RPC delivers unwrapped values, so we must
|
||||||
|
# wrap each input in a single-element list to match the contract.
|
||||||
|
if getattr(node_cls, "INPUT_IS_LIST", False):
|
||||||
|
resolved_inputs = {k: [v] for k, v in resolved_inputs.items()}
|
||||||
|
|
||||||
|
function_name = getattr(node_cls, "FUNCTION", "execute")
|
||||||
|
if not hasattr(instance, function_name):
|
||||||
|
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
||||||
|
|
||||||
|
handler = getattr(instance, function_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if asyncio.iscoroutinefunction(handler):
|
||||||
|
with torch.inference_mode():
|
||||||
|
result = await handler(**resolved_inputs)
|
||||||
|
else:
|
||||||
|
import functools
|
||||||
|
|
||||||
|
def _run_with_inference_mode(**kwargs):
|
||||||
|
with torch.inference_mode():
|
||||||
|
return handler(**kwargs)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
result = await loop.run_in_executor(
|
||||||
|
None, functools.partial(_run_with_inference_mode, **resolved_inputs)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"%s ISO:child_execute_error ext=%s node=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
getattr(self, "name", "?"),
|
||||||
|
node_name,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if type(result).__name__ == "NodeOutput":
|
||||||
|
node_output_dict = {
|
||||||
|
"__node_output__": True,
|
||||||
|
"args": self._wrap_unpicklable_objects(result.args),
|
||||||
|
}
|
||||||
|
if result.ui is not None:
|
||||||
|
node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui)
|
||||||
|
if getattr(result, "expand", None) is not None:
|
||||||
|
node_output_dict["expand"] = result.expand
|
||||||
|
if getattr(result, "block_execution", None) is not None:
|
||||||
|
node_output_dict["block_execution"] = result.block_execution
|
||||||
|
return node_output_dict
|
||||||
|
if self._is_comfy_protocol_return(result):
|
||||||
|
wrapped = self._wrap_unpicklable_objects(result)
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
if not isinstance(result, tuple):
|
||||||
|
result = (result,)
|
||||||
|
wrapped = self._wrap_unpicklable_objects(result)
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
async def flush_transport_state(self) -> int:
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||||
|
return 0
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
|
||||||
|
)
|
||||||
|
flushed = _flush_tensor_transport_state("EXT:workflow_end")
|
||||||
|
try:
|
||||||
|
from comfy.isolation.model_patcher_proxy_registry import (
|
||||||
|
ModelPatcherRegistry,
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = ModelPatcherRegistry()
|
||||||
|
removed = registry.sweep_pending_cleanup()
|
||||||
|
if removed > 0:
|
||||||
|
logger.debug(
|
||||||
|
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:child_flush_done ext=%s flushed=%d",
|
||||||
|
LOG_PREFIX,
|
||||||
|
getattr(self, "name", "?"),
|
||||||
|
flushed,
|
||||||
|
)
|
||||||
|
return flushed
|
||||||
|
|
||||||
|
async def get_remote_object(self, object_id: str) -> Any:
|
||||||
|
"""Retrieve a remote object by ID for host-side deserialization."""
|
||||||
|
if object_id not in self.remote_objects:
|
||||||
|
raise KeyError(f"Remote object {object_id} not found")
|
||||||
|
|
||||||
|
return self.remote_objects[object_id]
|
||||||
|
|
||||||
|
def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle:
|
||||||
|
object_id = str(uuid.uuid4())
|
||||||
|
self.remote_objects[object_id] = obj
|
||||||
|
return RemoteObjectHandle(object_id, type(obj).__name__)
|
||||||
|
|
||||||
|
async def call_remote_object_method(
|
||||||
|
self,
|
||||||
|
object_id: str,
|
||||||
|
method_name: str,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Invoke a method or attribute-backed accessor on a child-owned object."""
|
||||||
|
obj = await self.get_remote_object(object_id)
|
||||||
|
|
||||||
|
if method_name == "get_patcher_attr":
|
||||||
|
return getattr(obj, args[0])
|
||||||
|
if method_name == "get_model_options":
|
||||||
|
return getattr(obj, "model_options")
|
||||||
|
if method_name == "set_model_options":
|
||||||
|
setattr(obj, "model_options", args[0])
|
||||||
|
return None
|
||||||
|
if method_name == "get_object_patches":
|
||||||
|
return getattr(obj, "object_patches")
|
||||||
|
if method_name == "get_patches":
|
||||||
|
return getattr(obj, "patches")
|
||||||
|
if method_name == "get_wrappers":
|
||||||
|
return getattr(obj, "wrappers")
|
||||||
|
if method_name == "get_callbacks":
|
||||||
|
return getattr(obj, "callbacks")
|
||||||
|
if method_name == "get_load_device":
|
||||||
|
return getattr(obj, "load_device")
|
||||||
|
if method_name == "get_offload_device":
|
||||||
|
return getattr(obj, "offload_device")
|
||||||
|
if method_name == "get_hook_mode":
|
||||||
|
return getattr(obj, "hook_mode")
|
||||||
|
if method_name == "get_parent":
|
||||||
|
parent = getattr(obj, "parent", None)
|
||||||
|
if parent is None:
|
||||||
|
return None
|
||||||
|
return self._store_remote_object_handle(parent)
|
||||||
|
if method_name == "get_inner_model_attr":
|
||||||
|
attr_name = args[0]
|
||||||
|
if hasattr(obj.model, attr_name):
|
||||||
|
return getattr(obj.model, attr_name)
|
||||||
|
if hasattr(obj, attr_name):
|
||||||
|
return getattr(obj, attr_name)
|
||||||
|
return None
|
||||||
|
if method_name == "inner_model_apply_model":
|
||||||
|
return obj.model.apply_model(*args[0], **args[1])
|
||||||
|
if method_name == "inner_model_extra_conds_shapes":
|
||||||
|
return obj.model.extra_conds_shapes(*args[0], **args[1])
|
||||||
|
if method_name == "inner_model_extra_conds":
|
||||||
|
return obj.model.extra_conds(*args[0], **args[1])
|
||||||
|
if method_name == "inner_model_memory_required":
|
||||||
|
return obj.model.memory_required(*args[0], **args[1])
|
||||||
|
if method_name == "process_latent_in":
|
||||||
|
return obj.model.process_latent_in(*args[0], **args[1])
|
||||||
|
if method_name == "process_latent_out":
|
||||||
|
return obj.model.process_latent_out(*args[0], **args[1])
|
||||||
|
if method_name == "scale_latent_inpaint":
|
||||||
|
return obj.model.scale_latent_inpaint(*args[0], **args[1])
|
||||||
|
if method_name.startswith("get_"):
|
||||||
|
attr_name = method_name[4:]
|
||||||
|
if hasattr(obj, attr_name):
|
||||||
|
return getattr(obj, attr_name)
|
||||||
|
|
||||||
|
target = getattr(obj, method_name)
|
||||||
|
if callable(target):
|
||||||
|
result = target(*args, **kwargs)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
result = await result
|
||||||
|
if type(result).__name__ == "ModelPatcher":
|
||||||
|
return self._store_remote_object_handle(result)
|
||||||
|
return result
|
||||||
|
if args or kwargs:
|
||||||
|
raise TypeError(f"{method_name} is not callable on remote object {object_id}")
|
||||||
|
return target
|
||||||
|
|
||||||
|
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
||||||
|
if isinstance(data, (str, int, float, bool, type(None))):
|
||||||
|
return data
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
tensor = data.detach() if data.requires_grad else data
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
|
||||||
|
return tensor.cpu()
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||||
|
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||||
|
data, "last_hidden_state"
|
||||||
|
):
|
||||||
|
fields = {}
|
||||||
|
for attr in (
|
||||||
|
"penultimate_hidden_states",
|
||||||
|
"last_hidden_state",
|
||||||
|
"image_embeds",
|
||||||
|
"text_embeds",
|
||||||
|
):
|
||||||
|
if hasattr(data, attr):
|
||||||
|
try:
|
||||||
|
fields[attr] = self._wrap_unpicklable_objects(
|
||||||
|
getattr(data, attr)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if fields:
|
||||||
|
return {"__pyisolate_attribute_container__": True, "data": fields}
|
||||||
|
|
||||||
|
# Avoid converting arbitrary objects with stateful methods (models, etc.)
|
||||||
|
# They will be handled via RemoteObjectHandle below.
|
||||||
|
|
||||||
|
type_name = type(data).__name__
|
||||||
|
if type_name == "ModelPatcherProxy":
|
||||||
|
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
|
||||||
|
if type_name == "CLIPProxy":
|
||||||
|
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
|
||||||
|
if type_name == "VAEProxy":
|
||||||
|
return {"__type__": "VAERef", "vae_id": data._instance_id}
|
||||||
|
if type_name == "ModelSamplingProxy":
|
||||||
|
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
|
||||||
|
|
||||||
|
if isinstance(data, (list, tuple)):
|
||||||
|
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
|
||||||
|
return tuple(wrapped) if isinstance(data, tuple) else wrapped
|
||||||
|
if isinstance(data, dict):
|
||||||
|
converted_dict = {
|
||||||
|
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
|
||||||
|
}
|
||||||
|
return {"__pyisolate_attrdict__": True, "data": converted_dict}
|
||||||
|
|
||||||
|
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||||
|
|
||||||
|
registry = SerializerRegistry.get_instance()
|
||||||
|
if registry.is_data_type(type_name):
|
||||||
|
serializer = registry.get_serializer(type_name)
|
||||||
|
if serializer:
|
||||||
|
return serializer(data)
|
||||||
|
|
||||||
|
return self._store_remote_object_handle(data)
|
||||||
|
|
||||||
|
def _resolve_remote_objects(self, data: Any) -> Any:
|
||||||
|
if isinstance(data, RemoteObjectHandle):
|
||||||
|
if data.object_id not in self.remote_objects:
|
||||||
|
raise KeyError(f"Remote object {data.object_id} not found")
|
||||||
|
return self.remote_objects[data.object_id]
|
||||||
|
|
||||||
|
if isinstance(data, dict):
|
||||||
|
ref_type = data.get("__type__")
|
||||||
|
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
|
||||||
|
from pyisolate._internal.model_serialization import (
|
||||||
|
deserialize_proxy_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
return deserialize_proxy_result(data)
|
||||||
|
if ref_type == "ModelSamplingRef":
|
||||||
|
from pyisolate._internal.model_serialization import (
|
||||||
|
deserialize_proxy_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
return deserialize_proxy_result(data)
|
||||||
|
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
|
||||||
|
|
||||||
|
if isinstance(data, (list, tuple)):
|
||||||
|
resolved = [self._resolve_remote_objects(item) for item in data]
|
||||||
|
return tuple(resolved) if isinstance(data, tuple) else resolved
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _get_node_class(self, node_name: str) -> type:
|
||||||
|
if node_name not in self.node_classes:
|
||||||
|
raise KeyError(f"Unknown node: {node_name}")
|
||||||
|
return self.node_classes[node_name]
|
||||||
|
|
||||||
|
def _get_node_instance(self, node_name: str) -> Any:
|
||||||
|
if node_name not in self.node_instances:
|
||||||
|
if node_name not in self.node_classes:
|
||||||
|
raise KeyError(f"Unknown node: {node_name}")
|
||||||
|
self.node_instances[node_name] = self.node_classes[node_name]()
|
||||||
|
return self.node_instances[node_name]
|
||||||
|
|
||||||
|
async def before_module_loaded(self) -> None:
|
||||||
|
# Inject initialization here if we think this is the child
|
||||||
|
logger.warning(
|
||||||
|
"%s DIAG:before_module_loaded START | is_child=%s",
|
||||||
|
LOG_PREFIX, os.environ.get("PYISOLATE_CHILD"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from comfy.isolation import initialize_proxies
|
||||||
|
|
||||||
|
initialize_proxies()
|
||||||
|
logger.warning("%s DIAG:before_module_loaded initialize_proxies OK", LOG_PREFIX)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"%s DIAG:before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e
|
||||||
|
)
|
||||||
|
|
||||||
|
await super().before_module_loaded()
|
||||||
|
try:
|
||||||
|
from comfy_api.latest import ComfyAPI_latest
|
||||||
|
from .proxies.progress_proxy import ProgressProxy
|
||||||
|
|
||||||
|
ComfyAPI_latest.Execution = ProgressProxy
|
||||||
|
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
|
||||||
|
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
|
||||||
|
# latest_ui.folder_paths = fp_proxy
|
||||||
|
# latest_resources.folder_paths = fp_proxy
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def call_route_handler(
|
||||||
|
self,
|
||||||
|
handler_module: str,
|
||||||
|
handler_func: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
cache_key = f"{handler_module}.{handler_func}"
|
||||||
|
if cache_key not in self._route_handlers:
|
||||||
|
if self._module is not None and hasattr(self._module, "__file__"):
|
||||||
|
node_dir = os.path.dirname(self._module.__file__)
|
||||||
|
if node_dir not in sys.path:
|
||||||
|
sys.path.insert(0, node_dir)
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(handler_module)
|
||||||
|
self._route_handlers[cache_key] = getattr(module, handler_func)
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise ValueError(f"Route handler not found: {cache_key}") from e
|
||||||
|
|
||||||
|
handler = self._route_handlers[cache_key]
|
||||||
|
mock_request = MockRequest(request_data)
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(handler):
|
||||||
|
result = await handler(mock_request)
|
||||||
|
else:
|
||||||
|
result = handler(mock_request)
|
||||||
|
return self._serialize_response(result)
|
||||||
|
|
||||||
|
def _is_comfy_protocol_return(self, result: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the result matches the ComfyUI 'Protocol Return' schema.
|
||||||
|
|
||||||
|
A Protocol Return is a dictionary containing specific reserved keys that
|
||||||
|
ComfyUI's execution engine interprets as instructions (UI updates,
|
||||||
|
Workflow expansion, etc.) rather than purely data outputs.
|
||||||
|
|
||||||
|
Schema:
|
||||||
|
- Must be a dict
|
||||||
|
- Must contain at least one of: 'ui', 'result', 'expand'
|
||||||
|
"""
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
return False
|
||||||
|
return any(key in result for key in ("ui", "result", "expand"))
|
||||||
|
|
||||||
|
def _serialize_response(self, response: Any) -> Dict[str, Any]:
|
||||||
|
if response is None:
|
||||||
|
return {"type": "text", "body": "", "status": 204}
|
||||||
|
if isinstance(response, dict):
|
||||||
|
return {"type": "json", "body": response, "status": 200}
|
||||||
|
if isinstance(response, str):
|
||||||
|
return {"type": "text", "body": response, "status": 200}
|
||||||
|
if hasattr(response, "text") and hasattr(response, "status"):
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"body": response.text
|
||||||
|
if hasattr(response, "text")
|
||||||
|
else str(response.body),
|
||||||
|
"status": response.status,
|
||||||
|
"headers": dict(response.headers)
|
||||||
|
if hasattr(response, "headers")
|
||||||
|
else {},
|
||||||
|
}
|
||||||
|
if hasattr(response, "body") and hasattr(response, "status"):
|
||||||
|
body = response.body
|
||||||
|
if isinstance(body, bytes):
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"body": body.decode("utf-8"),
|
||||||
|
"status": response.status,
|
||||||
|
}
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return {
|
||||||
|
"type": "binary",
|
||||||
|
"body": body.hex(),
|
||||||
|
"status": response.status,
|
||||||
|
}
|
||||||
|
return {"type": "json", "body": body, "status": response.status}
|
||||||
|
return {"type": "text", "body": str(response), "status": 200}
|
||||||
|
|
||||||
|
|
||||||
|
class MockRequest:
|
||||||
|
def __init__(self, data: Dict[str, Any]):
|
||||||
|
self.method = data.get("method", "GET")
|
||||||
|
self.path = data.get("path", "/")
|
||||||
|
self.query = data.get("query", {})
|
||||||
|
self._body = data.get("body", {})
|
||||||
|
self._text = data.get("text", "")
|
||||||
|
self.headers = data.get("headers", {})
|
||||||
|
self.content_type = data.get(
|
||||||
|
"content_type", self.headers.get("Content-Type", "application/json")
|
||||||
|
)
|
||||||
|
self.match_info = data.get("match_info", {})
|
||||||
|
|
||||||
|
async def json(self) -> Any:
|
||||||
|
if isinstance(self._body, dict):
|
||||||
|
return self._body
|
||||||
|
if isinstance(self._body, str):
|
||||||
|
return json.loads(self._body)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def post(self) -> Dict[str, Any]:
|
||||||
|
if isinstance(self._body, dict):
|
||||||
|
return self._body
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def text(self) -> str:
|
||||||
|
if self._text:
|
||||||
|
return self._text
|
||||||
|
if isinstance(self._body, str):
|
||||||
|
return self._body
|
||||||
|
if isinstance(self._body, dict):
|
||||||
|
return json.dumps(self._body)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def read(self) -> bytes:
|
||||||
|
return (await self.text()).encode("utf-8")
|
||||||
30
comfy/isolation/host_hooks.py
Normal file
30
comfy/isolation/host_hooks.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
# Host process initialization for PyIsolate
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_host_process() -> None:
|
||||||
|
root = logging.getLogger()
|
||||||
|
for handler in root.handlers[:]:
|
||||||
|
root.removeHandler(handler)
|
||||||
|
root.addHandler(logging.NullHandler())
|
||||||
|
|
||||||
|
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||||
|
from .proxies.helper_proxies import HelperProxiesService
|
||||||
|
from .proxies.model_management_proxy import ModelManagementProxy
|
||||||
|
from .proxies.progress_proxy import ProgressProxy
|
||||||
|
from .proxies.prompt_server_impl import PromptServerService
|
||||||
|
from .proxies.utils_proxy import UtilsProxy
|
||||||
|
from .proxies.web_directory_proxy import WebDirectoryProxy
|
||||||
|
from .vae_proxy import VAERegistry
|
||||||
|
|
||||||
|
FolderPathsProxy()
|
||||||
|
HelperProxiesService()
|
||||||
|
ModelManagementProxy()
|
||||||
|
ProgressProxy()
|
||||||
|
PromptServerService()
|
||||||
|
UtilsProxy()
|
||||||
|
WebDirectoryProxy()
|
||||||
|
VAERegistry()
|
||||||
221
comfy/isolation/manifest_loader.py
Normal file
221
comfy/isolation/manifest_loader.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ImportError:
|
||||||
|
import tomli as tomllib # type: ignore[no-redef]
|
||||||
|
|
||||||
|
LOG_PREFIX = "]["
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CACHE_SUBDIR = "cache"
|
||||||
|
CACHE_KEY_FILE = "cache_key"
|
||||||
|
CACHE_DATA_FILE = "node_info.json"
|
||||||
|
CACHE_KEY_LENGTH = 16
|
||||||
|
_NESTED_SCAN_ROOT = "packages"
|
||||||
|
_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"}
|
||||||
|
|
||||||
|
|
||||||
|
def _read_manifest(manifest_path: Path) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
with manifest_path.open("rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_isolation_manifest(data: dict[str, Any]) -> bool:
|
||||||
|
return (
|
||||||
|
"tool" in data
|
||||||
|
and "comfy" in data["tool"]
|
||||||
|
and "isolation" in data["tool"]["comfy"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]:
|
||||||
|
packages_root = entry / _NESTED_SCAN_ROOT
|
||||||
|
if not packages_root.exists() or not packages_root.is_dir():
|
||||||
|
return []
|
||||||
|
|
||||||
|
nested: List[Tuple[Path, Path]] = []
|
||||||
|
for manifest in sorted(packages_root.rglob("pyproject.toml")):
|
||||||
|
node_dir = manifest.parent
|
||||||
|
if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts):
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = _read_manifest(manifest)
|
||||||
|
if not data or not _is_isolation_manifest(data):
|
||||||
|
continue
|
||||||
|
|
||||||
|
isolation = data["tool"]["comfy"]["isolation"]
|
||||||
|
if isolation.get("standalone") is True:
|
||||||
|
nested.append((node_dir, manifest))
|
||||||
|
|
||||||
|
return nested
|
||||||
|
|
||||||
|
|
||||||
|
def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
||||||
|
"""Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation]."""
|
||||||
|
manifest_dirs: List[Tuple[Path, Path]] = []
|
||||||
|
|
||||||
|
# Standard custom_nodes paths
|
||||||
|
for base_path in folder_paths.get_folder_paths("custom_nodes"):
|
||||||
|
base = Path(base_path)
|
||||||
|
if not base.exists() or not base.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for entry in base.iterdir():
|
||||||
|
if not entry.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Look for pyproject.toml
|
||||||
|
manifest = entry / "pyproject.toml"
|
||||||
|
if not manifest.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = _read_manifest(manifest)
|
||||||
|
if not data or not _is_isolation_manifest(data):
|
||||||
|
continue
|
||||||
|
|
||||||
|
manifest_dirs.append((entry, manifest))
|
||||||
|
manifest_dirs.extend(_discover_nested_manifests(entry))
|
||||||
|
|
||||||
|
return manifest_dirs
|
||||||
|
|
||||||
|
|
||||||
|
def compute_cache_key(node_dir: Path, manifest_path: Path) -> str:
|
||||||
|
"""Hash manifest + .py mtimes + Python version + PyIsolate version."""
|
||||||
|
hasher = hashlib.sha256()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Hashing the manifest content ensures config changes invalidate cache
|
||||||
|
hasher.update(manifest_path.read_bytes())
|
||||||
|
except OSError:
|
||||||
|
hasher.update(b"__manifest_read_error__")
|
||||||
|
|
||||||
|
try:
|
||||||
|
py_files = sorted(node_dir.rglob("*.py"))
|
||||||
|
for py_file in py_files:
|
||||||
|
rel_path = py_file.relative_to(node_dir)
|
||||||
|
if "__pycache__" in str(rel_path) or ".venv" in str(rel_path):
|
||||||
|
continue
|
||||||
|
hasher.update(str(rel_path).encode("utf-8"))
|
||||||
|
try:
|
||||||
|
hasher.update(str(py_file.stat().st_mtime).encode("utf-8"))
|
||||||
|
except OSError:
|
||||||
|
hasher.update(b"__file_stat_error__")
|
||||||
|
except OSError:
|
||||||
|
hasher.update(b"__dir_scan_error__")
|
||||||
|
|
||||||
|
hasher.update(sys.version.encode("utf-8"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pyisolate
|
||||||
|
|
||||||
|
hasher.update(pyisolate.__version__.encode("utf-8"))
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
hasher.update(b"__pyisolate_unknown__")
|
||||||
|
|
||||||
|
return hasher.hexdigest()[:CACHE_KEY_LENGTH]
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]:
|
||||||
|
"""Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/."""
|
||||||
|
cache_dir = venv_root / node_dir.name / CACHE_SUBDIR
|
||||||
|
return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool:
|
||||||
|
"""Return True only if stored cache key matches current computed key."""
|
||||||
|
try:
|
||||||
|
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||||
|
if not cache_key_file.exists() or not cache_data_file.exists():
|
||||||
|
return False
|
||||||
|
current_key = compute_cache_key(node_dir, manifest_path)
|
||||||
|
stored_key = cache_key_file.read_text(encoding="utf-8").strip()
|
||||||
|
return current_key == stored_key
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Load node metadata from cache, return None on any error."""
|
||||||
|
try:
|
||||||
|
_, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||||
|
if not cache_data_file.exists():
|
||||||
|
return None
|
||||||
|
data = json.loads(cache_data_file.read_text(encoding="utf-8"))
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_cache(
|
||||||
|
node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Save node metadata and cache key atomically."""
|
||||||
|
try:
|
||||||
|
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||||
|
cache_dir = cache_key_file.parent
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
cache_key = compute_cache_key(node_dir, manifest_path)
|
||||||
|
|
||||||
|
# Atomic write: data
|
||||||
|
tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||||
|
try:
|
||||||
|
with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(node_data, f, indent=2)
|
||||||
|
os.replace(tmp_data_path, cache_data_file)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_data_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Atomic write: key
|
||||||
|
tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||||
|
try:
|
||||||
|
with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f:
|
||||||
|
f.write(cache_key)
|
||||||
|
os.replace(tmp_key_path, cache_key_file)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_key_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LOG_PREFIX",
|
||||||
|
"find_manifest_directories",
|
||||||
|
"compute_cache_key",
|
||||||
|
"get_cache_path",
|
||||||
|
"is_cache_valid",
|
||||||
|
"load_from_cache",
|
||||||
|
"save_to_cache",
|
||||||
|
]
|
||||||
49
comfy/isolation/rpc_bridge.py
Normal file
49
comfy/isolation/rpc_bridge.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RpcBridge:
|
||||||
|
"""Minimal helper to run coroutines synchronously inside isolated processes.
|
||||||
|
|
||||||
|
If an event loop is already running, the coroutine is executed on a fresh
|
||||||
|
thread with its own loop to avoid nested run_until_complete errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run_sync(self, maybe_coro):
|
||||||
|
if not asyncio.iscoroutine(maybe_coro):
|
||||||
|
return maybe_coro
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
|
||||||
|
if loop and loop.is_running():
|
||||||
|
result_container = {}
|
||||||
|
exc_container = {}
|
||||||
|
|
||||||
|
def _runner():
|
||||||
|
try:
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
result_container["value"] = new_loop.run_until_complete(maybe_coro)
|
||||||
|
except Exception as exc: # pragma: no cover
|
||||||
|
exc_container["error"] = exc
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
new_loop.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
t = threading.Thread(target=_runner, daemon=True)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
if "error" in exc_container:
|
||||||
|
raise exc_container["error"]
|
||||||
|
return result_container.get("value")
|
||||||
|
|
||||||
|
return asyncio.run(maybe_coro)
|
||||||
471
comfy/isolation/runtime_helpers.py
Normal file
471
comfy/isolation/runtime_helpers.py
Normal file
@ -0,0 +1,471 @@
|
|||||||
|
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .proxies.helper_proxies import restore_input_types
|
||||||
|
from .shm_forensics import scan_shm_forensics
|
||||||
|
|
||||||
|
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||||
|
|
||||||
|
_ComfyNodeInternal = object
|
||||||
|
latest_io = None
|
||||||
|
|
||||||
|
if _IMPORT_TORCH:
|
||||||
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
|
from comfy_api.latest import _io as latest_io
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .extension_wrapper import ComfyNodeExtension
|
||||||
|
|
||||||
|
LOG_PREFIX = "]["
|
||||||
|
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class _RemoteObjectRegistryCaller:
|
||||||
|
def __init__(self, extension: Any) -> None:
|
||||||
|
self._extension = extension
|
||||||
|
|
||||||
|
def __getattr__(self, method_name: str) -> Any:
|
||||||
|
async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
return await self._extension.call_remote_object_method(
|
||||||
|
instance_id,
|
||||||
|
method_name,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _call
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any:
|
||||||
|
from pyisolate._internal.remote_handle import RemoteObjectHandle
|
||||||
|
|
||||||
|
if isinstance(value, RemoteObjectHandle):
|
||||||
|
if value.type_name == "ModelPatcher":
|
||||||
|
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||||
|
|
||||||
|
proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False)
|
||||||
|
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||||
|
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||||
|
return proxy
|
||||||
|
if value.type_name == "VAE":
|
||||||
|
from comfy.isolation.vae_proxy import VAEProxy
|
||||||
|
|
||||||
|
proxy = VAEProxy(value.object_id, manage_lifecycle=False)
|
||||||
|
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||||
|
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||||
|
return proxy
|
||||||
|
if value.type_name == "CLIP":
|
||||||
|
from comfy.isolation.clip_proxy import CLIPProxy
|
||||||
|
|
||||||
|
proxy = CLIPProxy(value.object_id, manage_lifecycle=False)
|
||||||
|
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||||
|
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||||
|
return proxy
|
||||||
|
if value.type_name == "ModelSampling":
|
||||||
|
from comfy.isolation.model_sampling_proxy import ModelSamplingProxy
|
||||||
|
|
||||||
|
proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False)
|
||||||
|
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||||
|
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||||
|
return proxy
|
||||||
|
return value
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {
|
||||||
|
k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value]
|
||||||
|
return type(value)(wrapped)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _resource_snapshot() -> Dict[str, int]:
|
||||||
|
fd_count = -1
|
||||||
|
shm_sender_files = 0
|
||||||
|
try:
|
||||||
|
fd_count = len(os.listdir("/proc/self/fd"))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
shm_root = Path("/dev/shm")
|
||||||
|
if shm_root.exists():
|
||||||
|
prefix = f"torch_{os.getpid()}_"
|
||||||
|
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
|
||||||
|
summary: Dict[str, int] = {
|
||||||
|
"tensor_count": 0,
|
||||||
|
"cpu_tensors": 0,
|
||||||
|
"cuda_tensors": 0,
|
||||||
|
"shared_cpu_tensors": 0,
|
||||||
|
"tensor_bytes": 0,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except Exception:
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def visit(node: Any) -> None:
|
||||||
|
if isinstance(node, torch.Tensor):
|
||||||
|
summary["tensor_count"] += 1
|
||||||
|
summary["tensor_bytes"] += int(node.numel() * node.element_size())
|
||||||
|
if node.device.type == "cpu":
|
||||||
|
summary["cpu_tensors"] += 1
|
||||||
|
if node.is_shared():
|
||||||
|
summary["shared_cpu_tensors"] += 1
|
||||||
|
elif node.device.type == "cuda":
|
||||||
|
summary["cuda_tensors"] += 1
|
||||||
|
return
|
||||||
|
if isinstance(node, dict):
|
||||||
|
for v in node.values():
|
||||||
|
visit(v)
|
||||||
|
return
|
||||||
|
if isinstance(node, (list, tuple)):
|
||||||
|
for v in node:
|
||||||
|
visit(v)
|
||||||
|
|
||||||
|
visit(value)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
|
||||||
|
for key, value in inputs.items():
|
||||||
|
key_text = str(key)
|
||||||
|
if "unique_id" in key_text:
|
||||||
|
return str(value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
|
||||||
|
try:
|
||||||
|
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if not callable(flush_tensor_keeper):
|
||||||
|
return
|
||||||
|
flushed = flush_tensor_keeper()
|
||||||
|
if flushed > 0:
|
||||||
|
logger.debug(
|
||||||
|
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
|
||||||
|
import comfy.model_management as model_management
|
||||||
|
|
||||||
|
model_management.cleanup_models_gc()
|
||||||
|
model_management.cleanup_models()
|
||||||
|
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
if not hasattr(device, "type") or device.type == "cpu":
|
||||||
|
return
|
||||||
|
|
||||||
|
required = max(
|
||||||
|
model_management.minimum_inference_memory(),
|
||||||
|
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||||
|
)
|
||||||
|
if model_management.get_free_memory(device) < required:
|
||||||
|
model_management.free_memory(required, device, for_dynamic=True)
|
||||||
|
if model_management.get_free_memory(device) < required:
|
||||||
|
model_management.free_memory(required, device, for_dynamic=False)
|
||||||
|
model_management.cleanup_models()
|
||||||
|
model_management.soft_empty_cache()
|
||||||
|
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||||
|
|
||||||
|
|
||||||
|
def _detach_shared_cpu_tensors(value: Any) -> Any:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except Exception:
|
||||||
|
return value
|
||||||
|
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
if value.device.type == "cpu" and value.is_shared():
|
||||||
|
clone = value.clone()
|
||||||
|
if value.requires_grad:
|
||||||
|
clone.requires_grad_(True)
|
||||||
|
return clone
|
||||||
|
return value
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_detach_shared_cpu_tensors(v) for v in value]
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return tuple(_detach_shared_cpu_tensors(v) for v in value)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def build_stub_class(
|
||||||
|
node_name: str,
|
||||||
|
info: Dict[str, object],
|
||||||
|
extension: "ComfyNodeExtension",
|
||||||
|
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||||
|
logger: logging.Logger,
|
||||||
|
) -> type:
|
||||||
|
if latest_io is None:
|
||||||
|
raise RuntimeError("comfy_api.latest._io is required to build isolation stubs")
|
||||||
|
is_v3 = bool(info.get("is_v3", False))
|
||||||
|
function_name = "_pyisolate_execute"
|
||||||
|
restored_input_types = restore_input_types(info.get("input_types", {}))
|
||||||
|
|
||||||
|
async def _execute(self, **inputs):
|
||||||
|
from comfy.isolation import _RUNNING_EXTENSIONS
|
||||||
|
|
||||||
|
# Update BOTH the local dict AND the module-level dict
|
||||||
|
running_extensions[extension.name] = extension
|
||||||
|
_RUNNING_EXTENSIONS[extension.name] = extension
|
||||||
|
prev_child = None
|
||||||
|
node_unique_id = _extract_hidden_unique_id(inputs)
|
||||||
|
summary = _tensor_transport_summary(inputs)
|
||||||
|
resources = _resource_snapshot()
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:execute_start ext=%s node=%s uid=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
summary["tensor_count"],
|
||||||
|
summary["cpu_tensors"],
|
||||||
|
summary["cuda_tensors"],
|
||||||
|
summary["shared_cpu_tensors"],
|
||||||
|
summary["tensor_bytes"],
|
||||||
|
resources["fd_count"],
|
||||||
|
resources["shm_sender_files"],
|
||||||
|
)
|
||||||
|
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
|
||||||
|
try:
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||||
|
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
|
||||||
|
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
|
||||||
|
from pyisolate._internal.model_serialization import (
|
||||||
|
serialize_for_isolation,
|
||||||
|
deserialize_from_isolation,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:serialize_start ext=%s node=%s uid=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
)
|
||||||
|
# Unwrap NodeOutput-like dicts before serialization.
|
||||||
|
# OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)}
|
||||||
|
# and the executor may pass this dict as input to downstream nodes.
|
||||||
|
unwrapped_inputs = {}
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v):
|
||||||
|
result = v.get("result")
|
||||||
|
if isinstance(result, (tuple, list)) and len(result) > 0:
|
||||||
|
unwrapped_inputs[k] = result[0]
|
||||||
|
else:
|
||||||
|
unwrapped_inputs[k] = result
|
||||||
|
else:
|
||||||
|
unwrapped_inputs[k] = v
|
||||||
|
serialized = serialize_for_isolation(unwrapped_inputs)
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
)
|
||||||
|
result = await extension.execute_node(node_name, **serialized)
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
)
|
||||||
|
# Reconstruct NodeOutput if the child serialized one
|
||||||
|
if isinstance(result, dict) and result.get("__node_output__"):
|
||||||
|
from comfy_api.latest import io as latest_io
|
||||||
|
args_raw = result.get("args", ())
|
||||||
|
deserialized_args = await deserialize_from_isolation(args_raw, extension)
|
||||||
|
deserialized_args = _wrap_remote_handles_as_host_proxies(
|
||||||
|
deserialized_args, extension
|
||||||
|
)
|
||||||
|
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
|
||||||
|
ui_raw = result.get("ui")
|
||||||
|
deserialized_ui = None
|
||||||
|
if ui_raw is not None:
|
||||||
|
deserialized_ui = await deserialize_from_isolation(ui_raw, extension)
|
||||||
|
deserialized_ui = _wrap_remote_handles_as_host_proxies(
|
||||||
|
deserialized_ui, extension
|
||||||
|
)
|
||||||
|
deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui)
|
||||||
|
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||||
|
return latest_io.NodeOutput(
|
||||||
|
*deserialized_args,
|
||||||
|
ui=deserialized_ui,
|
||||||
|
expand=result.get("expand"),
|
||||||
|
block_execution=result.get("block_execution"),
|
||||||
|
)
|
||||||
|
# OUTPUT_NODE: if sealed worker returned a tuple/list whose first
|
||||||
|
# element is a {"ui": ...} dict, unwrap it for the executor.
|
||||||
|
if (isinstance(result, (tuple, list)) and len(result) == 1
|
||||||
|
and isinstance(result[0], dict) and "ui" in result[0]):
|
||||||
|
return result[0]
|
||||||
|
deserialized = await deserialize_from_isolation(result, extension)
|
||||||
|
deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension)
|
||||||
|
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||||
|
return _detach_shared_cpu_tensors(deserialized)
|
||||||
|
except ImportError:
|
||||||
|
return await extension.execute_node(node_name, **inputs)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"%s ISO:execute_error ext=%s node=%s uid=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if prev_child is not None:
|
||||||
|
os.environ["PYISOLATE_CHILD"] = prev_child
|
||||||
|
logger.debug(
|
||||||
|
"%s ISO:execute_end ext=%s node=%s uid=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
extension.name,
|
||||||
|
node_name,
|
||||||
|
node_unique_id or "-",
|
||||||
|
)
|
||||||
|
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
|
||||||
|
|
||||||
|
def _input_types(
|
||||||
|
cls,
|
||||||
|
include_hidden: bool = True,
|
||||||
|
return_schema: bool = False,
|
||||||
|
live_inputs: Any = None,
|
||||||
|
):
|
||||||
|
if not is_v3:
|
||||||
|
return restored_input_types
|
||||||
|
|
||||||
|
inputs_copy = copy.deepcopy(restored_input_types)
|
||||||
|
if not include_hidden:
|
||||||
|
inputs_copy.pop("hidden", None)
|
||||||
|
|
||||||
|
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
|
||||||
|
dynamic = inputs_copy.pop("dynamic_paths", None)
|
||||||
|
if dynamic is not None:
|
||||||
|
v3_data["dynamic_paths"] = dynamic
|
||||||
|
|
||||||
|
if return_schema:
|
||||||
|
hidden_vals = info.get("hidden", []) or []
|
||||||
|
hidden_enums = []
|
||||||
|
for h in hidden_vals:
|
||||||
|
try:
|
||||||
|
hidden_enums.append(latest_io.Hidden(h))
|
||||||
|
except Exception:
|
||||||
|
hidden_enums.append(h)
|
||||||
|
|
||||||
|
class SchemaProxy:
|
||||||
|
hidden = hidden_enums
|
||||||
|
|
||||||
|
return inputs_copy, SchemaProxy, v3_data
|
||||||
|
return inputs_copy
|
||||||
|
|
||||||
|
def _validate_class(cls):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_node_info_v1(cls):
|
||||||
|
node_info = copy.deepcopy(info.get("schema_v1", {}))
|
||||||
|
relative_python_module = node_info.get("python_module")
|
||||||
|
if not isinstance(relative_python_module, str) or not relative_python_module:
|
||||||
|
relative_python_module = f"custom_nodes.{extension.name}"
|
||||||
|
node_info["python_module"] = relative_python_module
|
||||||
|
return node_info
|
||||||
|
|
||||||
|
def _get_base_class(cls):
|
||||||
|
return latest_io.ComfyNode
|
||||||
|
|
||||||
|
attributes: Dict[str, object] = {
|
||||||
|
"FUNCTION": function_name,
|
||||||
|
"CATEGORY": info.get("category", ""),
|
||||||
|
"OUTPUT_NODE": info.get("output_node", False),
|
||||||
|
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
|
||||||
|
"RETURN_NAMES": info.get("return_names"),
|
||||||
|
function_name: _execute,
|
||||||
|
"_pyisolate_extension": extension,
|
||||||
|
"_pyisolate_node_name": node_name,
|
||||||
|
"INPUT_TYPES": classmethod(_input_types),
|
||||||
|
}
|
||||||
|
|
||||||
|
output_is_list = info.get("output_is_list")
|
||||||
|
if output_is_list is not None:
|
||||||
|
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
|
||||||
|
|
||||||
|
if is_v3:
|
||||||
|
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
|
||||||
|
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
|
||||||
|
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
|
||||||
|
attributes["DESCRIPTION"] = info.get("description", "")
|
||||||
|
attributes["EXPERIMENTAL"] = info.get("experimental", False)
|
||||||
|
attributes["DEPRECATED"] = info.get("deprecated", False)
|
||||||
|
attributes["API_NODE"] = info.get("api_node", False)
|
||||||
|
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
|
||||||
|
attributes["ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
|
||||||
|
attributes["_ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
|
||||||
|
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
|
||||||
|
|
||||||
|
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
|
||||||
|
bases = (_ComfyNodeInternal,) if is_v3 else ()
|
||||||
|
stub_cls = type(class_name, bases, attributes)
|
||||||
|
|
||||||
|
if is_v3:
|
||||||
|
try:
|
||||||
|
stub_cls.VALIDATE_CLASS()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
|
||||||
|
|
||||||
|
return stub_cls
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_types_for_extension(
|
||||||
|
extension_name: str,
|
||||||
|
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||||
|
specs: List[Any],
|
||||||
|
) -> Set[str]:
|
||||||
|
extension = running_extensions.get(extension_name)
|
||||||
|
if not extension:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
ext_path = Path(extension.module_path)
|
||||||
|
class_types = set()
|
||||||
|
for spec in specs:
|
||||||
|
if spec.module_path.resolve() == ext_path.resolve():
|
||||||
|
class_types.add(spec.node_name)
|
||||||
|
return class_types
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_stub_class", "get_class_types_for_extension"]
|
||||||
217
comfy/isolation/shm_forensics.py
Normal file
217
comfy/isolation/shm_forensics.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
# pylint: disable=consider-using-from-import,import-outside-toplevel
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
LOG_PREFIX = "]["
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _shm_debug_enabled() -> bool:
|
||||||
|
return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1"
|
||||||
|
|
||||||
|
|
||||||
|
class _SHMForensicsTracker:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._started = False
|
||||||
|
self._tracked_files: Set[str] = set()
|
||||||
|
self._current_model_context: Dict[str, str] = {
|
||||||
|
"id": "unknown",
|
||||||
|
"name": "unknown",
|
||||||
|
"hash": "????",
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _snapshot_shm() -> Set[str]:
|
||||||
|
shm_path = Path("/dev/shm")
|
||||||
|
if not shm_path.exists():
|
||||||
|
return set()
|
||||||
|
return {f.name for f in shm_path.glob("torch_*")}
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
if self._started or not _shm_debug_enabled():
|
||||||
|
return
|
||||||
|
self._tracked_files = self._snapshot_shm()
|
||||||
|
self._started = True
|
||||||
|
logger.debug(
|
||||||
|
"%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files)
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
if not self._started:
|
||||||
|
return
|
||||||
|
self.scan("shutdown", refresh_model_context=True)
|
||||||
|
self._started = False
|
||||||
|
logger.debug("%s SHM:forensics_disabled", LOG_PREFIX)
|
||||||
|
|
||||||
|
def _compute_model_hash(self, model_patcher: Any) -> str:
|
||||||
|
try:
|
||||||
|
model_instance_id = getattr(model_patcher, "_instance_id", None)
|
||||||
|
if model_instance_id is not None:
|
||||||
|
model_id_text = str(model_instance_id)
|
||||||
|
return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
real_model = (
|
||||||
|
model_patcher.model
|
||||||
|
if hasattr(model_patcher, "model")
|
||||||
|
else model_patcher
|
||||||
|
)
|
||||||
|
tensor = None
|
||||||
|
if hasattr(real_model, "parameters"):
|
||||||
|
for p in real_model.parameters():
|
||||||
|
if torch.is_tensor(p) and p.numel() > 0:
|
||||||
|
tensor = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
return "0000"
|
||||||
|
|
||||||
|
flat = tensor.flatten()
|
||||||
|
values = []
|
||||||
|
indices = [0, flat.shape[0] // 2, flat.shape[0] - 1]
|
||||||
|
for i in indices:
|
||||||
|
if i < flat.shape[0]:
|
||||||
|
values.append(flat[i].item())
|
||||||
|
|
||||||
|
size = 0
|
||||||
|
if hasattr(model_patcher, "model_size"):
|
||||||
|
size = model_patcher.model_size()
|
||||||
|
sample_str = f"{values}_{id(model_patcher):016x}_{size}"
|
||||||
|
return hashlib.sha256(sample_str.encode()).hexdigest()[-4:]
|
||||||
|
except Exception:
|
||||||
|
return "err!"
|
||||||
|
|
||||||
|
def _get_models_snapshot(self) -> List[Dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
import comfy.model_management as model_management
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
snapshot: List[Dict[str, Any]] = []
|
||||||
|
try:
|
||||||
|
for loaded_model in model_management.current_loaded_models:
|
||||||
|
model = loaded_model.model
|
||||||
|
if model is None:
|
||||||
|
continue
|
||||||
|
if str(getattr(loaded_model, "device", "")) != "cuda:0":
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = (
|
||||||
|
model.model.__class__.__name__
|
||||||
|
if hasattr(model, "model")
|
||||||
|
else type(model).__name__
|
||||||
|
)
|
||||||
|
model_hash = self._compute_model_hash(model)
|
||||||
|
model_instance_id = getattr(model, "_instance_id", None)
|
||||||
|
if model_instance_id is None:
|
||||||
|
model_instance_id = model_hash
|
||||||
|
snapshot.append(
|
||||||
|
{
|
||||||
|
"name": str(name),
|
||||||
|
"id": str(model_instance_id),
|
||||||
|
"hash": str(model_hash or "????"),
|
||||||
|
"used": bool(getattr(loaded_model, "currently_used", False)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
def _update_model_context(self) -> None:
|
||||||
|
snapshot = self._get_models_snapshot()
|
||||||
|
selected = None
|
||||||
|
|
||||||
|
used_models = [m for m in snapshot if m.get("used") and m.get("id")]
|
||||||
|
if used_models:
|
||||||
|
selected = used_models[-1]
|
||||||
|
else:
|
||||||
|
live_models = [m for m in snapshot if m.get("id")]
|
||||||
|
if live_models:
|
||||||
|
selected = live_models[-1]
|
||||||
|
|
||||||
|
if selected is None:
|
||||||
|
self._current_model_context = {
|
||||||
|
"id": "unknown",
|
||||||
|
"name": "unknown",
|
||||||
|
"hash": "????",
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
self._current_model_context = {
|
||||||
|
"id": str(selected.get("id", "unknown")),
|
||||||
|
"name": str(selected.get("name", "unknown")),
|
||||||
|
"hash": str(selected.get("hash", "????") or "????"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def scan(self, marker: str, refresh_model_context: bool = True) -> None:
|
||||||
|
if not self._started or not _shm_debug_enabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
if refresh_model_context:
|
||||||
|
self._update_model_context()
|
||||||
|
|
||||||
|
current = self._snapshot_shm()
|
||||||
|
added = current - self._tracked_files
|
||||||
|
removed = self._tracked_files - current
|
||||||
|
self._tracked_files = current
|
||||||
|
|
||||||
|
if not added and not removed:
|
||||||
|
logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker)
|
||||||
|
return
|
||||||
|
|
||||||
|
for filename in sorted(added):
|
||||||
|
logger.info("%s SHM:created | %s", LOG_PREFIX, filename)
|
||||||
|
model_id = self._current_model_context["id"]
|
||||||
|
if model_id == "unknown":
|
||||||
|
logger.error(
|
||||||
|
"%s SHM:model_association_missing | file=%s | reason=no_active_model_context",
|
||||||
|
LOG_PREFIX,
|
||||||
|
filename,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s",
|
||||||
|
LOG_PREFIX,
|
||||||
|
model_id,
|
||||||
|
filename,
|
||||||
|
self._current_model_context["name"],
|
||||||
|
self._current_model_context["hash"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for filename in sorted(removed):
|
||||||
|
logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"%s SHM:scan marker=%s created=%d deleted=%d active=%d",
|
||||||
|
LOG_PREFIX,
|
||||||
|
marker,
|
||||||
|
len(added),
|
||||||
|
len(removed),
|
||||||
|
len(self._tracked_files),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_TRACKER = _SHMForensicsTracker()
|
||||||
|
|
||||||
|
|
||||||
|
def start_shm_forensics() -> None:
|
||||||
|
_TRACKER.start()
|
||||||
|
|
||||||
|
|
||||||
|
def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None:
|
||||||
|
_TRACKER.scan(marker, refresh_model_context=refresh_model_context)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_shm_forensics() -> None:
|
||||||
|
_TRACKER.stop()
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(stop_shm_forensics)
|
||||||
@ -35,3 +35,5 @@ pydantic~=2.0
|
|||||||
pydantic-settings~=2.0
|
pydantic-settings~=2.0
|
||||||
PyOpenGL
|
PyOpenGL
|
||||||
glfw
|
glfw
|
||||||
|
|
||||||
|
pyisolate==0.10.0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user