mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-08 10:47:32 +08:00
214 lines
7.2 KiB
Python
214 lines
7.2 KiB
Python
# pylint: disable=global-statement,import-outside-toplevel,protected-access
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import threading
|
|
import weakref
|
|
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
|
|
|
try:
|
|
from pyisolate import ProxiedSingleton
|
|
except ImportError:
|
|
|
|
class ProxiedSingleton: # type: ignore[no-redef]
|
|
pass
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
|
|
_thread_local = threading.local()
|
|
T = TypeVar("T")
|
|
|
|
|
|
def get_thread_loop() -> asyncio.AbstractEventLoop:
|
|
loop = getattr(_thread_local, "loop", None)
|
|
if loop is None or loop.is_closed():
|
|
loop = asyncio.new_event_loop()
|
|
_thread_local.loop = loop
|
|
return loop
|
|
|
|
|
|
def run_coro_in_new_loop(coro: Any) -> Any:
|
|
result_box: Dict[str, Any] = {}
|
|
exc_box: Dict[str, BaseException] = {}
|
|
|
|
def runner() -> None:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
result_box["value"] = loop.run_until_complete(coro)
|
|
except Exception as exc: # noqa: BLE001
|
|
exc_box["exc"] = exc
|
|
finally:
|
|
loop.close()
|
|
|
|
t = threading.Thread(target=runner, daemon=True)
|
|
t.start()
|
|
t.join()
|
|
if "exc" in exc_box:
|
|
raise exc_box["exc"]
|
|
return result_box.get("value")
|
|
|
|
|
|
def detach_if_grad(obj: Any) -> Any:
|
|
try:
|
|
import torch
|
|
except Exception:
|
|
return obj
|
|
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj.detach() if obj.requires_grad else obj
|
|
if isinstance(obj, (list, tuple)):
|
|
return type(obj)(detach_if_grad(x) for x in obj)
|
|
if isinstance(obj, dict):
|
|
return {k: detach_if_grad(v) for k, v in obj.items()}
|
|
return obj
|
|
|
|
|
|
class BaseRegistry(ProxiedSingleton, Generic[T]):
|
|
_type_prefix: str = "base"
|
|
|
|
def __init__(self) -> None:
|
|
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
|
|
super().__init__()
|
|
self._registry: Dict[str, T] = {}
|
|
self._id_map: Dict[int, str] = {}
|
|
self._counter = 0
|
|
self._lock = threading.Lock()
|
|
|
|
def register(self, instance: T) -> str:
|
|
with self._lock:
|
|
obj_id = id(instance)
|
|
if obj_id in self._id_map:
|
|
return self._id_map[obj_id]
|
|
instance_id = f"{self._type_prefix}_{self._counter}"
|
|
self._counter += 1
|
|
self._registry[instance_id] = instance
|
|
self._id_map[obj_id] = instance_id
|
|
return instance_id
|
|
|
|
def unregister_sync(self, instance_id: str) -> None:
|
|
with self._lock:
|
|
instance = self._registry.pop(instance_id, None)
|
|
if instance:
|
|
self._id_map.pop(id(instance), None)
|
|
|
|
def _get_instance(self, instance_id: str) -> T:
|
|
if IS_CHILD_PROCESS:
|
|
raise RuntimeError(
|
|
f"[{self.__class__.__name__}] _get_instance called in child"
|
|
)
|
|
with self._lock:
|
|
instance = self._registry.get(instance_id)
|
|
if instance is None:
|
|
raise ValueError(f"{instance_id} not found")
|
|
return instance
|
|
|
|
|
|
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
|
|
|
|
|
|
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
|
global _GLOBAL_LOOP
|
|
_GLOBAL_LOOP = loop
|
|
|
|
|
|
class BaseProxy(Generic[T]):
|
|
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
|
__module__: str = "comfy.isolation.proxies.base"
|
|
|
|
def __init__(
|
|
self,
|
|
instance_id: str,
|
|
registry: Optional[Any] = None,
|
|
manage_lifecycle: bool = False,
|
|
) -> None:
|
|
self._instance_id = instance_id
|
|
self._rpc_caller: Optional[Any] = None
|
|
self._registry = registry if registry is not None else self._registry_class()
|
|
self._manage_lifecycle = manage_lifecycle
|
|
self._cleaned_up = False
|
|
if manage_lifecycle and not IS_CHILD_PROCESS:
|
|
self._finalizer = weakref.finalize(
|
|
self, self._registry.unregister_sync, instance_id
|
|
)
|
|
|
|
def _get_rpc(self) -> Any:
|
|
if self._rpc_caller is None:
|
|
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
|
|
|
rpc = get_child_rpc_instance()
|
|
if rpc is None:
|
|
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
|
|
self._rpc_caller = rpc.create_caller(
|
|
self._registry_class, self._registry_class.get_remote_id()
|
|
)
|
|
return self._rpc_caller
|
|
|
|
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
|
rpc = self._get_rpc()
|
|
method = getattr(rpc, method_name)
|
|
coro = method(self._instance_id, *args, **kwargs)
|
|
|
|
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
|
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
|
try:
|
|
# If we are already in the global loop, we can't block on it?
|
|
# Actually, this method is synchronous (__getattr__ -> lambda).
|
|
# If called from async context in main loop, we need to handle that.
|
|
curr_loop = asyncio.get_running_loop()
|
|
if curr_loop is _GLOBAL_LOOP:
|
|
# We are in the main loop. We cannot await/block here if we are just a sync function.
|
|
# But proxies are often called from sync code.
|
|
# If called from sync code in main loop, creating a new loop is bad.
|
|
# But we can't await `coro`.
|
|
# This implies proxies MUST be awaited if called from async context?
|
|
# Existing code used `run_coro_in_new_loop` which is weird.
|
|
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
|
|
# we use run_coroutine_threadsafe.
|
|
pass
|
|
except RuntimeError:
|
|
# No running loop - we are in a worker thread.
|
|
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
|
return future.result()
|
|
|
|
try:
|
|
asyncio.get_running_loop()
|
|
return run_coro_in_new_loop(coro)
|
|
except RuntimeError:
|
|
loop = get_thread_loop()
|
|
return loop.run_until_complete(coro)
|
|
|
|
def __getstate__(self) -> Dict[str, Any]:
|
|
return {"_instance_id": self._instance_id}
|
|
|
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
|
self._instance_id = state["_instance_id"]
|
|
self._rpc_caller = None
|
|
self._registry = self._registry_class()
|
|
self._manage_lifecycle = False
|
|
self._cleaned_up = False
|
|
|
|
def cleanup(self) -> None:
|
|
if self._cleaned_up or IS_CHILD_PROCESS:
|
|
return
|
|
self._cleaned_up = True
|
|
finalizer = getattr(self, "_finalizer", None)
|
|
if finalizer is not None:
|
|
finalizer.detach()
|
|
self._registry.unregister_sync(self._instance_id)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<{self.__class__.__name__} {self._instance_id}>"
|
|
|
|
|
|
def create_rpc_method(method_name: str) -> Callable[..., Any]:
|
|
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
|
|
return self._call_rpc(method_name, *args, **kwargs)
|
|
|
|
method.__name__ = method_name
|
|
return method
|