From 683e2d6a733e4bdf731b70115059ef867e09572c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 29 Mar 2026 19:02:43 -0500 Subject: [PATCH] feat: model and service proxies for isolated custom nodes --- comfy/isolation/model_patcher_proxy.py | 59 +++-- .../isolation/model_patcher_proxy_registry.py | 113 +++++++-- comfy/isolation/proxies/base.py | 68 ++++-- comfy/isolation/proxies/folder_paths_proxy.py | 189 ++++++++++++++- comfy/isolation/proxies/helper_proxies.py | 64 ++++- .../proxies/model_management_proxy.py | 137 ++++++++++- comfy/isolation/proxies/progress_proxy.py | 56 ++++- comfy/isolation/proxies/prompt_server_impl.py | 14 +- comfy/isolation/proxies/utils_proxy.py | 46 ++-- .../isolation/proxies/web_directory_proxy.py | 219 ++++++++++++++++++ 10 files changed, 858 insertions(+), 107 deletions(-) create mode 100644 comfy/isolation/proxies/web_directory_proxy.py diff --git a/comfy/isolation/model_patcher_proxy.py b/comfy/isolation/model_patcher_proxy.py index e1c513933..f44de1d5a 100644 --- a/comfy/isolation/model_patcher_proxy.py +++ b/comfy/isolation/model_patcher_proxy.py @@ -22,6 +22,16 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): __module__ = "comfy.model_patcher" _APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024 + def _spawn_related_proxy(self, instance_id: str) -> "ModelPatcherProxy": + proxy = ModelPatcherProxy( + instance_id, + self._registry, + manage_lifecycle=not IS_CHILD_PROCESS, + ) + if getattr(self, "_rpc_caller", None) is not None: + proxy._rpc_caller = self._rpc_caller + return proxy + def _get_rpc(self) -> Any: if self._rpc_caller is None: from pyisolate._internal.rpc_protocol import get_child_rpc_instance @@ -164,9 +174,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): def clone(self) -> ModelPatcherProxy: new_id = self._call_rpc("clone") - return ModelPatcherProxy( - new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS - ) + return self._spawn_related_proxy(new_id) def clone_has_same_weights(self, clone: Any) -> bool: if isinstance(clone, ModelPatcherProxy): @@ -509,11 +517,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): ) new_model = None if result.get("model_id"): - new_model = ModelPatcherProxy( - result["model_id"], - self._registry, - manage_lifecycle=not IS_CHILD_PROCESS, - ) + new_model = self._spawn_related_proxy(result["model_id"]) new_clip = None if result.get("clip_id"): from comfy.isolation.clip_proxy import CLIPProxy @@ -789,12 +793,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): def get_additional_models(self) -> List[ModelPatcherProxy]: ids = self._call_rpc("get_additional_models") - return [ - ModelPatcherProxy( - mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS - ) - for mid in ids - ] + return [self._spawn_related_proxy(mid) for mid in ids] def model_patches_models(self) -> Any: return self._call_rpc("model_patches_models") @@ -803,6 +802,25 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): def parent(self) -> Any: return self._call_rpc("get_parent") + def model_mmap_residency(self, free: bool = False) -> tuple: + result = self._call_rpc("model_mmap_residency", free) + if isinstance(result, list): + return tuple(result) + return result + + def pinned_memory_size(self) -> int: + return self._call_rpc("pinned_memory_size") + + def get_non_dynamic_delegate(self) -> ModelPatcherProxy: + new_id = self._call_rpc("get_non_dynamic_delegate") + return self._spawn_related_proxy(new_id) + + def disable_model_cfg1_optimization(self) -> None: + self._call_rpc("disable_model_cfg1_optimization") + + def set_model_noise_refiner_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "noise_refiner") + class _InnerModelProxy: def __init__(self, parent: ModelPatcherProxy): @@ -812,8 +830,14 @@ class _InnerModelProxy: def __getattr__(self, name: str) -> Any: if name.startswith("_"): raise AttributeError(name) + if name == "model_config": + from types import SimpleNamespace + + data = self._parent._call_rpc("get_inner_model_attr", name) + if isinstance(data, dict): + return SimpleNamespace(**data) + return data if name in ( - "model_config", "latent_format", "model_type", "current_weight_patches_uuid", @@ -824,11 +848,14 @@ class _InnerModelProxy: if name == "device": return self._parent._call_rpc("get_inner_model_attr", "device") if name == "current_patcher": - return ModelPatcherProxy( + proxy = ModelPatcherProxy( self._parent._instance_id, self._parent._registry, manage_lifecycle=False, ) + if getattr(self._parent, "_rpc_caller", None) is not None: + proxy._rpc_caller = self._parent._rpc_caller + return proxy if name == "model_sampling": if self._model_sampling is None: self._model_sampling = self._parent._call_rpc( diff --git a/comfy/isolation/model_patcher_proxy_registry.py b/comfy/isolation/model_patcher_proxy_registry.py index c696f6a0a..b657121eb 100644 --- a/comfy/isolation/model_patcher_proxy_registry.py +++ b/comfy/isolation/model_patcher_proxy_registry.py @@ -250,22 +250,47 @@ class ModelPatcherRegistry(BaseRegistry[Any]): return f"" result = instance.get_model_object(name) if name == "model_sampling": - from comfy.isolation.model_sampling_proxy import ( - ModelSamplingRegistry, - ModelSamplingProxy, - ) - - registry = ModelSamplingRegistry() - # Preserve identity when upstream already returned a proxy. Re-registering - # a proxy object creates proxy-of-proxy call chains. - if isinstance(result, ModelSamplingProxy): - sampling_id = result._instance_id - else: - sampling_id = registry.register(result) - return ModelSamplingProxy(sampling_id, registry) + # Return inline serialization so the child reconstructs the real + # class with correct isinstance behavior. Returning a + # ModelSamplingProxy breaks isinstance checks (e.g. + # offset_first_sigma_for_snr in k_diffusion/sampling.py:173). + return self._serialize_model_sampling_inline(result) return detach_if_grad(result) + @staticmethod + def _serialize_model_sampling_inline(obj: Any) -> dict: + """Serialize a ModelSampling object as inline data for the child to reconstruct.""" + import torch + import base64 + import io as _io + + bases = [] + for base in type(obj).__mro__: + if base.__module__ == "comfy.model_sampling" and base.__name__ != "object": + bases.append(base.__name__) + + sd = obj.state_dict() + sd_serialized = {} + for k, v in sd.items(): + buf = _io.BytesIO() + torch.save(v, buf) + sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii") + + plain_attrs = {} + for k, v in obj.__dict__.items(): + if k.startswith("_"): + continue + if isinstance(v, (bool, int, float, str)): + plain_attrs[k] = v + + return { + "__type__": "ModelSamplingInline", + "bases": bases, + "state_dict": sd_serialized, + "attrs": plain_attrs, + } + async def get_model_options(self, instance_id: str) -> dict: instance = self._get_instance(instance_id) import copy @@ -348,6 +373,20 @@ class ModelPatcherRegistry(BaseRegistry[Any]): async def get_ram_usage(self, instance_id: str) -> int: return self._get_instance(instance_id).get_ram_usage() + async def model_mmap_residency(self, instance_id: str, free: bool = False) -> tuple: + return self._get_instance(instance_id).model_mmap_residency(free=free) + + async def pinned_memory_size(self, instance_id: str) -> int: + return self._get_instance(instance_id).pinned_memory_size() + + async def get_non_dynamic_delegate(self, instance_id: str) -> str: + instance = self._get_instance(instance_id) + delegate = instance.get_non_dynamic_delegate() + return self.register(delegate) + + async def disable_model_cfg1_optimization(self, instance_id: str) -> None: + self._get_instance(instance_id).disable_model_cfg1_optimization() + async def lowvram_patch_counter(self, instance_id: str) -> int: return self._get_instance(instance_id).lowvram_patch_counter() @@ -959,12 +998,54 @@ class ModelPatcherRegistry(BaseRegistry[Any]): async def get_inner_model_attr(self, instance_id: str, name: str) -> Any: try: - return self._sanitize_rpc_result( - getattr(self._get_instance(instance_id).model, name) - ) + value = getattr(self._get_instance(instance_id).model, name) + if name == "model_config": + value = self._extract_model_config(value) + return self._sanitize_rpc_result(value) except AttributeError: return None + @staticmethod + def _extract_model_config(config: Any) -> dict: + """Extract JSON-safe attributes from a model config object. + + ComfyUI model config classes (supported_models_base.BASE subclasses) + have a permissive __getattr__ that returns None for any unknown + attribute instead of raising AttributeError. This defeats hasattr-based + duck-typing in _sanitize_rpc_result, causing TypeError when it tries + to call obj.items() (which resolves to None). We extract the real + class-level and instance-level attributes into a plain dict. + """ + # Attributes consumed by ModelSampling*.__init__ and other callers + _CONFIG_KEYS = ( + "sampling_settings", + "unet_config", + "unet_extra_config", + "latent_format", + "manual_cast_dtype", + "custom_operations", + "optimizations", + "memory_usage_factor", + "supported_inference_dtypes", + ) + result: dict = {} + for key in _CONFIG_KEYS: + # Use type(config).__dict__ first (class attrs), then instance __dict__ + # to avoid triggering the permissive __getattr__ + if key in type(config).__dict__: + val = type(config).__dict__[key] + # Skip classmethods/staticmethods/descriptors + if not callable(val) or isinstance(val, (dict, list, tuple)): + result[key] = val + elif hasattr(config, "__dict__") and key in config.__dict__: + result[key] = config.__dict__[key] + # Also include instance overrides (e.g. set_inference_dtype sets unet_config['dtype']) + if hasattr(config, "__dict__"): + for key, val in config.__dict__.items(): + if key in _CONFIG_KEYS: + result[key] = val + return result + async def inner_model_memory_required( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: diff --git a/comfy/isolation/proxies/base.py b/comfy/isolation/proxies/base.py index 71cc1943c..498554217 100644 --- a/comfy/isolation/proxies/base.py +++ b/comfy/isolation/proxies/base.py @@ -118,6 +118,47 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None: _GLOBAL_LOOP = loop +def run_sync_rpc_coro(coro: Any, timeout_ms: Optional[int] = None) -> Any: + if timeout_ms is not None: + coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0) + + try: + if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): + try: + curr_loop = asyncio.get_running_loop() + if curr_loop is _GLOBAL_LOOP: + pass + except RuntimeError: + future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP) + return future.result( + timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None + ) + + try: + asyncio.get_running_loop() + return run_coro_in_new_loop(coro) + except RuntimeError: + loop = get_thread_loop() + return loop.run_until_complete(coro) + except asyncio.TimeoutError as exc: + raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc + except concurrent.futures.TimeoutError as exc: + raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc + + +def call_singleton_rpc( + caller: Any, + method_name: str, + *args: Any, + timeout_ms: Optional[int] = None, + **kwargs: Any, +) -> Any: + if caller is None: + raise RuntimeError(f"No RPC caller available for {method_name}") + method = getattr(caller, method_name) + return run_sync_rpc_coro(method(*args, **kwargs), timeout_ms=timeout_ms) + + class BaseProxy(Generic[T]): _registry_class: type = BaseRegistry # type: ignore[type-arg] __module__: str = "comfy.isolation.proxies.base" @@ -208,31 +249,8 @@ class BaseProxy(Generic[T]): ) try: - # If we have a global loop (Main Thread Loop), use it for dispatch from worker threads - if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): - try: - curr_loop = asyncio.get_running_loop() - if curr_loop is _GLOBAL_LOOP: - pass - except RuntimeError: - # No running loop - we are in a worker thread. - future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP) - return future.result( - timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None - ) - - try: - asyncio.get_running_loop() - return run_coro_in_new_loop(coro) - except RuntimeError: - loop = get_thread_loop() - return loop.run_until_complete(coro) - except asyncio.TimeoutError as exc: - raise TimeoutError( - f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} " - f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})" - ) from exc - except concurrent.futures.TimeoutError as exc: + return run_sync_rpc_coro(coro, timeout_ms=timeout_ms) + except TimeoutError as exc: raise TimeoutError( f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} " f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})" diff --git a/comfy/isolation/proxies/folder_paths_proxy.py b/comfy/isolation/proxies/folder_paths_proxy.py index a2996ec24..eb077c817 100644 --- a/comfy/isolation/proxies/folder_paths_proxy.py +++ b/comfy/isolation/proxies/folder_paths_proxy.py @@ -1,9 +1,35 @@ from __future__ import annotations -from typing import Dict +import os +from typing import Any, Dict, Optional -import folder_paths from pyisolate import ProxiedSingleton +from .base import call_singleton_rpc + + +def _folder_paths(): + import folder_paths + + return folder_paths + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def _serialize_folder_names_and_paths(data: dict[str, tuple[list[str], set[str]]]) -> dict[str, dict[str, list[str]]]: + return { + key: {"paths": list(paths), "extensions": sorted(list(extensions))} + for key, (paths, extensions) in data.items() + } + + +def _deserialize_folder_names_and_paths(data: dict[str, dict[str, list[str]]]) -> dict[str, tuple[list[str], set[str]]]: + return { + key: (list(value.get("paths", [])), set(value.get("extensions", []))) + for key, value in data.items() + } + class FolderPathsProxy(ProxiedSingleton): """ @@ -12,18 +38,165 @@ class FolderPathsProxy(ProxiedSingleton): mutable collections to ensure efficient by-value transfer. """ - def __getattr__(self, name): - return getattr(folder_paths, name) + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("FolderPathsProxy RPC caller is not configured") + return cls._rpc + + def __getattr__(self, name): + if _is_child_process(): + property_rpc = { + "models_dir": "rpc_get_models_dir", + "folder_names_and_paths": "rpc_get_folder_names_and_paths", + "extension_mimetypes_cache": "rpc_get_extension_mimetypes_cache", + "filename_list_cache": "rpc_get_filename_list_cache", + } + rpc_name = property_rpc.get(name) + if rpc_name is not None: + return call_singleton_rpc(self._get_caller(), rpc_name) + raise AttributeError(name) + return getattr(_folder_paths(), name) - # Return dict snapshots (avoid RPC chatter) @property def folder_names_and_paths(self) -> Dict: - return dict(folder_paths.folder_names_and_paths) + if _is_child_process(): + payload = call_singleton_rpc(self._get_caller(), "rpc_get_folder_names_and_paths") + return _deserialize_folder_names_and_paths(payload) + return _folder_paths().folder_names_and_paths @property def extension_mimetypes_cache(self) -> Dict: - return dict(folder_paths.extension_mimetypes_cache) + if _is_child_process(): + return dict(call_singleton_rpc(self._get_caller(), "rpc_get_extension_mimetypes_cache")) + return dict(_folder_paths().extension_mimetypes_cache) @property def filename_list_cache(self) -> Dict: - return dict(folder_paths.filename_list_cache) + if _is_child_process(): + return dict(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list_cache")) + return dict(_folder_paths().filename_list_cache) + + @property + def models_dir(self) -> str: + if _is_child_process(): + return str(call_singleton_rpc(self._get_caller(), "rpc_get_models_dir")) + return _folder_paths().models_dir + + def get_temp_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_temp_directory") + return _folder_paths().get_temp_directory() + + def get_input_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_input_directory") + return _folder_paths().get_input_directory() + + def get_output_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_output_directory") + return _folder_paths().get_output_directory() + + def get_user_directory(self) -> str: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_user_directory") + return _folder_paths().get_user_directory() + + def get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str: + if _is_child_process(): + return call_singleton_rpc( + self._get_caller(), "rpc_get_annotated_filepath", name, default_dir + ) + return _folder_paths().get_annotated_filepath(name, default_dir) + + def exists_annotated_filepath(self, name: str) -> bool: + if _is_child_process(): + return bool( + call_singleton_rpc(self._get_caller(), "rpc_exists_annotated_filepath", name) + ) + return bool(_folder_paths().exists_annotated_filepath(name)) + + def add_model_folder_path( + self, folder_name: str, full_folder_path: str, is_default: bool = False + ) -> None: + if _is_child_process(): + call_singleton_rpc( + self._get_caller(), + "rpc_add_model_folder_path", + folder_name, + full_folder_path, + is_default, + ) + return None + _folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default) + return None + + def get_folder_paths(self, folder_name: str) -> list[str]: + if _is_child_process(): + return list(call_singleton_rpc(self._get_caller(), "rpc_get_folder_paths", folder_name)) + return list(_folder_paths().get_folder_paths(folder_name)) + + def get_filename_list(self, folder_name: str) -> list[str]: + if _is_child_process(): + return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name)) + return list(_folder_paths().get_filename_list(folder_name)) + + def get_full_path(self, folder_name: str, filename: str) -> str | None: + if _is_child_process(): + return call_singleton_rpc(self._get_caller(), "rpc_get_full_path", folder_name, filename) + return _folder_paths().get_full_path(folder_name, filename) + + async def rpc_get_models_dir(self) -> str: + return _folder_paths().models_dir + + async def rpc_get_folder_names_and_paths(self) -> dict[str, dict[str, list[str]]]: + return _serialize_folder_names_and_paths(_folder_paths().folder_names_and_paths) + + async def rpc_get_extension_mimetypes_cache(self) -> dict[str, Any]: + return dict(_folder_paths().extension_mimetypes_cache) + + async def rpc_get_filename_list_cache(self) -> dict[str, Any]: + return dict(_folder_paths().filename_list_cache) + + async def rpc_get_temp_directory(self) -> str: + return _folder_paths().get_temp_directory() + + async def rpc_get_input_directory(self) -> str: + return _folder_paths().get_input_directory() + + async def rpc_get_output_directory(self) -> str: + return _folder_paths().get_output_directory() + + async def rpc_get_user_directory(self) -> str: + return _folder_paths().get_user_directory() + + async def rpc_get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str: + return _folder_paths().get_annotated_filepath(name, default_dir) + + async def rpc_exists_annotated_filepath(self, name: str) -> bool: + return _folder_paths().exists_annotated_filepath(name) + + async def rpc_add_model_folder_path( + self, folder_name: str, full_folder_path: str, is_default: bool = False + ) -> None: + _folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default) + + async def rpc_get_folder_paths(self, folder_name: str) -> list[str]: + return _folder_paths().get_folder_paths(folder_name) + + async def rpc_get_filename_list(self, folder_name: str) -> list[str]: + return _folder_paths().get_filename_list(folder_name) + + async def rpc_get_full_path(self, folder_name: str, filename: str) -> str | None: + return _folder_paths().get_full_path(folder_name, filename) diff --git a/comfy/isolation/proxies/helper_proxies.py b/comfy/isolation/proxies/helper_proxies.py index a50b9e4c4..278c098f1 100644 --- a/comfy/isolation/proxies/helper_proxies.py +++ b/comfy/isolation/proxies/helper_proxies.py @@ -1,7 +1,12 @@ from __future__ import annotations +import os from typing import Any, Dict, Optional +from pyisolate import ProxiedSingleton + +from .base import call_singleton_rpc + class AnyTypeProxy(str): """Replacement for custom AnyType objects used by some nodes.""" @@ -71,9 +76,29 @@ def _restore_special_value(value: Any) -> Any: return value -def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]: - """Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects.""" +def _serialize_special_value(value: Any) -> Any: + if isinstance(value, AnyTypeProxy): + return {"__pyisolate_any_type__": True, "value": str(value)} + if isinstance(value, FlexibleOptionalInputProxy): + return { + "__pyisolate_flexible_optional__": True, + "type": _serialize_special_value(value.type), + "data": {k: _serialize_special_value(v) for k, v in value.items()}, + } + if isinstance(value, ByPassTypeTupleProxy): + return { + "__pyisolate_bypass_tuple__": [_serialize_special_value(v) for v in value] + } + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_serialize_special_value(v) for v in value]} + if isinstance(value, list): + return [_serialize_special_value(v) for v in value] + if isinstance(value, dict): + return {k: _serialize_special_value(v) for k, v in value.items()} + return value + +def _restore_input_types_local(raw: Dict[str, object]) -> Dict[str, object]: if not isinstance(raw, dict): return raw # type: ignore[return-value] @@ -90,9 +115,44 @@ def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]: return restored +class HelperProxiesService(ProxiedSingleton): + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("HelperProxiesService RPC caller is not configured") + return cls._rpc + + async def rpc_restore_input_types(self, raw: Dict[str, object]) -> Dict[str, object]: + restored = _restore_input_types_local(raw) + return _serialize_special_value(restored) + + +def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]: + """Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects.""" + if os.environ.get("PYISOLATE_CHILD") == "1": + payload = call_singleton_rpc( + HelperProxiesService._get_caller(), + "rpc_restore_input_types", + raw, + ) + return _restore_input_types_local(payload) + return _restore_input_types_local(raw) + + __all__ = [ "AnyTypeProxy", "FlexibleOptionalInputProxy", "ByPassTypeTupleProxy", + "HelperProxiesService", "restore_input_types", ] diff --git a/comfy/isolation/proxies/model_management_proxy.py b/comfy/isolation/proxies/model_management_proxy.py index 00e14d9b4..445210aa4 100644 --- a/comfy/isolation/proxies/model_management_proxy.py +++ b/comfy/isolation/proxies/model_management_proxy.py @@ -1,27 +1,142 @@ -import comfy.model_management as mm +from __future__ import annotations + +import os +from typing import Any, Optional + from pyisolate import ProxiedSingleton +from .base import call_singleton_rpc + + +def _mm(): + import comfy.model_management + + return comfy.model_management + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +class TorchDeviceProxy: + def __init__(self, device_str: str): + self._device_str = device_str + if ":" in device_str: + device_type, index = device_str.split(":", 1) + self.type = device_type + self.index = int(index) + else: + self.type = device_str + self.index = None + + def __str__(self) -> str: + return self._device_str + + def __repr__(self) -> str: + return f"TorchDeviceProxy({self._device_str!r})" + + +def _serialize_value(value: Any) -> Any: + value_type = type(value) + if value_type.__module__ == "torch" and value_type.__name__ == "device": + return {"__pyisolate_torch_device__": str(value)} + if isinstance(value, TorchDeviceProxy): + return {"__pyisolate_torch_device__": str(value)} + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_serialize_value(item) for item in value]} + if isinstance(value, list): + return [_serialize_value(item) for item in value] + if isinstance(value, dict): + return {key: _serialize_value(inner) for key, inner in value.items()} + return value + + +def _deserialize_value(value: Any) -> Any: + if isinstance(value, dict): + if "__pyisolate_torch_device__" in value: + return TorchDeviceProxy(value["__pyisolate_torch_device__"]) + if "__pyisolate_tuple__" in value: + return tuple(_deserialize_value(item) for item in value["__pyisolate_tuple__"]) + return {key: _deserialize_value(inner) for key, inner in value.items()} + if isinstance(value, list): + return [_deserialize_value(item) for item in value] + return value + + +def _normalize_argument(value: Any) -> Any: + if isinstance(value, TorchDeviceProxy): + import torch + + return torch.device(str(value)) + if isinstance(value, dict): + if "__pyisolate_torch_device__" in value: + import torch + + return torch.device(value["__pyisolate_torch_device__"]) + if "__pyisolate_tuple__" in value: + return tuple(_normalize_argument(item) for item in value["__pyisolate_tuple__"]) + return {key: _normalize_argument(inner) for key, inner in value.items()} + if isinstance(value, list): + return [_normalize_argument(item) for item in value] + return value + class ModelManagementProxy(ProxiedSingleton): """ - Dynamic proxy for comfy.model_management. - Uses __getattr__ to forward all calls to the underlying module, - reducing maintenance burden. + Exact-relay proxy for comfy.model_management. + Child calls never import comfy.model_management directly; they serialize + arguments, relay to host, and deserialize the host result back. """ - # Explicitly expose Enums/Classes as properties + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("ModelManagementProxy RPC caller is not configured") + return cls._rpc + + def _relay_call(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + payload = call_singleton_rpc( + self._get_caller(), + "rpc_call", + method_name, + _serialize_value(args), + _serialize_value(kwargs), + ) + return _deserialize_value(payload) + @property def VRAMState(self): - return mm.VRAMState + return _mm().VRAMState @property def CPUState(self): - return mm.CPUState + return _mm().CPUState @property def OOM_EXCEPTION(self): - return mm.OOM_EXCEPTION + return _mm().OOM_EXCEPTION - def __getattr__(self, name): - """Forward all other attribute access to the module.""" - return getattr(mm, name) + def __getattr__(self, name: str): + if _is_child_process(): + def child_method(*args: Any, **kwargs: Any) -> Any: + return self._relay_call(name, *args, **kwargs) + + return child_method + return getattr(_mm(), name) + + async def rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any: + normalized_args = _normalize_argument(_deserialize_value(args)) + normalized_kwargs = _normalize_argument(_deserialize_value(kwargs)) + method = getattr(_mm(), method_name) + result = method(*normalized_args, **normalized_kwargs) + return _serialize_value(result) diff --git a/comfy/isolation/proxies/progress_proxy.py b/comfy/isolation/proxies/progress_proxy.py index 44494ea31..8f270afa0 100644 --- a/comfy/isolation/proxies/progress_proxy.py +++ b/comfy/isolation/proxies/progress_proxy.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from typing import Any, Optional try: @@ -10,13 +11,38 @@ except ImportError: class ProxiedSingleton: pass +from .base import call_singleton_rpc -from comfy_execution.progress import get_progress_state + +def _get_progress_state(): + from comfy_execution.progress import get_progress_state + + return get_progress_state() + + +def _is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" logger = logging.getLogger(__name__) class ProgressProxy(ProxiedSingleton): + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + cls._rpc = rpc.create_caller(cls, cls.get_remote_id()) + + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + + @classmethod + def _get_caller(cls) -> Any: + if cls._rpc is None: + raise RuntimeError("ProgressProxy RPC caller is not configured") + return cls._rpc + def set_progress( self, value: float, @@ -24,7 +50,33 @@ class ProgressProxy(ProxiedSingleton): node_id: Optional[str] = None, image: Any = None, ) -> None: - get_progress_state().update_progress( + if _is_child_process(): + call_singleton_rpc( + self._get_caller(), + "rpc_set_progress", + value, + max_value, + node_id, + image, + ) + return None + + _get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=image, + ) + return None + + async def rpc_set_progress( + self, + value: float, + max_value: float, + node_id: Optional[str] = None, + image: Any = None, + ) -> None: + _get_progress_state().update_progress( node_id=node_id, value=value, max_value=max_value, diff --git a/comfy/isolation/proxies/prompt_server_impl.py b/comfy/isolation/proxies/prompt_server_impl.py index 2a775e097..3f500522e 100644 --- a/comfy/isolation/proxies/prompt_server_impl.py +++ b/comfy/isolation/proxies/prompt_server_impl.py @@ -13,10 +13,10 @@ import os from typing import Any, Dict, Optional, Callable import logging -from aiohttp import web # IMPORTS from pyisolate import ProxiedSingleton +from .base import call_singleton_rpc logger = logging.getLogger(__name__) LOG_PREFIX = "[Isolation:C<->H]" @@ -64,6 +64,10 @@ class PromptServerStub: PromptServerService, target_id ) # We import Service below? + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + # We need PromptServerService available for the create_caller call? # Or just use the Stub class if ID matches? # prompt_server_impl.py defines BOTH. So PromptServerService IS available! @@ -133,7 +137,7 @@ class PromptServerStub: loop = asyncio.get_running_loop() loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid)) except RuntimeError: - pass # Sync context without loop? + call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid) # --- Route Registration Logic --- def register_route(self, method: str, path: str, handler: Callable): @@ -147,7 +151,7 @@ class PromptServerStub: loop = asyncio.get_running_loop() loop.create_task(self._rpc.register_route_rpc(method, path, handler)) except RuntimeError: - pass + call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler) class RouteStub: @@ -226,6 +230,7 @@ class PromptServerService(ProxiedSingleton): async def register_route_rpc(self, method: str, path: str, child_handler_proxy): """RPC Target: Register a route that forwards to the Child.""" + from aiohttp import web logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}") async def route_wrapper(request: web.Request) -> web.Response: @@ -251,8 +256,9 @@ class PromptServerService(ProxiedSingleton): # Register loop self.server.app.router.add_route(method, path, route_wrapper) - def _serialize_response(self, result: Any) -> web.Response: + def _serialize_response(self, result: Any) -> Any: """Helper to convert Child result -> web.Response""" + from aiohttp import web if isinstance(result, web.Response): return result # Handle dict (json) diff --git a/comfy/isolation/proxies/utils_proxy.py b/comfy/isolation/proxies/utils_proxy.py index 432f7ec90..f84727bbb 100644 --- a/comfy/isolation/proxies/utils_proxy.py +++ b/comfy/isolation/proxies/utils_proxy.py @@ -2,12 +2,16 @@ from __future__ import annotations from typing import Optional, Any -import comfy.utils from pyisolate import ProxiedSingleton import os +def _comfy_utils(): + import comfy.utils + return comfy.utils + + class UtilsProxy(ProxiedSingleton): """ Proxy for comfy.utils. @@ -23,6 +27,10 @@ class UtilsProxy(ProxiedSingleton): # Create caller using class name as ID (standard for Singletons) cls._rpc = rpc.create_caller(cls, "UtilsProxy") + @classmethod + def clear_rpc(cls) -> None: + cls._rpc = None + async def progress_bar_hook( self, value: int, @@ -35,30 +43,22 @@ class UtilsProxy(ProxiedSingleton): Child-side: this method call is intercepted by RPC and sent to host. """ if os.environ.get("PYISOLATE_CHILD") == "1": - # Manual RPC dispatch for Child process - # Use class-level RPC storage (Static Injection) - if UtilsProxy._rpc: - return await UtilsProxy._rpc.progress_bar_hook( - value, total, preview, node_id - ) - - # Fallback channel: global child rpc - try: - from pyisolate._internal.rpc_protocol import get_child_rpc_instance - - get_child_rpc_instance() - # If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it, - # but we need a caller. For now, just pass to avoid crashing. - pass - except (ImportError, LookupError): - pass - - return None + if UtilsProxy._rpc is None: + raise RuntimeError("UtilsProxy RPC caller is not configured") + return await UtilsProxy._rpc.progress_bar_hook( + value, total, preview, node_id + ) # Host Execution - if comfy.utils.PROGRESS_BAR_HOOK is not None: - comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id) + utils = _comfy_utils() + if utils.PROGRESS_BAR_HOOK is not None: + return utils.PROGRESS_BAR_HOOK(value, total, preview, node_id) + return None def set_progress_bar_global_hook(self, hook: Any) -> None: """Forward hook registration (though usually not needed from child).""" - comfy.utils.set_progress_bar_global_hook(hook) + if os.environ.get("PYISOLATE_CHILD") == "1": + raise RuntimeError( + "UtilsProxy.set_progress_bar_global_hook is not available in child without exact relay support" + ) + _comfy_utils().set_progress_bar_global_hook(hook) diff --git a/comfy/isolation/proxies/web_directory_proxy.py b/comfy/isolation/proxies/web_directory_proxy.py new file mode 100644 index 000000000..3acf3f4fc --- /dev/null +++ b/comfy/isolation/proxies/web_directory_proxy.py @@ -0,0 +1,219 @@ +"""WebDirectoryProxy — serves isolated node web assets via RPC. + +Child side: enumerates and reads files from the extension's web/ directory. +Host side: gets an RPC proxy that fetches file listings and contents on demand. + +Only files with allowed extensions (.js, .html, .css) are served. +Directory traversal is rejected. File contents are base64-encoded for +safe JSON-RPC transport. +""" + +from __future__ import annotations + +import base64 +import logging +import os +from pathlib import Path, PurePosixPath +from typing import Any, Dict, List + +from pyisolate import ProxiedSingleton + +logger = logging.getLogger(__name__) + +ALLOWED_EXTENSIONS = frozenset({".js", ".html", ".css"}) + +MIME_TYPES = { + ".js": "application/javascript", + ".html": "text/html", + ".css": "text/css", +} + + +class WebDirectoryProxy(ProxiedSingleton): + """Proxy for serving isolated extension web directories. + + On the child side, this class has direct filesystem access to the + extension's web/ directory. On the host side, callers get an RPC + proxy whose method calls are forwarded to the child. + """ + + # {extension_name: absolute_path_to_web_dir} + _web_dirs: dict[str, str] = {} + + @classmethod + def register_web_dir(cls, extension_name: str, web_dir_path: str) -> None: + """Register an extension's web directory (child-side only).""" + cls._web_dirs[extension_name] = web_dir_path + logger.info( + "][ WebDirectoryProxy: registered %s -> %s", + extension_name, + web_dir_path, + ) + + def list_web_files(self, extension_name: str) -> List[Dict[str, str]]: + """Return a list of servable files in the extension's web directory. + + Each entry is {"relative_path": "js/foo.js", "content_type": "application/javascript"}. + Only files with allowed extensions are included. + """ + web_dir = self._web_dirs.get(extension_name) + if not web_dir: + return [] + + root = Path(web_dir) + if not root.is_dir(): + return [] + + result: List[Dict[str, str]] = [] + for path in sorted(root.rglob("*")): + if not path.is_file(): + continue + ext = path.suffix.lower() + if ext not in ALLOWED_EXTENSIONS: + continue + rel = path.relative_to(root) + result.append({ + "relative_path": str(PurePosixPath(rel)), + "content_type": MIME_TYPES[ext], + }) + return result + + def get_web_file( + self, extension_name: str, relative_path: str + ) -> Dict[str, Any]: + """Return the contents of a single web file as base64. + + Raises ValueError for traversal attempts or disallowed file types. + Returns {"content": , "content_type": }. + """ + _validate_path(relative_path) + + web_dir = self._web_dirs.get(extension_name) + if not web_dir: + raise FileNotFoundError( + f"No web directory registered for {extension_name}" + ) + + root = Path(web_dir) + target = (root / relative_path).resolve() + + # Ensure resolved path is under the web directory + if not str(target).startswith(str(root.resolve())): + raise ValueError(f"Path escapes web directory: {relative_path}") + + if not target.is_file(): + raise FileNotFoundError(f"File not found: {relative_path}") + + ext = target.suffix.lower() + if ext not in ALLOWED_EXTENSIONS: + raise ValueError(f"Disallowed file type: {ext}") + + content_type = MIME_TYPES[ext] + raw = target.read_bytes() + + return { + "content": base64.b64encode(raw).decode("ascii"), + "content_type": content_type, + } + + +def _validate_path(relative_path: str) -> None: + """Reject directory traversal and absolute paths.""" + if os.path.isabs(relative_path): + raise ValueError(f"Absolute paths are not allowed: {relative_path}") + if ".." in PurePosixPath(relative_path).parts: + raise ValueError(f"Directory traversal is not allowed: {relative_path}") + + +# --------------------------------------------------------------------------- +# Host-side cache and aiohttp handler +# --------------------------------------------------------------------------- + + +class WebDirectoryCache: + """Host-side in-memory cache for proxied web directory contents. + + Populated lazily via RPC calls to the child's WebDirectoryProxy. + Once a file is cached, subsequent requests are served from memory. + """ + + def __init__(self) -> None: + # {extension_name: {relative_path: {"content": bytes, "content_type": str}}} + self._file_cache: dict[str, dict[str, dict[str, Any]]] = {} + # {extension_name: [{"relative_path": str, "content_type": str}, ...]} + self._listing_cache: dict[str, list[dict[str, str]]] = {} + # {extension_name: WebDirectoryProxy (RPC proxy instance)} + self._proxies: dict[str, Any] = {} + + def register_proxy(self, extension_name: str, proxy: Any) -> None: + """Register an RPC proxy for an extension's web directory.""" + self._proxies[extension_name] = proxy + logger.info( + "][ WebDirectoryCache: registered proxy for %s", extension_name + ) + + @property + def extension_names(self) -> list[str]: + return list(self._proxies.keys()) + + def list_files(self, extension_name: str) -> list[dict[str, str]]: + """List servable files for an extension (cached after first call).""" + if extension_name not in self._listing_cache: + proxy = self._proxies.get(extension_name) + if proxy is None: + return [] + try: + self._listing_cache[extension_name] = proxy.list_web_files( + extension_name + ) + except Exception: + logger.warning( + "][ WebDirectoryCache: failed to list files for %s", + extension_name, + exc_info=True, + ) + return [] + return self._listing_cache[extension_name] + + def get_file( + self, extension_name: str, relative_path: str + ) -> dict[str, Any] | None: + """Get file content (cached after first fetch). Returns None on miss.""" + ext_cache = self._file_cache.get(extension_name) + if ext_cache and relative_path in ext_cache: + return ext_cache[relative_path] + + proxy = self._proxies.get(extension_name) + if proxy is None: + return None + + try: + result = proxy.get_web_file(extension_name, relative_path) + except (FileNotFoundError, ValueError): + return None + except Exception: + logger.warning( + "][ WebDirectoryCache: failed to fetch %s/%s", + extension_name, + relative_path, + exc_info=True, + ) + return None + + decoded = { + "content": base64.b64decode(result["content"]), + "content_type": result["content_type"], + } + + if extension_name not in self._file_cache: + self._file_cache[extension_name] = {} + self._file_cache[extension_name][relative_path] = decoded + return decoded + + +# Global cache instance — populated during isolation loading +_web_directory_cache = WebDirectoryCache() + + +def get_web_directory_cache() -> WebDirectoryCache: + return _web_directory_cache