diff --git a/comfy/isolation/proxies/__init__.py b/comfy/isolation/proxies/__init__.py new file mode 100644 index 000000000..30d0089ad --- /dev/null +++ b/comfy/isolation/proxies/__init__.py @@ -0,0 +1,17 @@ +from .base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, + get_thread_loop, + run_coro_in_new_loop, +) + +__all__ = [ + "IS_CHILD_PROCESS", + "BaseRegistry", + "BaseProxy", + "get_thread_loop", + "run_coro_in_new_loop", + "detach_if_grad", +] diff --git a/comfy/isolation/proxies/base.py b/comfy/isolation/proxies/base.py new file mode 100644 index 000000000..498554217 --- /dev/null +++ b/comfy/isolation/proxies/base.py @@ -0,0 +1,301 @@ +# pylint: disable=global-statement,import-outside-toplevel,protected-access +from __future__ import annotations + +import asyncio +import concurrent.futures +import logging +import os +import threading +import time +import weakref +from typing import Any, Callable, Dict, Generic, Optional, TypeVar + +try: + from pyisolate import ProxiedSingleton +except ImportError: + + class ProxiedSingleton: # type: ignore[no-redef] + pass + + +logger = logging.getLogger(__name__) + +IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1" +_thread_local = threading.local() +T = TypeVar("T") + + +def get_thread_loop() -> asyncio.AbstractEventLoop: + loop = getattr(_thread_local, "loop", None) + if loop is None or loop.is_closed(): + loop = asyncio.new_event_loop() + _thread_local.loop = loop + return loop + + +def run_coro_in_new_loop(coro: Any) -> Any: + result_box: Dict[str, Any] = {} + exc_box: Dict[str, BaseException] = {} + + def runner() -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result_box["value"] = loop.run_until_complete(coro) + except Exception as exc: # noqa: BLE001 + exc_box["exc"] = exc + finally: + loop.close() + + t = threading.Thread(target=runner, daemon=True) + t.start() + t.join() + if "exc" in exc_box: + raise exc_box["exc"] + return result_box.get("value") + + +def detach_if_grad(obj: Any) -> Any: + try: + import torch + except Exception: + return obj + + if isinstance(obj, torch.Tensor): + return obj.detach() if obj.requires_grad else obj + if isinstance(obj, (list, tuple)): + return type(obj)(detach_if_grad(x) for x in obj) + if isinstance(obj, dict): + return {k: detach_if_grad(v) for k, v in obj.items()} + return obj + + +class BaseRegistry(ProxiedSingleton, Generic[T]): + _type_prefix: str = "base" + + def __init__(self) -> None: + if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object: + super().__init__() + self._registry: Dict[str, T] = {} + self._id_map: Dict[int, str] = {} + self._counter = 0 + self._lock = threading.Lock() + + def register(self, instance: T) -> str: + with self._lock: + obj_id = id(instance) + if obj_id in self._id_map: + return self._id_map[obj_id] + instance_id = f"{self._type_prefix}_{self._counter}" + self._counter += 1 + self._registry[instance_id] = instance + self._id_map[obj_id] = instance_id + return instance_id + + def unregister_sync(self, instance_id: str) -> None: + with self._lock: + instance = self._registry.pop(instance_id, None) + if instance: + self._id_map.pop(id(instance), None) + + def _get_instance(self, instance_id: str) -> T: + if IS_CHILD_PROCESS: + raise RuntimeError( + f"[{self.__class__.__name__}] _get_instance called in child" + ) + with self._lock: + instance = self._registry.get(instance_id) + if instance is None: + raise ValueError(f"{instance_id} not found") + return instance + + +_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None + + +def set_global_loop(loop: asyncio.AbstractEventLoop) -> None: + global _GLOBAL_LOOP + _GLOBAL_LOOP = loop + + +def run_sync_rpc_coro(coro: Any, timeout_ms: Optional[int] = None) -> Any: + if timeout_ms is not None: + coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0) + + try: + if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): + try: + curr_loop = asyncio.get_running_loop() + if curr_loop is _GLOBAL_LOOP: + pass + except RuntimeError: + future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP) + return future.result( + timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None + ) + + try: + asyncio.get_running_loop() + return run_coro_in_new_loop(coro) + except RuntimeError: + loop = get_thread_loop() + return loop.run_until_complete(coro) + except asyncio.TimeoutError as exc: + raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc + except concurrent.futures.TimeoutError as exc: + raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc + + +def call_singleton_rpc( + caller: Any, + method_name: str, + *args: Any, + timeout_ms: Optional[int] = None, + **kwargs: Any, +) -> Any: + if caller is None: + raise RuntimeError(f"No RPC caller available for {method_name}") + method = getattr(caller, method_name) + return run_sync_rpc_coro(method(*args, **kwargs), timeout_ms=timeout_ms) + + +class BaseProxy(Generic[T]): + _registry_class: type = BaseRegistry # type: ignore[type-arg] + __module__: str = "comfy.isolation.proxies.base" + _TIMEOUT_RPC_METHODS = frozenset( + { + "partially_load", + "partially_unload", + "load", + "patch_model", + "unpatch_model", + "inner_model_apply_model", + "memory_required", + "model_dtype", + "inner_model_memory_required", + "inner_model_extra_conds_shapes", + "inner_model_extra_conds", + "process_latent_in", + "process_latent_out", + "scale_latent_inpaint", + } + ) + + def __init__( + self, + instance_id: str, + registry: Optional[Any] = None, + manage_lifecycle: bool = False, + ) -> None: + self._instance_id = instance_id + self._rpc_caller: Optional[Any] = None + self._registry = registry if registry is not None else self._registry_class() + self._manage_lifecycle = manage_lifecycle + self._cleaned_up = False + if manage_lifecycle and not IS_CHILD_PROCESS: + self._finalizer = weakref.finalize( + self, self._registry.unregister_sync, instance_id + ) + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is None: + raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child") + self._rpc_caller = rpc.create_caller( + self._registry_class, self._registry_class.get_remote_id() + ) + return self._rpc_caller + + def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]: + if method_name not in self._TIMEOUT_RPC_METHODS: + return None + try: + timeout_ms = int( + os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000") + ) + except ValueError: + timeout_ms = 120000 + return max(1, timeout_ms) + + def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + rpc = self._get_rpc() + method = getattr(rpc, method_name) + timeout_ms = self._rpc_timeout_ms_for_method(method_name) + coro = method(self._instance_id, *args, **kwargs) + if timeout_ms is not None: + coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0) + + start_epoch = time.time() + start_perf = time.perf_counter() + thread_id = threading.get_ident() + try: + running_loop = asyncio.get_running_loop() + loop_id: Optional[int] = id(running_loop) + except RuntimeError: + loop_id = None + logger.debug( + "ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f " + "thread=%s loop=%s timeout_ms=%s", + self.__class__.__name__, + method_name, + self._instance_id, + start_epoch, + thread_id, + loop_id, + timeout_ms, + ) + + try: + return run_sync_rpc_coro(coro, timeout_ms=timeout_ms) + except TimeoutError as exc: + raise TimeoutError( + f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} " + f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})" + ) from exc + finally: + end_epoch = time.time() + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + logger.debug( + "ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f " + "elapsed_ms=%.3f thread=%s loop=%s", + self.__class__.__name__, + method_name, + self._instance_id, + end_epoch, + elapsed_ms, + thread_id, + loop_id, + ) + + def __getstate__(self) -> Dict[str, Any]: + return {"_instance_id": self._instance_id} + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._instance_id = state["_instance_id"] + self._rpc_caller = None + self._registry = self._registry_class() + self._manage_lifecycle = False + self._cleaned_up = False + + def cleanup(self) -> None: + if self._cleaned_up or IS_CHILD_PROCESS: + return + self._cleaned_up = True + finalizer = getattr(self, "_finalizer", None) + if finalizer is not None: + finalizer.detach() + self._registry.unregister_sync(self._instance_id) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self._instance_id}>" + + +def create_rpc_method(method_name: str) -> Callable[..., Any]: + def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any: + return self._call_rpc(method_name, *args, **kwargs) + + method.__name__ = method_name + return method diff --git a/comfy/isolation/proxies/folder_paths_proxy.py b/comfy/isolation/proxies/folder_paths_proxy.py new file mode 100644 index 000000000..b324da4e5 --- /dev/null +++ b/comfy/isolation/proxies/folder_paths_proxy.py @@ -0,0 +1,221 @@ +from __future__ import annotations +import logging +import os +import traceback +from typing import Any, Dict, Optional + +from pyisolate import ProxiedSingleton + +from .base import call_singleton_rpc + +_fp_logger = logging.getLogger(__name__) + + +def _folder_paths(): + import folder_paths + + return folder_paths + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def _serialize_folder_names_and_paths(data: dict[str, tuple[list[str], set[str]]]) -> dict[str, dict[str, list[str]]]: + return { + key: {"paths": list(paths), "extensions": sorted(list(extensions))} + for key, (paths, extensions) in data.items() + } + + +def _deserialize_folder_names_and_paths(data: dict[str, dict[str, list[str]]]) -> dict[str, tuple[list[str], set[str]]]: + return { + key: (list(value.get("paths", [])), set(value.get("extensions", []))) + for key, value in data.items() + } + + +class FolderPathsProxy(ProxiedSingleton): + """ + Dynamic proxy for folder_paths. + Uses __getattr__ for most lookups, with explicit handling for + mutable collections to ensure efficient by-value transfer. + """ + + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("FolderPathsProxy RPC caller is not configured") + return cls._rpc + + def __getattr__(self, name): + if _is_child_process(): + property_rpc = { + "models_dir": "rpc_get_models_dir", + "folder_names_and_paths": "rpc_get_folder_names_and_paths", + "extension_mimetypes_cache": "rpc_get_extension_mimetypes_cache", + "filename_list_cache": "rpc_get_filename_list_cache", + } + rpc_name = property_rpc.get(name) + if rpc_name is not None: + return call_singleton_rpc(self._get_caller(), rpc_name) + raise AttributeError(name) + return getattr(_folder_paths(), name) + + @property + def folder_names_and_paths(self) -> Dict: + if _is_child_process(): + payload = call_singleton_rpc(self._get_caller(), "rpc_get_folder_names_and_paths") + return _deserialize_folder_names_and_paths(payload) + return _folder_paths().folder_names_and_paths + + @property + def extension_mimetypes_cache(self) -> Dict: + if _is_child_process(): + return dict(call_singleton_rpc(self._get_caller(), "rpc_get_extension_mimetypes_cache")) + return dict(_folder_paths().extension_mimetypes_cache) + + @property + def filename_list_cache(self) -> Dict: + if _is_child_process(): + return dict(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list_cache")) + return dict(_folder_paths().filename_list_cache) + + @property + def models_dir(self) -> str: + if _is_child_process(): + return str(call_singleton_rpc(self._get_caller(), "rpc_get_models_dir")) + return _folder_paths().models_dir + + def get_temp_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_temp_directory") + return _folder_paths().get_temp_directory() + + def get_input_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_input_directory") + return _folder_paths().get_input_directory() + + def get_output_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_output_directory") + return _folder_paths().get_output_directory() + + def get_user_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_user_directory") + return _folder_paths().get_user_directory() + + def get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str: + if _is_child_process(): + return call_singleton_rpc( + self._get_caller(), "rpc_get_annotated_filepath", name, default_dir + ) + return _folder_paths().get_annotated_filepath(name, default_dir) + + def exists_annotated_filepath(self, name: str) -> bool: + if _is_child_process(): + return bool( + call_singleton_rpc(self._get_caller(), "rpc_exists_annotated_filepath", name) + ) + return bool(_folder_paths().exists_annotated_filepath(name)) + + def add_model_folder_path( + self, folder_name: str, full_folder_path: str, is_default: bool = False + ) -> None: + if _is_child_process(): + call_singleton_rpc( + self._get_caller(), + "rpc_add_model_folder_path", + folder_name, + full_folder_path, + is_default, + ) + return None + _folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default) + return None + + def get_folder_paths(self, folder_name: str) -> list[str]: + if _is_child_process(): + return list(call_singleton_rpc(self._get_caller(), "rpc_get_folder_paths", folder_name)) + return list(_folder_paths().get_folder_paths(folder_name)) + + def get_filename_list(self, folder_name: str) -> list[str]: + caller_stack = "".join(traceback.format_stack()[-4:-1]) + _fp_logger.warning( + "][ DIAG:FolderPathsProxy.get_filename_list called | folder=%s | is_child=%s | rpc_configured=%s\n%s", + folder_name, _is_child_process(), self._rpc is not None, caller_stack, + ) + if _is_child_process(): + result = list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name)) + _fp_logger.warning( + "][ DIAG:FolderPathsProxy.get_filename_list RPC result | folder=%s | count=%d | first=%s", + folder_name, len(result), result[:3] if result else "EMPTY", + ) + return result + result = list(_folder_paths().get_filename_list(folder_name)) + _fp_logger.warning( + "][ DIAG:FolderPathsProxy.get_filename_list LOCAL result | folder=%s | count=%d | first=%s", + folder_name, len(result), result[:3] if result else "EMPTY", + ) + return result + + def get_full_path(self, folder_name: str, filename: str) -> str | None: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_full_path", folder_name, filename) + return _folder_paths().get_full_path(folder_name, filename) + + async def rpc_get_models_dir(self) -> str: + return _folder_paths().models_dir + + async def rpc_get_folder_names_and_paths(self) -> dict[str, dict[str, list[str]]]: + return _serialize_folder_names_and_paths(_folder_paths().folder_names_and_paths) + + async def rpc_get_extension_mimetypes_cache(self) -> dict[str, Any]: + return dict(_folder_paths().extension_mimetypes_cache) + + async def rpc_get_filename_list_cache(self) -> dict[str, Any]: + return dict(_folder_paths().filename_list_cache) + + async def rpc_get_temp_directory(self) -> str: + return _folder_paths().get_temp_directory() + + async def rpc_get_input_directory(self) -> str: + return _folder_paths().get_input_directory() + + async def rpc_get_output_directory(self) -> str: + return _folder_paths().get_output_directory() + + async def rpc_get_user_directory(self) -> str: + return _folder_paths().get_user_directory() + + async def rpc_get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str: + return _folder_paths().get_annotated_filepath(name, default_dir) + + async def rpc_exists_annotated_filepath(self, name: str) -> bool: + return _folder_paths().exists_annotated_filepath(name) + + async def rpc_add_model_folder_path( + self, folder_name: str, full_folder_path: str, is_default: bool = False + ) -> None: + _folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default) + + async def rpc_get_folder_paths(self, folder_name: str) -> list[str]: + return _folder_paths().get_folder_paths(folder_name) + + async def rpc_get_filename_list(self, folder_name: str) -> list[str]: + return _folder_paths().get_filename_list(folder_name) + + async def rpc_get_full_path(self, folder_name: str, filename: str) -> str | None: + return _folder_paths().get_full_path(folder_name, filename) diff --git a/comfy/isolation/proxies/helper_proxies.py b/comfy/isolation/proxies/helper_proxies.py new file mode 100644 index 000000000..278c098f1 --- /dev/null +++ b/comfy/isolation/proxies/helper_proxies.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, Optional + +from pyisolate import ProxiedSingleton + +from .base import call_singleton_rpc + + +class AnyTypeProxy(str): + """Replacement for custom AnyType objects used by some nodes.""" + + def __new__(cls, value: str = "*"): + return super().__new__(cls, value) + + def __ne__(self, other): # type: ignore[override] + return False + + +class FlexibleOptionalInputProxy(dict): + """Replacement for FlexibleOptionalInputType to allow dynamic inputs.""" + + def __init__(self, flex_type, data: Optional[Dict[str, object]] = None): + super().__init__() + self.type = flex_type + if data: + self.update(data) + + def __getitem__(self, key): # type: ignore[override] + return (self.type,) + + def __contains__(self, key): # type: ignore[override] + return True + + +class ByPassTypeTupleProxy(tuple): + """Replacement for ByPassTypeTuple to mirror wildcard fallback behavior.""" + + def __new__(cls, values): + return super().__new__(cls, values) + + def __getitem__(self, index): # type: ignore[override] + if index >= len(self): + return AnyTypeProxy("*") + return super().__getitem__(index) + + +def _restore_special_value(value: Any) -> Any: + if isinstance(value, dict): + if value.get("__pyisolate_any_type__"): + return AnyTypeProxy(value.get("value", "*")) + if value.get("__pyisolate_flexible_optional__"): + flex_type = _restore_special_value(value.get("type")) + data_raw = value.get("data") + data = ( + {k: _restore_special_value(v) for k, v in data_raw.items()} + if isinstance(data_raw, dict) + else {} + ) + return FlexibleOptionalInputProxy(flex_type, data) + if value.get("__pyisolate_tuple__") is not None: + return tuple( + _restore_special_value(v) for v in value["__pyisolate_tuple__"] + ) + if value.get("__pyisolate_bypass_tuple__") is not None: + return ByPassTypeTupleProxy( + tuple( + _restore_special_value(v) + for v in value["__pyisolate_bypass_tuple__"] + ) + ) + return {k: _restore_special_value(v) for k, v in value.items()} + if isinstance(value, list): + return [_restore_special_value(v) for v in value] + return value + + +def _serialize_special_value(value: Any) -> Any: + if isinstance(value, AnyTypeProxy): + return {"__pyisolate_any_type__": True, "value": str(value)} + if isinstance(value, FlexibleOptionalInputProxy): + return { + "__pyisolate_flexible_optional__": True, + "type": _serialize_special_value(value.type), + "data": {k: _serialize_special_value(v) for k, v in value.items()}, + } + if isinstance(value, ByPassTypeTupleProxy): + return { + "__pyisolate_bypass_tuple__": [_serialize_special_value(v) for v in value] + } + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_serialize_special_value(v) for v in value]} + if isinstance(value, list): + return [_serialize_special_value(v) for v in value] + if isinstance(value, dict): + return {k: _serialize_special_value(v) for k, v in value.items()} + return value + + +def _restore_input_types_local(raw: Dict[str, object]) -> Dict[str, object]: + if not isinstance(raw, dict): + return raw # type: ignore[return-value] + + restored: Dict[str, object] = {} + for section, entries in raw.items(): + if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"): + restored[section] = _restore_special_value(entries) + elif isinstance(entries, dict): + restored[section] = { + k: _restore_special_value(v) for k, v in entries.items() + } + else: + restored[section] = _restore_special_value(entries) + return restored + + +class HelperProxiesService(ProxiedSingleton): + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("HelperProxiesService RPC caller is not configured") + return cls._rpc + + async def rpc_restore_input_types(self, raw: Dict[str, object]) -> Dict[str, object]: + restored = _restore_input_types_local(raw) + return _serialize_special_value(restored) + + +def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]: + """Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects.""" + if os.environ.get("PYISOLATE_CHILD") == "1": + payload = call_singleton_rpc( + HelperProxiesService._get_caller(), + "rpc_restore_input_types", + raw, + ) + return _restore_input_types_local(payload) + return _restore_input_types_local(raw) + + +__all__ = [ + "AnyTypeProxy", + "FlexibleOptionalInputProxy", + "ByPassTypeTupleProxy", + "HelperProxiesService", + "restore_input_types", +] diff --git a/comfy/isolation/proxies/model_management_proxy.py b/comfy/isolation/proxies/model_management_proxy.py new file mode 100644 index 000000000..445210aa4 --- /dev/null +++ b/comfy/isolation/proxies/model_management_proxy.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import os +from typing import Any, Optional + +from pyisolate import ProxiedSingleton + +from .base import call_singleton_rpc + + +def _mm(): + import comfy.model_management + + return comfy.model_management + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +class TorchDeviceProxy: + def __init__(self, device_str: str): + self._device_str = device_str + if ":" in device_str: + device_type, index = device_str.split(":", 1) + self.type = device_type + self.index = int(index) + else: + self.type = device_str + self.index = None + + def __str__(self) -> str: + return self._device_str + + def __repr__(self) -> str: + return f"TorchDeviceProxy({self._device_str!r})" + + +def _serialize_value(value: Any) -> Any: + value_type = type(value) + if value_type.__module__ == "torch" and value_type.__name__ == "device": + return {"__pyisolate_torch_device__": str(value)} + if isinstance(value, TorchDeviceProxy): + return {"__pyisolate_torch_device__": str(value)} + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_serialize_value(item) for item in value]} + if isinstance(value, list): + return [_serialize_value(item) for item in value] + if isinstance(value, dict): + return {key: _serialize_value(inner) for key, inner in value.items()} + return value + + +def _deserialize_value(value: Any) -> Any: + if isinstance(value, dict): + if "__pyisolate_torch_device__" in value: + return TorchDeviceProxy(value["__pyisolate_torch_device__"]) + if "__pyisolate_tuple__" in value: + return tuple(_deserialize_value(item) for item in value["__pyisolate_tuple__"]) + return {key: _deserialize_value(inner) for key, inner in value.items()} + if isinstance(value, list): + return [_deserialize_value(item) for item in value] + return value + + +def _normalize_argument(value: Any) -> Any: + if isinstance(value, TorchDeviceProxy): + import torch + + return torch.device(str(value)) + if isinstance(value, dict): + if "__pyisolate_torch_device__" in value: + import torch + + return torch.device(value["__pyisolate_torch_device__"]) + if "__pyisolate_tuple__" in value: + return tuple(_normalize_argument(item) for item in value["__pyisolate_tuple__"]) + return {key: _normalize_argument(inner) for key, inner in value.items()} + if isinstance(value, list): + return [_normalize_argument(item) for item in value] + return value + + +class ModelManagementProxy(ProxiedSingleton): + """ + Exact-relay proxy for comfy.model_management. + Child calls never import comfy.model_management directly; they serialize + arguments, relay to host, and deserialize the host result back. + """ + + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("ModelManagementProxy RPC caller is not configured") + return cls._rpc + + def _relay_call(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + payload = call_singleton_rpc( + self._get_caller(), + "rpc_call", + method_name, + _serialize_value(args), + _serialize_value(kwargs), + ) + return _deserialize_value(payload) + + @property + def VRAMState(self): + return _mm().VRAMState + + @property + def CPUState(self): + return _mm().CPUState + + @property + def OOM_EXCEPTION(self): + return _mm().OOM_EXCEPTION + + def __getattr__(self, name: str): + if _is_child_process(): + def child_method(*args: Any, **kwargs: Any) -> Any: + return self._relay_call(name, *args, **kwargs) + + return child_method + return getattr(_mm(), name) + + async def rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any: + normalized_args = _normalize_argument(_deserialize_value(args)) + normalized_kwargs = _normalize_argument(_deserialize_value(kwargs)) + method = getattr(_mm(), method_name) + result = method(*normalized_args, **normalized_kwargs) + return _serialize_value(result) diff --git a/comfy/isolation/proxies/progress_proxy.py b/comfy/isolation/proxies/progress_proxy.py new file mode 100644 index 000000000..8f270afa0 --- /dev/null +++ b/comfy/isolation/proxies/progress_proxy.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, Optional + +try: + from pyisolate import ProxiedSingleton +except ImportError: + + class ProxiedSingleton: + pass + +from .base import call_singleton_rpc + + +def _get_progress_state(): + from comfy_execution.progress import get_progress_state + + return get_progress_state() + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + +logger = logging.getLogger(__name__) + + +class ProgressProxy(ProxiedSingleton): + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("ProgressProxy RPC caller is not configured") + return cls._rpc + + def set_progress( + self, + value: float, + max_value: float, + node_id: Optional[str] = None, + image: Any = None, + ) -> None: + if _is_child_process(): + call_singleton_rpc( + self._get_caller(), + "rpc_set_progress", + value, + max_value, + node_id, + image, + ) + return None + + _get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=image, + ) + return None + + async def rpc_set_progress( + self, + value: float, + max_value: float, + node_id: Optional[str] = None, + image: Any = None, + ) -> None: + _get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=image, + ) + + +__all__ = ["ProgressProxy"] diff --git a/comfy/isolation/proxies/prompt_server_impl.py b/comfy/isolation/proxies/prompt_server_impl.py new file mode 100644 index 000000000..3f500522e --- /dev/null +++ b/comfy/isolation/proxies/prompt_server_impl.py @@ -0,0 +1,271 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called +"""Stateless RPC Implementation for PromptServer. + +Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture. +- Host: PromptServerService (RPC Handler) +- Child: PromptServerStub (Interface Implementation) +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any, Dict, Optional, Callable + +import logging + +# IMPORTS +from pyisolate import ProxiedSingleton +from .base import call_singleton_rpc + +logger = logging.getLogger(__name__) +LOG_PREFIX = "[Isolation:C<->H]" + +# ... + +# ============================================================================= +# CHILD SIDE: PromptServerStub +# ============================================================================= + + +class PromptServerStub: + """Stateless Stub for PromptServer.""" + + # Masquerade as the real server module + __module__ = "server" + + _instance: Optional["PromptServerStub"] = None + _rpc: Optional[Any] = None # This will be the Caller object + _source_file: Optional[str] = None + + def __init__(self): + self.routes = RouteStub(self) + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + """Inject RPC client (called by adapter.py or manually).""" + # Create caller for HOST Service + # Assuming Host Service is registered as "PromptServerService" (class name) + # We target the Host Service Class + target_id = "PromptServerService" + # We need to pass a class to create_caller? Usually yes. + # But we don't have the Service class imported here necessarily (if running on child). + # pyisolate check verify_service type? + # If we pass PromptServerStub as the 'class', it might mismatch if checking types. + # But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub. + # We need a dummy class with right name? + # Or just rely on string ID if create_caller supports it? + # Standard: rpc.create_caller(PromptServerStub, target_id) + # But wait, PromptServerStub is the *Local* class. + # We want to call *Remote* class. + # If we use PromptServerStub as the type, returning object will be typed as PromptServerStub? + # The first arg is 'service_cls'. + cls._rpc = rpc.create_caller( + PromptServerService, target_id + ) # We import Service below? + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + # We need PromptServerService available for the create_caller call? + # Or just use the Stub class if ID matches? + # prompt_server_impl.py defines BOTH. So PromptServerService IS available! + + @property + def instance(self) -> "PromptServerStub": + return self + + # ... Compatibility ... + @classmethod + def _get_source_file(cls) -> str: + if cls._source_file is None: + import folder_paths + + cls._source_file = os.path.join(folder_paths.base_path, "server.py") + return cls._source_file + + @property + def __file__(self) -> str: + return self._get_source_file() + + # --- Properties --- + @property + def client_id(self) -> Optional[str]: + return "isolated_client" + + def supports(self, feature: str) -> bool: + return True + + @property + def app(self): + raise RuntimeError( + "PromptServer.app is not accessible in isolated nodes. Use RPC routes instead." + ) + + @property + def prompt_queue(self): + raise RuntimeError( + "PromptServer.prompt_queue is not accessible in isolated nodes." + ) + + # --- UI Communication (RPC Delegates) --- + async def send_sync( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ) -> None: + if self._rpc: + await self._rpc.ui_send_sync(event, data, sid) + + async def send( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ) -> None: + if self._rpc: + await self._rpc.ui_send(event, data, sid) + + def send_progress_text(self, text: str, node_id: str, sid=None) -> None: + if self._rpc: + # Fire and forget likely needed. If method is async on host, caller invocation returns coroutine. + # We must schedule it? + # Or use fire_remote equivalent? + # Caller object usually proxies calls. If host method is async, it returns coro. + # If we are sync here (send_progress_text checks imply sync usage), we must background it. + # But UtilsProxy hook wrapper creates task. + # Does send_progress_text need to be sync? Yes, node code calls it sync. + import asyncio + + try: + loop = asyncio.get_running_loop() + loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid)) + except RuntimeError: + call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid) + + # --- Route Registration Logic --- + def register_route(self, method: str, path: str, handler: Callable): + """Register a route handler via RPC.""" + if not self._rpc: + logger.error("RPC not initialized in PromptServerStub") + return + + # Fire registration async + try: + loop = asyncio.get_running_loop() + loop.create_task(self._rpc.register_route_rpc(method, path, handler)) + except RuntimeError: + call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler) + + +class RouteStub: + """Simulates aiohttp.web.RouteTableDef.""" + + def __init__(self, stub: PromptServerStub): + self._stub = stub + + def get(self, path: str): + def decorator(handler): + self._stub.register_route("GET", path, handler) + return handler + + return decorator + + def post(self, path: str): + def decorator(handler): + self._stub.register_route("POST", path, handler) + return handler + + return decorator + + def patch(self, path: str): + def decorator(handler): + self._stub.register_route("PATCH", path, handler) + return handler + + return decorator + + def put(self, path: str): + def decorator(handler): + self._stub.register_route("PUT", path, handler) + return handler + + return decorator + + def delete(self, path: str): + def decorator(handler): + self._stub.register_route("DELETE", path, handler) + return handler + + return decorator + + +# ============================================================================= +# HOST SIDE: PromptServerService +# ============================================================================= + + +class PromptServerService(ProxiedSingleton): + """Host-side RPC Service for PromptServer.""" + + def __init__(self): + # We will bind to the real server instance lazily or via global import + pass + + @property + def server(self): + from server import PromptServer + + return PromptServer.instance + + async def ui_send_sync( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ): + await self.server.send_sync(event, data, sid) + + async def ui_send( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ): + await self.server.send(event, data, sid) + + async def ui_send_progress_text(self, text: str, node_id: str, sid=None): + # Made async to be awaitable by RPC layer + self.server.send_progress_text(text, node_id, sid) + + async def register_route_rpc(self, method: str, path: str, child_handler_proxy): + """RPC Target: Register a route that forwards to the Child.""" + from aiohttp import web + logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}") + + async def route_wrapper(request: web.Request) -> web.Response: + # 1. Capture request data + req_data = { + "method": request.method, + "path": request.path, + "query": dict(request.query), + } + if request.can_read_body: + req_data["text"] = await request.text() + + try: + # 2. Call Child Handler via RPC (child_handler_proxy is async callable) + result = await child_handler_proxy(req_data) + + # 3. Serialize Response + return self._serialize_response(result) + except Exception as e: + logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}") + return web.Response(status=500, text=str(e)) + + # Register loop + self.server.app.router.add_route(method, path, route_wrapper) + + def _serialize_response(self, result: Any) -> Any: + """Helper to convert Child result -> web.Response""" + from aiohttp import web + if isinstance(result, web.Response): + return result + # Handle dict (json) + if isinstance(result, dict): + return web.json_response(result) + # Handle string + if isinstance(result, str): + return web.Response(text=result) + # Fallback + return web.Response(text=str(result)) diff --git a/comfy/isolation/proxies/utils_proxy.py b/comfy/isolation/proxies/utils_proxy.py new file mode 100644 index 000000000..f84727bbb --- /dev/null +++ b/comfy/isolation/proxies/utils_proxy.py @@ -0,0 +1,64 @@ +# pylint: disable=cyclic-import,import-outside-toplevel +from __future__ import annotations + +from typing import Optional, Any +from pyisolate import ProxiedSingleton + +import os + + +def _comfy_utils(): + import comfy.utils + return comfy.utils + + +class UtilsProxy(ProxiedSingleton): + """ + Proxy for comfy.utils. + Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates + from isolated nodes reach the host. + """ + + # _instance and __new__ removed to rely on SingletonMetaclass + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + # Create caller using class name as ID (standard for Singletons) + cls._rpc = rpc.create_caller(cls, "UtilsProxy") + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + async def progress_bar_hook( + self, + value: int, + total: int, + preview: Optional[bytes] = None, + node_id: Optional[str] = None, + ) -> Any: + """ + Host-side implementation: forwards the call to the real global hook. + Child-side: this method call is intercepted by RPC and sent to host. + """ + if os.environ.get("PYISOLATE_CHILD") == "1": + if UtilsProxy._rpc is None: + raise RuntimeError("UtilsProxy RPC caller is not configured") + return await UtilsProxy._rpc.progress_bar_hook( + value, total, preview, node_id + ) + + # Host Execution + utils = _comfy_utils() + if utils.PROGRESS_BAR_HOOK is not None: + return utils.PROGRESS_BAR_HOOK(value, total, preview, node_id) + return None + + def set_progress_bar_global_hook(self, hook: Any) -> None: + """Forward hook registration (though usually not needed from child).""" + if os.environ.get("PYISOLATE_CHILD") == "1": + raise RuntimeError( + "UtilsProxy.set_progress_bar_global_hook is not available in child without exact relay support" + ) + _comfy_utils().set_progress_bar_global_hook(hook) diff --git a/comfy/isolation/proxies/web_directory_proxy.py b/comfy/isolation/proxies/web_directory_proxy.py new file mode 100644 index 000000000..3acf3f4fc --- /dev/null +++ b/comfy/isolation/proxies/web_directory_proxy.py @@ -0,0 +1,219 @@ +"""WebDirectoryProxy — serves isolated node web assets via RPC. + +Child side: enumerates and reads files from the extension's web/ directory. +Host side: gets an RPC proxy that fetches file listings and contents on demand. + +Only files with allowed extensions (.js, .html, .css) are served. +Directory traversal is rejected. File contents are base64-encoded for +safe JSON-RPC transport. +""" + +from __future__ import annotations + +import base64 +import logging +import os +from pathlib import Path, PurePosixPath +from typing import Any, Dict, List + +from pyisolate import ProxiedSingleton + +logger = logging.getLogger(__name__) + +ALLOWED_EXTENSIONS = frozenset({".js", ".html", ".css"}) + +MIME_TYPES = { + ".js": "application/javascript", + ".html": "text/html", + ".css": "text/css", +} + + +class WebDirectoryProxy(ProxiedSingleton): + """Proxy for serving isolated extension web directories. + + On the child side, this class has direct filesystem access to the + extension's web/ directory. On the host side, callers get an RPC + proxy whose method calls are forwarded to the child. + """ + + # {extension_name: absolute_path_to_web_dir} + _web_dirs: dict[str, str] = {} + + @classmethod + def register_web_dir(cls, extension_name: str, web_dir_path: str) -> None: + """Register an extension's web directory (child-side only).""" + cls._web_dirs[extension_name] = web_dir_path + logger.info( + "][ WebDirectoryProxy: registered %s -> %s", + extension_name, + web_dir_path, + ) + + def list_web_files(self, extension_name: str) -> List[Dict[str, str]]: + """Return a list of servable files in the extension's web directory. + + Each entry is {"relative_path": "js/foo.js", "content_type": "application/javascript"}. + Only files with allowed extensions are included. + """ + web_dir = self._web_dirs.get(extension_name) + if not web_dir: + return [] + + root = Path(web_dir) + if not root.is_dir(): + return [] + + result: List[Dict[str, str]] = [] + for path in sorted(root.rglob("*")): + if not path.is_file(): + continue + ext = path.suffix.lower() + if ext not in ALLOWED_EXTENSIONS: + continue + rel = path.relative_to(root) + result.append({ + "relative_path": str(PurePosixPath(rel)), + "content_type": MIME_TYPES[ext], + }) + return result + + def get_web_file( + self, extension_name: str, relative_path: str + ) -> Dict[str, Any]: + """Return the contents of a single web file as base64. + + Raises ValueError for traversal attempts or disallowed file types. + Returns {"content": , "content_type": }. + """ + _validate_path(relative_path) + + web_dir = self._web_dirs.get(extension_name) + if not web_dir: + raise FileNotFoundError( + f"No web directory registered for {extension_name}" + ) + + root = Path(web_dir) + target = (root / relative_path).resolve() + + # Ensure resolved path is under the web directory + if not str(target).startswith(str(root.resolve())): + raise ValueError(f"Path escapes web directory: {relative_path}") + + if not target.is_file(): + raise FileNotFoundError(f"File not found: {relative_path}") + + ext = target.suffix.lower() + if ext not in ALLOWED_EXTENSIONS: + raise ValueError(f"Disallowed file type: {ext}") + + content_type = MIME_TYPES[ext] + raw = target.read_bytes() + + return { + "content": base64.b64encode(raw).decode("ascii"), + "content_type": content_type, + } + + +def _validate_path(relative_path: str) -> None: + """Reject directory traversal and absolute paths.""" + if os.path.isabs(relative_path): + raise ValueError(f"Absolute paths are not allowed: {relative_path}") + if ".." in PurePosixPath(relative_path).parts: + raise ValueError(f"Directory traversal is not allowed: {relative_path}") + + +# --------------------------------------------------------------------------- +# Host-side cache and aiohttp handler +# --------------------------------------------------------------------------- + + +class WebDirectoryCache: + """Host-side in-memory cache for proxied web directory contents. + + Populated lazily via RPC calls to the child's WebDirectoryProxy. + Once a file is cached, subsequent requests are served from memory. + """ + + def __init__(self) -> None: + # {extension_name: {relative_path: {"content": bytes, "content_type": str}}} + self._file_cache: dict[str, dict[str, dict[str, Any]]] = {} + # {extension_name: [{"relative_path": str, "content_type": str}, ...]} + self._listing_cache: dict[str, list[dict[str, str]]] = {} + # {extension_name: WebDirectoryProxy (RPC proxy instance)} + self._proxies: dict[str, Any] = {} + + def register_proxy(self, extension_name: str, proxy: Any) -> None: + """Register an RPC proxy for an extension's web directory.""" + self._proxies[extension_name] = proxy + logger.info( + "][ WebDirectoryCache: registered proxy for %s", extension_name + ) + + @property + def extension_names(self) -> list[str]: + return list(self._proxies.keys()) + + def list_files(self, extension_name: str) -> list[dict[str, str]]: + """List servable files for an extension (cached after first call).""" + if extension_name not in self._listing_cache: + proxy = self._proxies.get(extension_name) + if proxy is None: + return [] + try: + self._listing_cache[extension_name] = proxy.list_web_files( + extension_name + ) + except Exception: + logger.warning( + "][ WebDirectoryCache: failed to list files for %s", + extension_name, + exc_info=True, + ) + return [] + return self._listing_cache[extension_name] + + def get_file( + self, extension_name: str, relative_path: str + ) -> dict[str, Any] | None: + """Get file content (cached after first fetch). Returns None on miss.""" + ext_cache = self._file_cache.get(extension_name) + if ext_cache and relative_path in ext_cache: + return ext_cache[relative_path] + + proxy = self._proxies.get(extension_name) + if proxy is None: + return None + + try: + result = proxy.get_web_file(extension_name, relative_path) + except (FileNotFoundError, ValueError): + return None + except Exception: + logger.warning( + "][ WebDirectoryCache: failed to fetch %s/%s", + extension_name, + relative_path, + exc_info=True, + ) + return None + + decoded = { + "content": base64.b64decode(result["content"]), + "content_type": result["content_type"], + } + + if extension_name not in self._file_cache: + self._file_cache[extension_name] = {} + self._file_cache[extension_name][relative_path] = decoded + return decoded + + +# Global cache instance — populated during isolation loading +_web_directory_cache = WebDirectoryCache() + + +def get_web_directory_cache() -> WebDirectoryCache: + return _web_directory_cache