feat: model and service proxies for isolated custom nodes

This commit is contained in:
John Pollock 2026-03-29 19:02:43 -05:00
parent 878684d8b2
commit 683e2d6a73
10 changed files with 858 additions and 107 deletions

View File

@ -22,6 +22,16 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
__module__ = "comfy.model_patcher"
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
def _spawn_related_proxy(self, instance_id: str) -> "ModelPatcherProxy":
proxy = ModelPatcherProxy(
instance_id,
self._registry,
manage_lifecycle=not IS_CHILD_PROCESS,
)
if getattr(self, "_rpc_caller", None) is not None:
proxy._rpc_caller = self._rpc_caller
return proxy
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
@ -164,9 +174,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def clone(self) -> ModelPatcherProxy:
new_id = self._call_rpc("clone")
return ModelPatcherProxy(
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
return self._spawn_related_proxy(new_id)
def clone_has_same_weights(self, clone: Any) -> bool:
if isinstance(clone, ModelPatcherProxy):
@ -509,11 +517,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
)
new_model = None
if result.get("model_id"):
new_model = ModelPatcherProxy(
result["model_id"],
self._registry,
manage_lifecycle=not IS_CHILD_PROCESS,
)
new_model = self._spawn_related_proxy(result["model_id"])
new_clip = None
if result.get("clip_id"):
from comfy.isolation.clip_proxy import CLIPProxy
@ -789,12 +793,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def get_additional_models(self) -> List[ModelPatcherProxy]:
ids = self._call_rpc("get_additional_models")
return [
ModelPatcherProxy(
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
for mid in ids
]
return [self._spawn_related_proxy(mid) for mid in ids]
def model_patches_models(self) -> Any:
return self._call_rpc("model_patches_models")
@ -803,6 +802,25 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def parent(self) -> Any:
return self._call_rpc("get_parent")
def model_mmap_residency(self, free: bool = False) -> tuple:
result = self._call_rpc("model_mmap_residency", free)
if isinstance(result, list):
return tuple(result)
return result
def pinned_memory_size(self) -> int:
return self._call_rpc("pinned_memory_size")
def get_non_dynamic_delegate(self) -> ModelPatcherProxy:
new_id = self._call_rpc("get_non_dynamic_delegate")
return self._spawn_related_proxy(new_id)
def disable_model_cfg1_optimization(self) -> None:
self._call_rpc("disable_model_cfg1_optimization")
def set_model_noise_refiner_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "noise_refiner")
class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy):
@ -812,8 +830,14 @@ class _InnerModelProxy:
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
raise AttributeError(name)
if name == "model_config":
from types import SimpleNamespace
data = self._parent._call_rpc("get_inner_model_attr", name)
if isinstance(data, dict):
return SimpleNamespace(**data)
return data
if name in (
"model_config",
"latent_format",
"model_type",
"current_weight_patches_uuid",
@ -824,11 +848,14 @@ class _InnerModelProxy:
if name == "device":
return self._parent._call_rpc("get_inner_model_attr", "device")
if name == "current_patcher":
return ModelPatcherProxy(
proxy = ModelPatcherProxy(
self._parent._instance_id,
self._parent._registry,
manage_lifecycle=False,
)
if getattr(self._parent, "_rpc_caller", None) is not None:
proxy._rpc_caller = self._parent._rpc_caller
return proxy
if name == "model_sampling":
if self._model_sampling is None:
self._model_sampling = self._parent._call_rpc(

View File

@ -250,22 +250,47 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return f"<ModelObject: {type(instance.model).__name__}>"
result = instance.get_model_object(name)
if name == "model_sampling":
from comfy.isolation.model_sampling_proxy import (
ModelSamplingRegistry,
ModelSamplingProxy,
)
registry = ModelSamplingRegistry()
# Preserve identity when upstream already returned a proxy. Re-registering
# a proxy object creates proxy-of-proxy call chains.
if isinstance(result, ModelSamplingProxy):
sampling_id = result._instance_id
else:
sampling_id = registry.register(result)
return ModelSamplingProxy(sampling_id, registry)
# Return inline serialization so the child reconstructs the real
# class with correct isinstance behavior. Returning a
# ModelSamplingProxy breaks isinstance checks (e.g.
# offset_first_sigma_for_snr in k_diffusion/sampling.py:173).
return self._serialize_model_sampling_inline(result)
return detach_if_grad(result)
@staticmethod
def _serialize_model_sampling_inline(obj: Any) -> dict:
"""Serialize a ModelSampling object as inline data for the child to reconstruct."""
import torch
import base64
import io as _io
bases = []
for base in type(obj).__mro__:
if base.__module__ == "comfy.model_sampling" and base.__name__ != "object":
bases.append(base.__name__)
sd = obj.state_dict()
sd_serialized = {}
for k, v in sd.items():
buf = _io.BytesIO()
torch.save(v, buf)
sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii")
plain_attrs = {}
for k, v in obj.__dict__.items():
if k.startswith("_"):
continue
if isinstance(v, (bool, int, float, str)):
plain_attrs[k] = v
return {
"__type__": "ModelSamplingInline",
"bases": bases,
"state_dict": sd_serialized,
"attrs": plain_attrs,
}
async def get_model_options(self, instance_id: str) -> dict:
instance = self._get_instance(instance_id)
import copy
@ -348,6 +373,20 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def get_ram_usage(self, instance_id: str) -> int:
return self._get_instance(instance_id).get_ram_usage()
async def model_mmap_residency(self, instance_id: str, free: bool = False) -> tuple:
return self._get_instance(instance_id).model_mmap_residency(free=free)
async def pinned_memory_size(self, instance_id: str) -> int:
return self._get_instance(instance_id).pinned_memory_size()
async def get_non_dynamic_delegate(self, instance_id: str) -> str:
instance = self._get_instance(instance_id)
delegate = instance.get_non_dynamic_delegate()
return self.register(delegate)
async def disable_model_cfg1_optimization(self, instance_id: str) -> None:
self._get_instance(instance_id).disable_model_cfg1_optimization()
async def lowvram_patch_counter(self, instance_id: str) -> int:
return self._get_instance(instance_id).lowvram_patch_counter()
@ -959,12 +998,54 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def get_inner_model_attr(self, instance_id: str, name: str) -> Any:
try:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id).model, name)
)
value = getattr(self._get_instance(instance_id).model, name)
if name == "model_config":
value = self._extract_model_config(value)
return self._sanitize_rpc_result(value)
except AttributeError:
return None
@staticmethod
def _extract_model_config(config: Any) -> dict:
"""Extract JSON-safe attributes from a model config object.
ComfyUI model config classes (supported_models_base.BASE subclasses)
have a permissive __getattr__ that returns None for any unknown
attribute instead of raising AttributeError. This defeats hasattr-based
duck-typing in _sanitize_rpc_result, causing TypeError when it tries
to call obj.items() (which resolves to None). We extract the real
class-level and instance-level attributes into a plain dict.
"""
# Attributes consumed by ModelSampling*.__init__ and other callers
_CONFIG_KEYS = (
"sampling_settings",
"unet_config",
"unet_extra_config",
"latent_format",
"manual_cast_dtype",
"custom_operations",
"optimizations",
"memory_usage_factor",
"supported_inference_dtypes",
)
result: dict = {}
for key in _CONFIG_KEYS:
# Use type(config).__dict__ first (class attrs), then instance __dict__
# to avoid triggering the permissive __getattr__
if key in type(config).__dict__:
val = type(config).__dict__[key]
# Skip classmethods/staticmethods/descriptors
if not callable(val) or isinstance(val, (dict, list, tuple)):
result[key] = val
elif hasattr(config, "__dict__") and key in config.__dict__:
result[key] = config.__dict__[key]
# Also include instance overrides (e.g. set_inference_dtype sets unet_config['dtype'])
if hasattr(config, "__dict__"):
for key, val in config.__dict__.items():
if key in _CONFIG_KEYS:
result[key] = val
return result
async def inner_model_memory_required(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:

View File

@ -118,6 +118,47 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
_GLOBAL_LOOP = loop
def run_sync_rpc_coro(coro: Any, timeout_ms: Optional[int] = None) -> Any:
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
try:
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
def call_singleton_rpc(
caller: Any,
method_name: str,
*args: Any,
timeout_ms: Optional[int] = None,
**kwargs: Any,
) -> Any:
if caller is None:
raise RuntimeError(f"No RPC caller available for {method_name}")
method = getattr(caller, method_name)
return run_sync_rpc_coro(method(*args, **kwargs), timeout_ms=timeout_ms)
class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base"
@ -208,31 +249,8 @@ class BaseProxy(Generic[T]):
)
try:
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
except concurrent.futures.TimeoutError as exc:
return run_sync_rpc_coro(coro, timeout_ms=timeout_ms)
except TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"

View File

@ -1,9 +1,35 @@
from __future__ import annotations
from typing import Dict
import os
from typing import Any, Dict, Optional
import folder_paths
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
def _folder_paths():
import folder_paths
return folder_paths
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
def _serialize_folder_names_and_paths(data: dict[str, tuple[list[str], set[str]]]) -> dict[str, dict[str, list[str]]]:
return {
key: {"paths": list(paths), "extensions": sorted(list(extensions))}
for key, (paths, extensions) in data.items()
}
def _deserialize_folder_names_and_paths(data: dict[str, dict[str, list[str]]]) -> dict[str, tuple[list[str], set[str]]]:
return {
key: (list(value.get("paths", [])), set(value.get("extensions", [])))
for key, value in data.items()
}
class FolderPathsProxy(ProxiedSingleton):
"""
@ -12,18 +38,165 @@ class FolderPathsProxy(ProxiedSingleton):
mutable collections to ensure efficient by-value transfer.
"""
def __getattr__(self, name):
return getattr(folder_paths, name)
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("FolderPathsProxy RPC caller is not configured")
return cls._rpc
def __getattr__(self, name):
if _is_child_process():
property_rpc = {
"models_dir": "rpc_get_models_dir",
"folder_names_and_paths": "rpc_get_folder_names_and_paths",
"extension_mimetypes_cache": "rpc_get_extension_mimetypes_cache",
"filename_list_cache": "rpc_get_filename_list_cache",
}
rpc_name = property_rpc.get(name)
if rpc_name is not None:
return call_singleton_rpc(self._get_caller(), rpc_name)
raise AttributeError(name)
return getattr(_folder_paths(), name)
# Return dict snapshots (avoid RPC chatter)
@property
def folder_names_and_paths(self) -> Dict:
return dict(folder_paths.folder_names_and_paths)
if _is_child_process():
payload = call_singleton_rpc(self._get_caller(), "rpc_get_folder_names_and_paths")
return _deserialize_folder_names_and_paths(payload)
return _folder_paths().folder_names_and_paths
@property
def extension_mimetypes_cache(self) -> Dict:
return dict(folder_paths.extension_mimetypes_cache)
if _is_child_process():
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_extension_mimetypes_cache"))
return dict(_folder_paths().extension_mimetypes_cache)
@property
def filename_list_cache(self) -> Dict:
return dict(folder_paths.filename_list_cache)
if _is_child_process():
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list_cache"))
return dict(_folder_paths().filename_list_cache)
@property
def models_dir(self) -> str:
if _is_child_process():
return str(call_singleton_rpc(self._get_caller(), "rpc_get_models_dir"))
return _folder_paths().models_dir
def get_temp_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_temp_directory")
return _folder_paths().get_temp_directory()
def get_input_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_input_directory")
return _folder_paths().get_input_directory()
def get_output_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_output_directory")
return _folder_paths().get_output_directory()
def get_user_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_user_directory")
return _folder_paths().get_user_directory()
def get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
if _is_child_process():
return call_singleton_rpc(
self._get_caller(), "rpc_get_annotated_filepath", name, default_dir
)
return _folder_paths().get_annotated_filepath(name, default_dir)
def exists_annotated_filepath(self, name: str) -> bool:
if _is_child_process():
return bool(
call_singleton_rpc(self._get_caller(), "rpc_exists_annotated_filepath", name)
)
return bool(_folder_paths().exists_annotated_filepath(name))
def add_model_folder_path(
self, folder_name: str, full_folder_path: str, is_default: bool = False
) -> None:
if _is_child_process():
call_singleton_rpc(
self._get_caller(),
"rpc_add_model_folder_path",
folder_name,
full_folder_path,
is_default,
)
return None
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
return None
def get_folder_paths(self, folder_name: str) -> list[str]:
if _is_child_process():
return list(call_singleton_rpc(self._get_caller(), "rpc_get_folder_paths", folder_name))
return list(_folder_paths().get_folder_paths(folder_name))
def get_filename_list(self, folder_name: str) -> list[str]:
if _is_child_process():
return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
return list(_folder_paths().get_filename_list(folder_name))
def get_full_path(self, folder_name: str, filename: str) -> str | None:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_full_path", folder_name, filename)
return _folder_paths().get_full_path(folder_name, filename)
async def rpc_get_models_dir(self) -> str:
return _folder_paths().models_dir
async def rpc_get_folder_names_and_paths(self) -> dict[str, dict[str, list[str]]]:
return _serialize_folder_names_and_paths(_folder_paths().folder_names_and_paths)
async def rpc_get_extension_mimetypes_cache(self) -> dict[str, Any]:
return dict(_folder_paths().extension_mimetypes_cache)
async def rpc_get_filename_list_cache(self) -> dict[str, Any]:
return dict(_folder_paths().filename_list_cache)
async def rpc_get_temp_directory(self) -> str:
return _folder_paths().get_temp_directory()
async def rpc_get_input_directory(self) -> str:
return _folder_paths().get_input_directory()
async def rpc_get_output_directory(self) -> str:
return _folder_paths().get_output_directory()
async def rpc_get_user_directory(self) -> str:
return _folder_paths().get_user_directory()
async def rpc_get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
return _folder_paths().get_annotated_filepath(name, default_dir)
async def rpc_exists_annotated_filepath(self, name: str) -> bool:
return _folder_paths().exists_annotated_filepath(name)
async def rpc_add_model_folder_path(
self, folder_name: str, full_folder_path: str, is_default: bool = False
) -> None:
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
async def rpc_get_folder_paths(self, folder_name: str) -> list[str]:
return _folder_paths().get_folder_paths(folder_name)
async def rpc_get_filename_list(self, folder_name: str) -> list[str]:
return _folder_paths().get_filename_list(folder_name)
async def rpc_get_full_path(self, folder_name: str, filename: str) -> str | None:
return _folder_paths().get_full_path(folder_name, filename)

View File

@ -1,7 +1,12 @@
from __future__ import annotations
import os
from typing import Any, Dict, Optional
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
class AnyTypeProxy(str):
"""Replacement for custom AnyType objects used by some nodes."""
@ -71,9 +76,29 @@ def _restore_special_value(value: Any) -> Any:
return value
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
def _serialize_special_value(value: Any) -> Any:
if isinstance(value, AnyTypeProxy):
return {"__pyisolate_any_type__": True, "value": str(value)}
if isinstance(value, FlexibleOptionalInputProxy):
return {
"__pyisolate_flexible_optional__": True,
"type": _serialize_special_value(value.type),
"data": {k: _serialize_special_value(v) for k, v in value.items()},
}
if isinstance(value, ByPassTypeTupleProxy):
return {
"__pyisolate_bypass_tuple__": [_serialize_special_value(v) for v in value]
}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_serialize_special_value(v) for v in value]}
if isinstance(value, list):
return [_serialize_special_value(v) for v in value]
if isinstance(value, dict):
return {k: _serialize_special_value(v) for k, v in value.items()}
return value
def _restore_input_types_local(raw: Dict[str, object]) -> Dict[str, object]:
if not isinstance(raw, dict):
return raw # type: ignore[return-value]
@ -90,9 +115,44 @@ def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
return restored
class HelperProxiesService(ProxiedSingleton):
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("HelperProxiesService RPC caller is not configured")
return cls._rpc
async def rpc_restore_input_types(self, raw: Dict[str, object]) -> Dict[str, object]:
restored = _restore_input_types_local(raw)
return _serialize_special_value(restored)
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
if os.environ.get("PYISOLATE_CHILD") == "1":
payload = call_singleton_rpc(
HelperProxiesService._get_caller(),
"rpc_restore_input_types",
raw,
)
return _restore_input_types_local(payload)
return _restore_input_types_local(raw)
__all__ = [
"AnyTypeProxy",
"FlexibleOptionalInputProxy",
"ByPassTypeTupleProxy",
"HelperProxiesService",
"restore_input_types",
]

View File

@ -1,27 +1,142 @@
import comfy.model_management as mm
from __future__ import annotations
import os
from typing import Any, Optional
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
def _mm():
import comfy.model_management
return comfy.model_management
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
class TorchDeviceProxy:
def __init__(self, device_str: str):
self._device_str = device_str
if ":" in device_str:
device_type, index = device_str.split(":", 1)
self.type = device_type
self.index = int(index)
else:
self.type = device_str
self.index = None
def __str__(self) -> str:
return self._device_str
def __repr__(self) -> str:
return f"TorchDeviceProxy({self._device_str!r})"
def _serialize_value(value: Any) -> Any:
value_type = type(value)
if value_type.__module__ == "torch" and value_type.__name__ == "device":
return {"__pyisolate_torch_device__": str(value)}
if isinstance(value, TorchDeviceProxy):
return {"__pyisolate_torch_device__": str(value)}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_serialize_value(item) for item in value]}
if isinstance(value, list):
return [_serialize_value(item) for item in value]
if isinstance(value, dict):
return {key: _serialize_value(inner) for key, inner in value.items()}
return value
def _deserialize_value(value: Any) -> Any:
if isinstance(value, dict):
if "__pyisolate_torch_device__" in value:
return TorchDeviceProxy(value["__pyisolate_torch_device__"])
if "__pyisolate_tuple__" in value:
return tuple(_deserialize_value(item) for item in value["__pyisolate_tuple__"])
return {key: _deserialize_value(inner) for key, inner in value.items()}
if isinstance(value, list):
return [_deserialize_value(item) for item in value]
return value
def _normalize_argument(value: Any) -> Any:
if isinstance(value, TorchDeviceProxy):
import torch
return torch.device(str(value))
if isinstance(value, dict):
if "__pyisolate_torch_device__" in value:
import torch
return torch.device(value["__pyisolate_torch_device__"])
if "__pyisolate_tuple__" in value:
return tuple(_normalize_argument(item) for item in value["__pyisolate_tuple__"])
return {key: _normalize_argument(inner) for key, inner in value.items()}
if isinstance(value, list):
return [_normalize_argument(item) for item in value]
return value
class ModelManagementProxy(ProxiedSingleton):
"""
Dynamic proxy for comfy.model_management.
Uses __getattr__ to forward all calls to the underlying module,
reducing maintenance burden.
Exact-relay proxy for comfy.model_management.
Child calls never import comfy.model_management directly; they serialize
arguments, relay to host, and deserialize the host result back.
"""
# Explicitly expose Enums/Classes as properties
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("ModelManagementProxy RPC caller is not configured")
return cls._rpc
def _relay_call(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
payload = call_singleton_rpc(
self._get_caller(),
"rpc_call",
method_name,
_serialize_value(args),
_serialize_value(kwargs),
)
return _deserialize_value(payload)
@property
def VRAMState(self):
return mm.VRAMState
return _mm().VRAMState
@property
def CPUState(self):
return mm.CPUState
return _mm().CPUState
@property
def OOM_EXCEPTION(self):
return mm.OOM_EXCEPTION
return _mm().OOM_EXCEPTION
def __getattr__(self, name):
"""Forward all other attribute access to the module."""
return getattr(mm, name)
def __getattr__(self, name: str):
if _is_child_process():
def child_method(*args: Any, **kwargs: Any) -> Any:
return self._relay_call(name, *args, **kwargs)
return child_method
return getattr(_mm(), name)
async def rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
normalized_args = _normalize_argument(_deserialize_value(args))
normalized_kwargs = _normalize_argument(_deserialize_value(kwargs))
method = getattr(_mm(), method_name)
result = method(*normalized_args, **normalized_kwargs)
return _serialize_value(result)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import os
from typing import Any, Optional
try:
@ -10,13 +11,38 @@ except ImportError:
class ProxiedSingleton:
pass
from .base import call_singleton_rpc
from comfy_execution.progress import get_progress_state
def _get_progress_state():
from comfy_execution.progress import get_progress_state
return get_progress_state()
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
logger = logging.getLogger(__name__)
class ProgressProxy(ProxiedSingleton):
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("ProgressProxy RPC caller is not configured")
return cls._rpc
def set_progress(
self,
value: float,
@ -24,7 +50,33 @@ class ProgressProxy(ProxiedSingleton):
node_id: Optional[str] = None,
image: Any = None,
) -> None:
get_progress_state().update_progress(
if _is_child_process():
call_singleton_rpc(
self._get_caller(),
"rpc_set_progress",
value,
max_value,
node_id,
image,
)
return None
_get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=image,
)
return None
async def rpc_set_progress(
self,
value: float,
max_value: float,
node_id: Optional[str] = None,
image: Any = None,
) -> None:
_get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,

View File

@ -13,10 +13,10 @@ import os
from typing import Any, Dict, Optional, Callable
import logging
from aiohttp import web
# IMPORTS
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
logger = logging.getLogger(__name__)
LOG_PREFIX = "[Isolation:C<->H]"
@ -64,6 +64,10 @@ class PromptServerStub:
PromptServerService, target_id
) # We import Service below?
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
# We need PromptServerService available for the create_caller call?
# Or just use the Stub class if ID matches?
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
@ -133,7 +137,7 @@ class PromptServerStub:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
except RuntimeError:
pass # Sync context without loop?
call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid)
# --- Route Registration Logic ---
def register_route(self, method: str, path: str, handler: Callable):
@ -147,7 +151,7 @@ class PromptServerStub:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
except RuntimeError:
pass
call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler)
class RouteStub:
@ -226,6 +230,7 @@ class PromptServerService(ProxiedSingleton):
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
"""RPC Target: Register a route that forwards to the Child."""
from aiohttp import web
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
async def route_wrapper(request: web.Request) -> web.Response:
@ -251,8 +256,9 @@ class PromptServerService(ProxiedSingleton):
# Register loop
self.server.app.router.add_route(method, path, route_wrapper)
def _serialize_response(self, result: Any) -> web.Response:
def _serialize_response(self, result: Any) -> Any:
"""Helper to convert Child result -> web.Response"""
from aiohttp import web
if isinstance(result, web.Response):
return result
# Handle dict (json)

View File

@ -2,12 +2,16 @@
from __future__ import annotations
from typing import Optional, Any
import comfy.utils
from pyisolate import ProxiedSingleton
import os
def _comfy_utils():
import comfy.utils
return comfy.utils
class UtilsProxy(ProxiedSingleton):
"""
Proxy for comfy.utils.
@ -23,6 +27,10 @@ class UtilsProxy(ProxiedSingleton):
# Create caller using class name as ID (standard for Singletons)
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
async def progress_bar_hook(
self,
value: int,
@ -35,30 +43,22 @@ class UtilsProxy(ProxiedSingleton):
Child-side: this method call is intercepted by RPC and sent to host.
"""
if os.environ.get("PYISOLATE_CHILD") == "1":
# Manual RPC dispatch for Child process
# Use class-level RPC storage (Static Injection)
if UtilsProxy._rpc:
return await UtilsProxy._rpc.progress_bar_hook(
value, total, preview, node_id
)
# Fallback channel: global child rpc
try:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
get_child_rpc_instance()
# If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it,
# but we need a caller. For now, just pass to avoid crashing.
pass
except (ImportError, LookupError):
pass
return None
if UtilsProxy._rpc is None:
raise RuntimeError("UtilsProxy RPC caller is not configured")
return await UtilsProxy._rpc.progress_bar_hook(
value, total, preview, node_id
)
# Host Execution
if comfy.utils.PROGRESS_BAR_HOOK is not None:
comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
utils = _comfy_utils()
if utils.PROGRESS_BAR_HOOK is not None:
return utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
return None
def set_progress_bar_global_hook(self, hook: Any) -> None:
"""Forward hook registration (though usually not needed from child)."""
comfy.utils.set_progress_bar_global_hook(hook)
if os.environ.get("PYISOLATE_CHILD") == "1":
raise RuntimeError(
"UtilsProxy.set_progress_bar_global_hook is not available in child without exact relay support"
)
_comfy_utils().set_progress_bar_global_hook(hook)

View File

@ -0,0 +1,219 @@
"""WebDirectoryProxy — serves isolated node web assets via RPC.
Child side: enumerates and reads files from the extension's web/ directory.
Host side: gets an RPC proxy that fetches file listings and contents on demand.
Only files with allowed extensions (.js, .html, .css) are served.
Directory traversal is rejected. File contents are base64-encoded for
safe JSON-RPC transport.
"""
from __future__ import annotations
import base64
import logging
import os
from pathlib import Path, PurePosixPath
from typing import Any, Dict, List
from pyisolate import ProxiedSingleton
logger = logging.getLogger(__name__)
ALLOWED_EXTENSIONS = frozenset({".js", ".html", ".css"})
MIME_TYPES = {
".js": "application/javascript",
".html": "text/html",
".css": "text/css",
}
class WebDirectoryProxy(ProxiedSingleton):
"""Proxy for serving isolated extension web directories.
On the child side, this class has direct filesystem access to the
extension's web/ directory. On the host side, callers get an RPC
proxy whose method calls are forwarded to the child.
"""
# {extension_name: absolute_path_to_web_dir}
_web_dirs: dict[str, str] = {}
@classmethod
def register_web_dir(cls, extension_name: str, web_dir_path: str) -> None:
"""Register an extension's web directory (child-side only)."""
cls._web_dirs[extension_name] = web_dir_path
logger.info(
"][ WebDirectoryProxy: registered %s -> %s",
extension_name,
web_dir_path,
)
def list_web_files(self, extension_name: str) -> List[Dict[str, str]]:
"""Return a list of servable files in the extension's web directory.
Each entry is {"relative_path": "js/foo.js", "content_type": "application/javascript"}.
Only files with allowed extensions are included.
"""
web_dir = self._web_dirs.get(extension_name)
if not web_dir:
return []
root = Path(web_dir)
if not root.is_dir():
return []
result: List[Dict[str, str]] = []
for path in sorted(root.rglob("*")):
if not path.is_file():
continue
ext = path.suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
continue
rel = path.relative_to(root)
result.append({
"relative_path": str(PurePosixPath(rel)),
"content_type": MIME_TYPES[ext],
})
return result
def get_web_file(
self, extension_name: str, relative_path: str
) -> Dict[str, Any]:
"""Return the contents of a single web file as base64.
Raises ValueError for traversal attempts or disallowed file types.
Returns {"content": <base64 str>, "content_type": <MIME str>}.
"""
_validate_path(relative_path)
web_dir = self._web_dirs.get(extension_name)
if not web_dir:
raise FileNotFoundError(
f"No web directory registered for {extension_name}"
)
root = Path(web_dir)
target = (root / relative_path).resolve()
# Ensure resolved path is under the web directory
if not str(target).startswith(str(root.resolve())):
raise ValueError(f"Path escapes web directory: {relative_path}")
if not target.is_file():
raise FileNotFoundError(f"File not found: {relative_path}")
ext = target.suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise ValueError(f"Disallowed file type: {ext}")
content_type = MIME_TYPES[ext]
raw = target.read_bytes()
return {
"content": base64.b64encode(raw).decode("ascii"),
"content_type": content_type,
}
def _validate_path(relative_path: str) -> None:
"""Reject directory traversal and absolute paths."""
if os.path.isabs(relative_path):
raise ValueError(f"Absolute paths are not allowed: {relative_path}")
if ".." in PurePosixPath(relative_path).parts:
raise ValueError(f"Directory traversal is not allowed: {relative_path}")
# ---------------------------------------------------------------------------
# Host-side cache and aiohttp handler
# ---------------------------------------------------------------------------
class WebDirectoryCache:
"""Host-side in-memory cache for proxied web directory contents.
Populated lazily via RPC calls to the child's WebDirectoryProxy.
Once a file is cached, subsequent requests are served from memory.
"""
def __init__(self) -> None:
# {extension_name: {relative_path: {"content": bytes, "content_type": str}}}
self._file_cache: dict[str, dict[str, dict[str, Any]]] = {}
# {extension_name: [{"relative_path": str, "content_type": str}, ...]}
self._listing_cache: dict[str, list[dict[str, str]]] = {}
# {extension_name: WebDirectoryProxy (RPC proxy instance)}
self._proxies: dict[str, Any] = {}
def register_proxy(self, extension_name: str, proxy: Any) -> None:
"""Register an RPC proxy for an extension's web directory."""
self._proxies[extension_name] = proxy
logger.info(
"][ WebDirectoryCache: registered proxy for %s", extension_name
)
@property
def extension_names(self) -> list[str]:
return list(self._proxies.keys())
def list_files(self, extension_name: str) -> list[dict[str, str]]:
"""List servable files for an extension (cached after first call)."""
if extension_name not in self._listing_cache:
proxy = self._proxies.get(extension_name)
if proxy is None:
return []
try:
self._listing_cache[extension_name] = proxy.list_web_files(
extension_name
)
except Exception:
logger.warning(
"][ WebDirectoryCache: failed to list files for %s",
extension_name,
exc_info=True,
)
return []
return self._listing_cache[extension_name]
def get_file(
self, extension_name: str, relative_path: str
) -> dict[str, Any] | None:
"""Get file content (cached after first fetch). Returns None on miss."""
ext_cache = self._file_cache.get(extension_name)
if ext_cache and relative_path in ext_cache:
return ext_cache[relative_path]
proxy = self._proxies.get(extension_name)
if proxy is None:
return None
try:
result = proxy.get_web_file(extension_name, relative_path)
except (FileNotFoundError, ValueError):
return None
except Exception:
logger.warning(
"][ WebDirectoryCache: failed to fetch %s/%s",
extension_name,
relative_path,
exc_info=True,
)
return None
decoded = {
"content": base64.b64decode(result["content"]),
"content_type": result["content_type"],
}
if extension_name not in self._file_cache:
self._file_cache[extension_name] = {}
self._file_cache[extension_name][relative_path] = decoded
return decoded
# Global cache instance — populated during isolation loading
_web_directory_cache = WebDirectoryCache()
def get_web_directory_cache() -> WebDirectoryCache:
return _web_directory_cache