mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 10:17:31 +08:00
- harden isolation ModelPatcher proxy/registry behavior for DynamicVRAM-backed patchers - improve serializer/adapter boundaries (device/dtype/model refs) to reduce pre-inference lockups - add structured ISO registry/modelsampling telemetry and explicit RPC timeout surfacing - preserve isolation-first lifecycle handling and boundary cleanup sequencing - validate isolated workflows: most targeted runs now complete under --use-sage-attention --use-process-isolation --disable-cuda-malloc Known issue (reproducible): - isolation_99_full_iso_stack still times out at SamplerCustom_ISO path - failure is explicit RPC timeout: ModelPatcherProxy.process_latent_in(instance_id=model_0, timeout_ms=120000) - this indicates the remaining stall is on process_latent_in RPC path, not generic startup/manager fetch
284 lines
9.3 KiB
Python
284 lines
9.3 KiB
Python
# pylint: disable=global-statement,import-outside-toplevel,protected-access
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import concurrent.futures
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
import weakref
|
|
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
|
|
|
try:
|
|
from pyisolate import ProxiedSingleton
|
|
except ImportError:
|
|
|
|
class ProxiedSingleton: # type: ignore[no-redef]
|
|
pass
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
|
|
_thread_local = threading.local()
|
|
T = TypeVar("T")
|
|
|
|
|
|
def get_thread_loop() -> asyncio.AbstractEventLoop:
|
|
loop = getattr(_thread_local, "loop", None)
|
|
if loop is None or loop.is_closed():
|
|
loop = asyncio.new_event_loop()
|
|
_thread_local.loop = loop
|
|
return loop
|
|
|
|
|
|
def run_coro_in_new_loop(coro: Any) -> Any:
|
|
result_box: Dict[str, Any] = {}
|
|
exc_box: Dict[str, BaseException] = {}
|
|
|
|
def runner() -> None:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
result_box["value"] = loop.run_until_complete(coro)
|
|
except Exception as exc: # noqa: BLE001
|
|
exc_box["exc"] = exc
|
|
finally:
|
|
loop.close()
|
|
|
|
t = threading.Thread(target=runner, daemon=True)
|
|
t.start()
|
|
t.join()
|
|
if "exc" in exc_box:
|
|
raise exc_box["exc"]
|
|
return result_box.get("value")
|
|
|
|
|
|
def detach_if_grad(obj: Any) -> Any:
|
|
try:
|
|
import torch
|
|
except Exception:
|
|
return obj
|
|
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj.detach() if obj.requires_grad else obj
|
|
if isinstance(obj, (list, tuple)):
|
|
return type(obj)(detach_if_grad(x) for x in obj)
|
|
if isinstance(obj, dict):
|
|
return {k: detach_if_grad(v) for k, v in obj.items()}
|
|
return obj
|
|
|
|
|
|
class BaseRegistry(ProxiedSingleton, Generic[T]):
|
|
_type_prefix: str = "base"
|
|
|
|
def __init__(self) -> None:
|
|
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
|
|
super().__init__()
|
|
self._registry: Dict[str, T] = {}
|
|
self._id_map: Dict[int, str] = {}
|
|
self._counter = 0
|
|
self._lock = threading.Lock()
|
|
|
|
def register(self, instance: T) -> str:
|
|
with self._lock:
|
|
obj_id = id(instance)
|
|
if obj_id in self._id_map:
|
|
return self._id_map[obj_id]
|
|
instance_id = f"{self._type_prefix}_{self._counter}"
|
|
self._counter += 1
|
|
self._registry[instance_id] = instance
|
|
self._id_map[obj_id] = instance_id
|
|
return instance_id
|
|
|
|
def unregister_sync(self, instance_id: str) -> None:
|
|
with self._lock:
|
|
instance = self._registry.pop(instance_id, None)
|
|
if instance:
|
|
self._id_map.pop(id(instance), None)
|
|
|
|
def _get_instance(self, instance_id: str) -> T:
|
|
if IS_CHILD_PROCESS:
|
|
raise RuntimeError(
|
|
f"[{self.__class__.__name__}] _get_instance called in child"
|
|
)
|
|
with self._lock:
|
|
instance = self._registry.get(instance_id)
|
|
if instance is None:
|
|
raise ValueError(f"{instance_id} not found")
|
|
return instance
|
|
|
|
|
|
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
|
|
|
|
|
|
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
|
global _GLOBAL_LOOP
|
|
_GLOBAL_LOOP = loop
|
|
|
|
|
|
class BaseProxy(Generic[T]):
|
|
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
|
__module__: str = "comfy.isolation.proxies.base"
|
|
_TIMEOUT_RPC_METHODS = frozenset(
|
|
{
|
|
"partially_load",
|
|
"partially_unload",
|
|
"load",
|
|
"patch_model",
|
|
"unpatch_model",
|
|
"inner_model_apply_model",
|
|
"memory_required",
|
|
"model_dtype",
|
|
"inner_model_memory_required",
|
|
"inner_model_extra_conds_shapes",
|
|
"inner_model_extra_conds",
|
|
"process_latent_in",
|
|
"process_latent_out",
|
|
"scale_latent_inpaint",
|
|
}
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
instance_id: str,
|
|
registry: Optional[Any] = None,
|
|
manage_lifecycle: bool = False,
|
|
) -> None:
|
|
self._instance_id = instance_id
|
|
self._rpc_caller: Optional[Any] = None
|
|
self._registry = registry if registry is not None else self._registry_class()
|
|
self._manage_lifecycle = manage_lifecycle
|
|
self._cleaned_up = False
|
|
if manage_lifecycle and not IS_CHILD_PROCESS:
|
|
self._finalizer = weakref.finalize(
|
|
self, self._registry.unregister_sync, instance_id
|
|
)
|
|
|
|
def _get_rpc(self) -> Any:
|
|
if self._rpc_caller is None:
|
|
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
|
|
|
rpc = get_child_rpc_instance()
|
|
if rpc is None:
|
|
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
|
|
self._rpc_caller = rpc.create_caller(
|
|
self._registry_class, self._registry_class.get_remote_id()
|
|
)
|
|
return self._rpc_caller
|
|
|
|
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
|
|
if method_name not in self._TIMEOUT_RPC_METHODS:
|
|
return None
|
|
try:
|
|
timeout_ms = int(
|
|
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
|
|
)
|
|
except ValueError:
|
|
timeout_ms = 120000
|
|
return max(1, timeout_ms)
|
|
|
|
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
|
rpc = self._get_rpc()
|
|
method = getattr(rpc, method_name)
|
|
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
|
|
coro = method(self._instance_id, *args, **kwargs)
|
|
if timeout_ms is not None:
|
|
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
|
|
|
|
start_epoch = time.time()
|
|
start_perf = time.perf_counter()
|
|
thread_id = threading.get_ident()
|
|
try:
|
|
running_loop = asyncio.get_running_loop()
|
|
loop_id: Optional[int] = id(running_loop)
|
|
except RuntimeError:
|
|
loop_id = None
|
|
logger.debug(
|
|
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
|
|
"thread=%s loop=%s timeout_ms=%s",
|
|
self.__class__.__name__,
|
|
method_name,
|
|
self._instance_id,
|
|
start_epoch,
|
|
thread_id,
|
|
loop_id,
|
|
timeout_ms,
|
|
)
|
|
|
|
try:
|
|
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
|
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
|
try:
|
|
curr_loop = asyncio.get_running_loop()
|
|
if curr_loop is _GLOBAL_LOOP:
|
|
pass
|
|
except RuntimeError:
|
|
# No running loop - we are in a worker thread.
|
|
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
|
return future.result(
|
|
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
|
|
)
|
|
|
|
try:
|
|
asyncio.get_running_loop()
|
|
return run_coro_in_new_loop(coro)
|
|
except RuntimeError:
|
|
loop = get_thread_loop()
|
|
return loop.run_until_complete(coro)
|
|
except asyncio.TimeoutError as exc:
|
|
raise TimeoutError(
|
|
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
|
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
|
) from exc
|
|
except concurrent.futures.TimeoutError as exc:
|
|
raise TimeoutError(
|
|
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
|
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
|
) from exc
|
|
finally:
|
|
end_epoch = time.time()
|
|
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
|
logger.debug(
|
|
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
|
|
"elapsed_ms=%.3f thread=%s loop=%s",
|
|
self.__class__.__name__,
|
|
method_name,
|
|
self._instance_id,
|
|
end_epoch,
|
|
elapsed_ms,
|
|
thread_id,
|
|
loop_id,
|
|
)
|
|
|
|
def __getstate__(self) -> Dict[str, Any]:
|
|
return {"_instance_id": self._instance_id}
|
|
|
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
|
self._instance_id = state["_instance_id"]
|
|
self._rpc_caller = None
|
|
self._registry = self._registry_class()
|
|
self._manage_lifecycle = False
|
|
self._cleaned_up = False
|
|
|
|
def cleanup(self) -> None:
|
|
if self._cleaned_up or IS_CHILD_PROCESS:
|
|
return
|
|
self._cleaned_up = True
|
|
finalizer = getattr(self, "_finalizer", None)
|
|
if finalizer is not None:
|
|
finalizer.detach()
|
|
self._registry.unregister_sync(self._instance_id)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<{self.__class__.__name__} {self._instance_id}>"
|
|
|
|
|
|
def create_rpc_method(method_name: str) -> Callable[..., Any]:
|
|
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
|
|
return self._call_rpc(method_name, *args, **kwargs)
|
|
|
|
method.__name__ = method_name
|
|
return method
|