ComfyUI/comfy/isolation/proxies/base.py
John Pollock 26edd5663d
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
isolation+dynamicvram: stabilize ModelPatcher RPC path, add diagnostics; known process_latent_in timeout remains
- harden isolation ModelPatcher proxy/registry behavior for DynamicVRAM-backed patchers
- improve serializer/adapter boundaries (device/dtype/model refs) to reduce pre-inference lockups
- add structured ISO registry/modelsampling telemetry and explicit RPC timeout surfacing
- preserve isolation-first lifecycle handling and boundary cleanup sequencing
- validate isolated workflows: most targeted runs now complete under
  --use-sage-attention --use-process-isolation --disable-cuda-malloc

Known issue (reproducible):
- isolation_99_full_iso_stack still times out at SamplerCustom_ISO path
- failure is explicit RPC timeout:
  ModelPatcherProxy.process_latent_in(instance_id=model_0, timeout_ms=120000)
- this indicates the remaining stall is on process_latent_in RPC path, not generic startup/manager fetch
2026-03-04 10:41:33 -06:00

284 lines
9.3 KiB
Python

# pylint: disable=global-statement,import-outside-toplevel,protected-access
from __future__ import annotations
import asyncio
import concurrent.futures
import logging
import os
import threading
import time
import weakref
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
try:
from pyisolate import ProxiedSingleton
except ImportError:
class ProxiedSingleton: # type: ignore[no-redef]
pass
logger = logging.getLogger(__name__)
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
_thread_local = threading.local()
T = TypeVar("T")
def get_thread_loop() -> asyncio.AbstractEventLoop:
loop = getattr(_thread_local, "loop", None)
if loop is None or loop.is_closed():
loop = asyncio.new_event_loop()
_thread_local.loop = loop
return loop
def run_coro_in_new_loop(coro: Any) -> Any:
result_box: Dict[str, Any] = {}
exc_box: Dict[str, BaseException] = {}
def runner() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result_box["value"] = loop.run_until_complete(coro)
except Exception as exc: # noqa: BLE001
exc_box["exc"] = exc
finally:
loop.close()
t = threading.Thread(target=runner, daemon=True)
t.start()
t.join()
if "exc" in exc_box:
raise exc_box["exc"]
return result_box.get("value")
def detach_if_grad(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
return obj.detach() if obj.requires_grad else obj
if isinstance(obj, (list, tuple)):
return type(obj)(detach_if_grad(x) for x in obj)
if isinstance(obj, dict):
return {k: detach_if_grad(v) for k, v in obj.items()}
return obj
class BaseRegistry(ProxiedSingleton, Generic[T]):
_type_prefix: str = "base"
def __init__(self) -> None:
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
super().__init__()
self._registry: Dict[str, T] = {}
self._id_map: Dict[int, str] = {}
self._counter = 0
self._lock = threading.Lock()
def register(self, instance: T) -> str:
with self._lock:
obj_id = id(instance)
if obj_id in self._id_map:
return self._id_map[obj_id]
instance_id = f"{self._type_prefix}_{self._counter}"
self._counter += 1
self._registry[instance_id] = instance
self._id_map[obj_id] = instance_id
return instance_id
def unregister_sync(self, instance_id: str) -> None:
with self._lock:
instance = self._registry.pop(instance_id, None)
if instance:
self._id_map.pop(id(instance), None)
def _get_instance(self, instance_id: str) -> T:
if IS_CHILD_PROCESS:
raise RuntimeError(
f"[{self.__class__.__name__}] _get_instance called in child"
)
with self._lock:
instance = self._registry.get(instance_id)
if instance is None:
raise ValueError(f"{instance_id} not found")
return instance
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
global _GLOBAL_LOOP
_GLOBAL_LOOP = loop
class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base"
_TIMEOUT_RPC_METHODS = frozenset(
{
"partially_load",
"partially_unload",
"load",
"patch_model",
"unpatch_model",
"inner_model_apply_model",
"memory_required",
"model_dtype",
"inner_model_memory_required",
"inner_model_extra_conds_shapes",
"inner_model_extra_conds",
"process_latent_in",
"process_latent_out",
"scale_latent_inpaint",
}
)
def __init__(
self,
instance_id: str,
registry: Optional[Any] = None,
manage_lifecycle: bool = False,
) -> None:
self._instance_id = instance_id
self._rpc_caller: Optional[Any] = None
self._registry = registry if registry is not None else self._registry_class()
self._manage_lifecycle = manage_lifecycle
self._cleaned_up = False
if manage_lifecycle and not IS_CHILD_PROCESS:
self._finalizer = weakref.finalize(
self, self._registry.unregister_sync, instance_id
)
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is None:
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
self._rpc_caller = rpc.create_caller(
self._registry_class, self._registry_class.get_remote_id()
)
return self._rpc_caller
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
if method_name not in self._TIMEOUT_RPC_METHODS:
return None
try:
timeout_ms = int(
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
)
except ValueError:
timeout_ms = 120000
return max(1, timeout_ms)
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
coro = method(self._instance_id, *args, **kwargs)
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
try:
running_loop = asyncio.get_running_loop()
loop_id: Optional[int] = id(running_loop)
except RuntimeError:
loop_id = None
logger.debug(
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
"thread=%s loop=%s timeout_ms=%s",
self.__class__.__name__,
method_name,
self._instance_id,
start_epoch,
thread_id,
loop_id,
timeout_ms,
)
try:
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
finally:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
"elapsed_ms=%.3f thread=%s loop=%s",
self.__class__.__name__,
method_name,
self._instance_id,
end_epoch,
elapsed_ms,
thread_id,
loop_id,
)
def __getstate__(self) -> Dict[str, Any]:
return {"_instance_id": self._instance_id}
def __setstate__(self, state: Dict[str, Any]) -> None:
self._instance_id = state["_instance_id"]
self._rpc_caller = None
self._registry = self._registry_class()
self._manage_lifecycle = False
self._cleaned_up = False
def cleanup(self) -> None:
if self._cleaned_up or IS_CHILD_PROCESS:
return
self._cleaned_up = True
finalizer = getattr(self, "_finalizer", None)
if finalizer is not None:
finalizer.detach()
self._registry.unregister_sync(self._instance_id)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._instance_id}>"
def create_rpc_method(method_name: str) -> Callable[..., Any]:
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
return self._call_rpc(method_name, *args, **kwargs)
method.__name__ = method_name
return method