mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 18:27:40 +08:00
feat(isolation-proxies): proxy base + host service proxies
This commit is contained in:
parent
22f5e43c12
commit
9ca799362d
17
comfy/isolation/proxies/__init__.py
Normal file
17
comfy/isolation/proxies/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from .base import (
|
||||||
|
IS_CHILD_PROCESS,
|
||||||
|
BaseProxy,
|
||||||
|
BaseRegistry,
|
||||||
|
detach_if_grad,
|
||||||
|
get_thread_loop,
|
||||||
|
run_coro_in_new_loop,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"IS_CHILD_PROCESS",
|
||||||
|
"BaseRegistry",
|
||||||
|
"BaseProxy",
|
||||||
|
"get_thread_loop",
|
||||||
|
"run_coro_in_new_loop",
|
||||||
|
"detach_if_grad",
|
||||||
|
]
|
||||||
213
comfy/isolation/proxies/base.py
Normal file
213
comfy/isolation/proxies/base.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
# pylint: disable=global-statement,import-outside-toplevel,protected-access
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import weakref
|
||||||
|
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyisolate import ProxiedSingleton
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
class ProxiedSingleton: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
_thread_local = threading.local()
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_loop() -> asyncio.AbstractEventLoop:
|
||||||
|
loop = getattr(_thread_local, "loop", None)
|
||||||
|
if loop is None or loop.is_closed():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
_thread_local.loop = loop
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
def run_coro_in_new_loop(coro: Any) -> Any:
|
||||||
|
result_box: Dict[str, Any] = {}
|
||||||
|
exc_box: Dict[str, BaseException] = {}
|
||||||
|
|
||||||
|
def runner() -> None:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
result_box["value"] = loop.run_until_complete(coro)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
exc_box["exc"] = exc
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
t = threading.Thread(target=runner, daemon=True)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
if "exc" in exc_box:
|
||||||
|
raise exc_box["exc"]
|
||||||
|
return result_box.get("value")
|
||||||
|
|
||||||
|
|
||||||
|
def detach_if_grad(obj: Any) -> Any:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except Exception:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
return obj.detach() if obj.requires_grad else obj
|
||||||
|
if isinstance(obj, (list, tuple)):
|
||||||
|
return type(obj)(detach_if_grad(x) for x in obj)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: detach_if_grad(v) for k, v in obj.items()}
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRegistry(ProxiedSingleton, Generic[T]):
|
||||||
|
_type_prefix: str = "base"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
|
||||||
|
super().__init__()
|
||||||
|
self._registry: Dict[str, T] = {}
|
||||||
|
self._id_map: Dict[int, str] = {}
|
||||||
|
self._counter = 0
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def register(self, instance: T) -> str:
|
||||||
|
with self._lock:
|
||||||
|
obj_id = id(instance)
|
||||||
|
if obj_id in self._id_map:
|
||||||
|
return self._id_map[obj_id]
|
||||||
|
instance_id = f"{self._type_prefix}_{self._counter}"
|
||||||
|
self._counter += 1
|
||||||
|
self._registry[instance_id] = instance
|
||||||
|
self._id_map[obj_id] = instance_id
|
||||||
|
return instance_id
|
||||||
|
|
||||||
|
def unregister_sync(self, instance_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
instance = self._registry.pop(instance_id, None)
|
||||||
|
if instance:
|
||||||
|
self._id_map.pop(id(instance), None)
|
||||||
|
|
||||||
|
def _get_instance(self, instance_id: str) -> T:
|
||||||
|
if IS_CHILD_PROCESS:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[{self.__class__.__name__}] _get_instance called in child"
|
||||||
|
)
|
||||||
|
with self._lock:
|
||||||
|
instance = self._registry.get(instance_id)
|
||||||
|
if instance is None:
|
||||||
|
raise ValueError(f"{instance_id} not found")
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
global _GLOBAL_LOOP
|
||||||
|
_GLOBAL_LOOP = loop
|
||||||
|
|
||||||
|
|
||||||
|
class BaseProxy(Generic[T]):
|
||||||
|
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
||||||
|
__module__: str = "comfy.isolation.proxies.base"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
instance_id: str,
|
||||||
|
registry: Optional[Any] = None,
|
||||||
|
manage_lifecycle: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self._instance_id = instance_id
|
||||||
|
self._rpc_caller: Optional[Any] = None
|
||||||
|
self._registry = registry if registry is not None else self._registry_class()
|
||||||
|
self._manage_lifecycle = manage_lifecycle
|
||||||
|
self._cleaned_up = False
|
||||||
|
if manage_lifecycle and not IS_CHILD_PROCESS:
|
||||||
|
self._finalizer = weakref.finalize(
|
||||||
|
self, self._registry.unregister_sync, instance_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_rpc(self) -> Any:
|
||||||
|
if self._rpc_caller is None:
|
||||||
|
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||||
|
|
||||||
|
rpc = get_child_rpc_instance()
|
||||||
|
if rpc is None:
|
||||||
|
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
|
||||||
|
self._rpc_caller = rpc.create_caller(
|
||||||
|
self._registry_class, self._registry_class.get_remote_id()
|
||||||
|
)
|
||||||
|
return self._rpc_caller
|
||||||
|
|
||||||
|
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
rpc = self._get_rpc()
|
||||||
|
method = getattr(rpc, method_name)
|
||||||
|
coro = method(self._instance_id, *args, **kwargs)
|
||||||
|
|
||||||
|
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
||||||
|
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||||
|
try:
|
||||||
|
# If we are already in the global loop, we can't block on it?
|
||||||
|
# Actually, this method is synchronous (__getattr__ -> lambda).
|
||||||
|
# If called from async context in main loop, we need to handle that.
|
||||||
|
curr_loop = asyncio.get_running_loop()
|
||||||
|
if curr_loop is _GLOBAL_LOOP:
|
||||||
|
# We are in the main loop. We cannot await/block here if we are just a sync function.
|
||||||
|
# But proxies are often called from sync code.
|
||||||
|
# If called from sync code in main loop, creating a new loop is bad.
|
||||||
|
# But we can't await `coro`.
|
||||||
|
# This implies proxies MUST be awaited if called from async context?
|
||||||
|
# Existing code used `run_coro_in_new_loop` which is weird.
|
||||||
|
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
|
||||||
|
# we use run_coroutine_threadsafe.
|
||||||
|
pass
|
||||||
|
except RuntimeError:
|
||||||
|
# No running loop - we are in a worker thread.
|
||||||
|
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||||
|
return future.result()
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
return run_coro_in_new_loop(coro)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = get_thread_loop()
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict[str, Any]:
|
||||||
|
return {"_instance_id": self._instance_id}
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||||
|
self._instance_id = state["_instance_id"]
|
||||||
|
self._rpc_caller = None
|
||||||
|
self._registry = self._registry_class()
|
||||||
|
self._manage_lifecycle = False
|
||||||
|
self._cleaned_up = False
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
if self._cleaned_up or IS_CHILD_PROCESS:
|
||||||
|
return
|
||||||
|
self._cleaned_up = True
|
||||||
|
finalizer = getattr(self, "_finalizer", None)
|
||||||
|
if finalizer is not None:
|
||||||
|
finalizer.detach()
|
||||||
|
self._registry.unregister_sync(self._instance_id)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__} {self._instance_id}>"
|
||||||
|
|
||||||
|
|
||||||
|
def create_rpc_method(method_name: str) -> Callable[..., Any]:
|
||||||
|
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
|
||||||
|
return self._call_rpc(method_name, *args, **kwargs)
|
||||||
|
|
||||||
|
method.__name__ = method_name
|
||||||
|
return method
|
||||||
29
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
29
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from pyisolate import ProxiedSingleton
|
||||||
|
|
||||||
|
|
||||||
|
class FolderPathsProxy(ProxiedSingleton):
|
||||||
|
"""
|
||||||
|
Dynamic proxy for folder_paths.
|
||||||
|
Uses __getattr__ for most lookups, with explicit handling for
|
||||||
|
mutable collections to ensure efficient by-value transfer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(folder_paths, name)
|
||||||
|
|
||||||
|
# Return dict snapshots (avoid RPC chatter)
|
||||||
|
@property
|
||||||
|
def folder_names_and_paths(self) -> Dict:
|
||||||
|
return dict(folder_paths.folder_names_and_paths)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extension_mimetypes_cache(self) -> Dict:
|
||||||
|
return dict(folder_paths.extension_mimetypes_cache)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename_list_cache(self) -> Dict:
|
||||||
|
return dict(folder_paths.filename_list_cache)
|
||||||
98
comfy/isolation/proxies/helper_proxies.py
Normal file
98
comfy/isolation/proxies/helper_proxies.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class AnyTypeProxy(str):
|
||||||
|
"""Replacement for custom AnyType objects used by some nodes."""
|
||||||
|
|
||||||
|
def __new__(cls, value: str = "*"):
|
||||||
|
return super().__new__(cls, value)
|
||||||
|
|
||||||
|
def __ne__(self, other): # type: ignore[override]
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class FlexibleOptionalInputProxy(dict):
|
||||||
|
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
|
||||||
|
|
||||||
|
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.type = flex_type
|
||||||
|
if data:
|
||||||
|
self.update(data)
|
||||||
|
|
||||||
|
def __getitem__(self, key): # type: ignore[override]
|
||||||
|
return (self.type,)
|
||||||
|
|
||||||
|
def __contains__(self, key): # type: ignore[override]
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ByPassTypeTupleProxy(tuple):
|
||||||
|
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
|
||||||
|
|
||||||
|
def __new__(cls, values):
|
||||||
|
return super().__new__(cls, values)
|
||||||
|
|
||||||
|
def __getitem__(self, index): # type: ignore[override]
|
||||||
|
if index >= len(self):
|
||||||
|
return AnyTypeProxy("*")
|
||||||
|
return super().__getitem__(index)
|
||||||
|
|
||||||
|
|
||||||
|
def _restore_special_value(value: Any) -> Any:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if value.get("__pyisolate_any_type__"):
|
||||||
|
return AnyTypeProxy(value.get("value", "*"))
|
||||||
|
if value.get("__pyisolate_flexible_optional__"):
|
||||||
|
flex_type = _restore_special_value(value.get("type"))
|
||||||
|
data_raw = value.get("data")
|
||||||
|
data = (
|
||||||
|
{k: _restore_special_value(v) for k, v in data_raw.items()}
|
||||||
|
if isinstance(data_raw, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
return FlexibleOptionalInputProxy(flex_type, data)
|
||||||
|
if value.get("__pyisolate_tuple__") is not None:
|
||||||
|
return tuple(
|
||||||
|
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
|
||||||
|
)
|
||||||
|
if value.get("__pyisolate_bypass_tuple__") is not None:
|
||||||
|
return ByPassTypeTupleProxy(
|
||||||
|
tuple(
|
||||||
|
_restore_special_value(v)
|
||||||
|
for v in value["__pyisolate_bypass_tuple__"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {k: _restore_special_value(v) for k, v in value.items()}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_restore_special_value(v) for v in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
|
||||||
|
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
|
||||||
|
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
return raw # type: ignore[return-value]
|
||||||
|
|
||||||
|
restored: Dict[str, object] = {}
|
||||||
|
for section, entries in raw.items():
|
||||||
|
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
|
||||||
|
restored[section] = _restore_special_value(entries)
|
||||||
|
elif isinstance(entries, dict):
|
||||||
|
restored[section] = {
|
||||||
|
k: _restore_special_value(v) for k, v in entries.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
restored[section] = _restore_special_value(entries)
|
||||||
|
return restored
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnyTypeProxy",
|
||||||
|
"FlexibleOptionalInputProxy",
|
||||||
|
"ByPassTypeTupleProxy",
|
||||||
|
"restore_input_types",
|
||||||
|
]
|
||||||
27
comfy/isolation/proxies/model_management_proxy.py
Normal file
27
comfy/isolation/proxies/model_management_proxy.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import comfy.model_management as mm
|
||||||
|
from pyisolate import ProxiedSingleton
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManagementProxy(ProxiedSingleton):
|
||||||
|
"""
|
||||||
|
Dynamic proxy for comfy.model_management.
|
||||||
|
Uses __getattr__ to forward all calls to the underlying module,
|
||||||
|
reducing maintenance burden.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Explicitly expose Enums/Classes as properties
|
||||||
|
@property
|
||||||
|
def VRAMState(self):
|
||||||
|
return mm.VRAMState
|
||||||
|
|
||||||
|
@property
|
||||||
|
def CPUState(self):
|
||||||
|
return mm.CPUState
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OOM_EXCEPTION(self):
|
||||||
|
return mm.OOM_EXCEPTION
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
"""Forward all other attribute access to the module."""
|
||||||
|
return getattr(mm, name)
|
||||||
35
comfy/isolation/proxies/progress_proxy.py
Normal file
35
comfy/isolation/proxies/progress_proxy.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyisolate import ProxiedSingleton
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
class ProxiedSingleton:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
from comfy_execution.progress import get_progress_state
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressProxy(ProxiedSingleton):
|
||||||
|
def set_progress(
|
||||||
|
self,
|
||||||
|
value: float,
|
||||||
|
max_value: float,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
image: Any = None,
|
||||||
|
) -> None:
|
||||||
|
get_progress_state().update_progress(
|
||||||
|
node_id=node_id,
|
||||||
|
value=value,
|
||||||
|
max_value=max_value,
|
||||||
|
image=image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ProgressProxy"]
|
||||||
265
comfy/isolation/proxies/prompt_server_impl.py
Normal file
265
comfy/isolation/proxies/prompt_server_impl.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
|
||||||
|
"""Stateless RPC Implementation for PromptServer.
|
||||||
|
|
||||||
|
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
|
||||||
|
- Host: PromptServerService (RPC Handler)
|
||||||
|
- Child: PromptServerStub (Interface Implementation)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional, Callable
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
# IMPORTS
|
||||||
|
from pyisolate import ProxiedSingleton
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
LOG_PREFIX = "[Isolation:C<->H]"
|
||||||
|
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CHILD SIDE: PromptServerStub
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class PromptServerStub:
|
||||||
|
"""Stateless Stub for PromptServer."""
|
||||||
|
|
||||||
|
# Masquerade as the real server module
|
||||||
|
__module__ = "server"
|
||||||
|
|
||||||
|
_instance: Optional["PromptServerStub"] = None
|
||||||
|
_rpc: Optional[Any] = None # This will be the Caller object
|
||||||
|
_source_file: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.routes = RouteStub(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_rpc(cls, rpc: Any) -> None:
|
||||||
|
"""Inject RPC client (called by adapter.py or manually)."""
|
||||||
|
# Create caller for HOST Service
|
||||||
|
# Assuming Host Service is registered as "PromptServerService" (class name)
|
||||||
|
# We target the Host Service Class
|
||||||
|
target_id = "PromptServerService"
|
||||||
|
# We need to pass a class to create_caller? Usually yes.
|
||||||
|
# But we don't have the Service class imported here necessarily (if running on child).
|
||||||
|
# pyisolate check verify_service type?
|
||||||
|
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
|
||||||
|
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
|
||||||
|
# We need a dummy class with right name?
|
||||||
|
# Or just rely on string ID if create_caller supports it?
|
||||||
|
# Standard: rpc.create_caller(PromptServerStub, target_id)
|
||||||
|
# But wait, PromptServerStub is the *Local* class.
|
||||||
|
# We want to call *Remote* class.
|
||||||
|
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
|
||||||
|
# The first arg is 'service_cls'.
|
||||||
|
cls._rpc = rpc.create_caller(
|
||||||
|
PromptServerService, target_id
|
||||||
|
) # We import Service below?
|
||||||
|
|
||||||
|
# We need PromptServerService available for the create_caller call?
|
||||||
|
# Or just use the Stub class if ID matches?
|
||||||
|
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
|
||||||
|
|
||||||
|
@property
|
||||||
|
def instance(self) -> "PromptServerStub":
|
||||||
|
return self
|
||||||
|
|
||||||
|
# ... Compatibility ...
|
||||||
|
@classmethod
|
||||||
|
def _get_source_file(cls) -> str:
|
||||||
|
if cls._source_file is None:
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
|
||||||
|
return cls._source_file
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __file__(self) -> str:
|
||||||
|
return self._get_source_file()
|
||||||
|
|
||||||
|
# --- Properties ---
|
||||||
|
@property
|
||||||
|
def client_id(self) -> Optional[str]:
|
||||||
|
return "isolated_client"
|
||||||
|
|
||||||
|
def supports(self, feature: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app(self):
|
||||||
|
raise RuntimeError(
|
||||||
|
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_queue(self):
|
||||||
|
raise RuntimeError(
|
||||||
|
"PromptServer.prompt_queue is not accessible in isolated nodes."
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- UI Communication (RPC Delegates) ---
|
||||||
|
async def send_sync(
|
||||||
|
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
if self._rpc:
|
||||||
|
await self._rpc.ui_send_sync(event, data, sid)
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
if self._rpc:
|
||||||
|
await self._rpc.ui_send(event, data, sid)
|
||||||
|
|
||||||
|
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
|
||||||
|
if self._rpc:
|
||||||
|
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
|
||||||
|
# We must schedule it?
|
||||||
|
# Or use fire_remote equivalent?
|
||||||
|
# Caller object usually proxies calls. If host method is async, it returns coro.
|
||||||
|
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
|
||||||
|
# But UtilsProxy hook wrapper creates task.
|
||||||
|
# Does send_progress_text need to be sync? Yes, node code calls it sync.
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
|
||||||
|
except RuntimeError:
|
||||||
|
pass # Sync context without loop?
|
||||||
|
|
||||||
|
# --- Route Registration Logic ---
|
||||||
|
def register_route(self, method: str, path: str, handler: Callable):
|
||||||
|
"""Register a route handler via RPC."""
|
||||||
|
if not self._rpc:
|
||||||
|
logger.error("RPC not initialized in PromptServerStub")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fire registration async
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RouteStub:
|
||||||
|
"""Simulates aiohttp.web.RouteTableDef."""
|
||||||
|
|
||||||
|
def __init__(self, stub: PromptServerStub):
|
||||||
|
self._stub = stub
|
||||||
|
|
||||||
|
def get(self, path: str):
|
||||||
|
def decorator(handler):
|
||||||
|
self._stub.register_route("GET", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def post(self, path: str):
|
||||||
|
def decorator(handler):
|
||||||
|
self._stub.register_route("POST", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def patch(self, path: str):
|
||||||
|
def decorator(handler):
|
||||||
|
self._stub.register_route("PATCH", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def put(self, path: str):
|
||||||
|
def decorator(handler):
|
||||||
|
self._stub.register_route("PUT", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def delete(self, path: str):
|
||||||
|
def decorator(handler):
|
||||||
|
self._stub.register_route("DELETE", path, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# HOST SIDE: PromptServerService
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class PromptServerService(ProxiedSingleton):
|
||||||
|
"""Host-side RPC Service for PromptServer."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# We will bind to the real server instance lazily or via global import
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server(self):
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
|
return PromptServer.instance
|
||||||
|
|
||||||
|
async def ui_send_sync(
|
||||||
|
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||||
|
):
|
||||||
|
await self.server.send_sync(event, data, sid)
|
||||||
|
|
||||||
|
async def ui_send(
|
||||||
|
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||||
|
):
|
||||||
|
await self.server.send(event, data, sid)
|
||||||
|
|
||||||
|
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
|
||||||
|
# Made async to be awaitable by RPC layer
|
||||||
|
self.server.send_progress_text(text, node_id, sid)
|
||||||
|
|
||||||
|
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
|
||||||
|
"""RPC Target: Register a route that forwards to the Child."""
|
||||||
|
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
|
||||||
|
|
||||||
|
async def route_wrapper(request: web.Request) -> web.Response:
|
||||||
|
# 1. Capture request data
|
||||||
|
req_data = {
|
||||||
|
"method": request.method,
|
||||||
|
"path": request.path,
|
||||||
|
"query": dict(request.query),
|
||||||
|
}
|
||||||
|
if request.can_read_body:
|
||||||
|
req_data["text"] = await request.text()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
|
||||||
|
result = await child_handler_proxy(req_data)
|
||||||
|
|
||||||
|
# 3. Serialize Response
|
||||||
|
return self._serialize_response(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
|
||||||
|
return web.Response(status=500, text=str(e))
|
||||||
|
|
||||||
|
# Register loop
|
||||||
|
self.server.app.router.add_route(method, path, route_wrapper)
|
||||||
|
|
||||||
|
def _serialize_response(self, result: Any) -> web.Response:
|
||||||
|
"""Helper to convert Child result -> web.Response"""
|
||||||
|
if isinstance(result, web.Response):
|
||||||
|
return result
|
||||||
|
# Handle dict (json)
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return web.json_response(result)
|
||||||
|
# Handle string
|
||||||
|
if isinstance(result, str):
|
||||||
|
return web.Response(text=result)
|
||||||
|
# Fallback
|
||||||
|
return web.Response(text=str(result))
|
||||||
64
comfy/isolation/proxies/utils_proxy.py
Normal file
64
comfy/isolation/proxies/utils_proxy.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional, Any
|
||||||
|
import comfy.utils
|
||||||
|
from pyisolate import ProxiedSingleton
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class UtilsProxy(ProxiedSingleton):
|
||||||
|
"""
|
||||||
|
Proxy for comfy.utils.
|
||||||
|
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
|
||||||
|
from isolated nodes reach the host.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# _instance and __new__ removed to rely on SingletonMetaclass
|
||||||
|
_rpc: Optional[Any] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_rpc(cls, rpc: Any) -> None:
|
||||||
|
# Create caller using class name as ID (standard for Singletons)
|
||||||
|
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
|
||||||
|
|
||||||
|
async def progress_bar_hook(
|
||||||
|
self,
|
||||||
|
value: int,
|
||||||
|
total: int,
|
||||||
|
preview: Optional[bytes] = None,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Host-side implementation: forwards the call to the real global hook.
|
||||||
|
Child-side: this method call is intercepted by RPC and sent to host.
|
||||||
|
"""
|
||||||
|
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||||
|
# Manual RPC dispatch for Child process
|
||||||
|
# Use class-level RPC storage (Static Injection)
|
||||||
|
if UtilsProxy._rpc:
|
||||||
|
return await UtilsProxy._rpc.progress_bar_hook(
|
||||||
|
value, total, preview, node_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback channel: global child rpc
|
||||||
|
try:
|
||||||
|
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||||
|
|
||||||
|
get_child_rpc_instance()
|
||||||
|
# If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it,
|
||||||
|
# but we need a caller. For now, just pass to avoid crashing.
|
||||||
|
pass
|
||||||
|
except (ImportError, LookupError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Host Execution
|
||||||
|
if comfy.utils.PROGRESS_BAR_HOOK is not None:
|
||||||
|
comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
|
||||||
|
|
||||||
|
def set_progress_bar_global_hook(self, hook: Any) -> None:
|
||||||
|
"""Forward hook registration (though usually not needed from child)."""
|
||||||
|
comfy.utils.set_progress_bar_global_hook(hook)
|
||||||
Loading…
Reference in New Issue
Block a user