mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 10:17:31 +08:00
isolation+dynamicvram: stabilize ModelPatcher RPC path, add diagnostics; known process_latent_in timeout remains
- harden isolation ModelPatcher proxy/registry behavior for DynamicVRAM-backed patchers - improve serializer/adapter boundaries (device/dtype/model refs) to reduce pre-inference lockups - add structured ISO registry/modelsampling telemetry and explicit RPC timeout surfacing - preserve isolation-first lifecycle handling and boundary cleanup sequencing - validate isolated workflows: most targeted runs now complete under --use-sage-attention --use-process-isolation --disable-cuda-malloc Known issue (reproducible): - isolation_99_full_iso_stack still times out at SamplerCustom_ISO path - failure is explicit RPC timeout: ModelPatcherProxy.process_latent_in(instance_id=model_0, timeout_ms=120000) - this indicates the remaining stall is on process_latent_in RPC path, not generic startup/manager fetch
This commit is contained in:
parent
6ba03a0747
commit
26edd5663d
@ -26,6 +26,7 @@ PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
|
||||
|
||||
|
||||
def initialize_proxies() -> None:
|
||||
@ -168,6 +169,11 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||
|
||||
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
"""Evict running extensions not needed for current execution."""
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:notify_graph_wait_idle",
|
||||
)
|
||||
|
||||
async def _stop_extension(
|
||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||
@ -237,6 +243,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
|
||||
|
||||
async def flush_running_extensions_transport_state() -> int:
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:flush_transport_wait_idle",
|
||||
)
|
||||
total_flushed = 0
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||
@ -263,6 +274,50 @@ async def flush_running_extensions_transport_state() -> int:
|
||||
return total_flushed
|
||||
|
||||
|
||||
async def wait_for_model_patcher_quiescence(
|
||||
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
marker: str = "ISO:wait_model_patcher_idle",
|
||||
) -> bool:
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
start = time.perf_counter()
|
||||
idle = await registry.wait_all_idle(timeout_ms)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
if idle:
|
||||
logger.debug(
|
||||
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
)
|
||||
return True
|
||||
|
||||
states = await registry.get_all_operation_states()
|
||||
logger.error(
|
||||
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
states,
|
||||
)
|
||||
if fail_loud:
|
||||
raise TimeoutError(
|
||||
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_claimed_paths() -> Set[Path]:
|
||||
return _CLAIMED_PATHS
|
||||
|
||||
@ -320,6 +375,7 @@ __all__ = [
|
||||
"await_isolation_loading",
|
||||
"notify_execution_graph",
|
||||
"flush_running_extensions_transport_state",
|
||||
"wait_for_model_patcher_quiescence",
|
||||
"get_claimed_paths",
|
||||
"update_rpc_event_loops",
|
||||
"IsolatedNodeSpec",
|
||||
|
||||
@ -83,6 +83,33 @@ class ComfyUIAdapter(IsolationAdapter):
|
||||
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||
|
||||
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||
import torch
|
||||
|
||||
def serialize_device(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "device", "device_str": str(obj)}
|
||||
|
||||
def deserialize_device(data: Dict[str, Any]) -> Any:
|
||||
return torch.device(data["device_str"])
|
||||
|
||||
registry.register("device", serialize_device, deserialize_device)
|
||||
|
||||
_VALID_DTYPES = {
|
||||
"float16", "float32", "float64", "bfloat16",
|
||||
"int8", "int16", "int32", "int64",
|
||||
"uint8", "bool",
|
||||
}
|
||||
|
||||
def serialize_dtype(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "dtype", "dtype_str": str(obj)}
|
||||
|
||||
def deserialize_dtype(data: Dict[str, Any]) -> Any:
|
||||
dtype_name = data["dtype_str"].replace("torch.", "")
|
||||
if dtype_name not in _VALID_DTYPES:
|
||||
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
|
||||
return getattr(torch, dtype_name)
|
||||
|
||||
registry.register("dtype", serialize_dtype, deserialize_dtype)
|
||||
|
||||
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
@ -193,6 +220,10 @@ class ComfyUIAdapter(IsolationAdapter):
|
||||
f"ModelSampling in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side pass-through for proxies: do not re-register a proxy as a
|
||||
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||
ms_id = ModelSamplingRegistry().register(obj)
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||
@ -211,22 +242,21 @@ class ComfyUIAdapter(IsolationAdapter):
|
||||
else:
|
||||
return ModelSamplingRegistry()._get_instance(data["ms_id"])
|
||||
|
||||
# Register ModelSampling type and proxy
|
||||
registry.register(
|
||||
"ModelSamplingDiscrete",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingContinuousEDM",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingContinuousV",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
# Register all ModelSampling* and StableCascadeSampling classes dynamically
|
||||
import comfy.model_sampling
|
||||
|
||||
for ms_cls in vars(comfy.model_sampling).values():
|
||||
if not isinstance(ms_cls, type):
|
||||
continue
|
||||
if not issubclass(ms_cls, torch.nn.Module):
|
||||
continue
|
||||
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
|
||||
continue
|
||||
registry.register(
|
||||
ms_cls.__name__,
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
|
||||
)
|
||||
|
||||
@ -382,6 +382,10 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
|
||||
if type(result).__name__ == "NodeOutput":
|
||||
result = result.args
|
||||
print(
|
||||
f"{LOG_PREFIX} ISO:child_result_ready node={node_name} type={type(result).__name__}",
|
||||
flush=True,
|
||||
)
|
||||
if self._is_comfy_protocol_return(result):
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_done ext=%s node=%s protocol_return=1",
|
||||
@ -389,10 +393,17 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
return self._wrap_unpicklable_objects(result)
|
||||
print(f"{LOG_PREFIX} ISO:child_wrap_start node={node_name} protocol=1", flush=True)
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
print(f"{LOG_PREFIX} ISO:child_wrap_done node={node_name} protocol=1", flush=True)
|
||||
return wrapped
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
print(
|
||||
f"{LOG_PREFIX} ISO:child_result_tuple node={node_name} outputs={len(result)}",
|
||||
flush=True,
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d",
|
||||
LOG_PREFIX,
|
||||
@ -400,7 +411,10 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
node_name,
|
||||
len(result),
|
||||
)
|
||||
return self._wrap_unpicklable_objects(result)
|
||||
print(f"{LOG_PREFIX} ISO:child_wrap_start node={node_name} protocol=0", flush=True)
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
print(f"{LOG_PREFIX} ISO:child_wrap_done node={node_name} protocol=0", flush=True)
|
||||
return wrapped
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
|
||||
@ -443,7 +457,10 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
if isinstance(data, (str, int, float, bool, type(None))):
|
||||
return data
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.detach() if data.requires_grad else data
|
||||
tensor = data.detach() if data.requires_grad else data
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
|
||||
return tensor.cpu()
|
||||
return tensor
|
||||
|
||||
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||
|
||||
@ -338,6 +338,34 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
def apply_model(self, *args, **kwargs) -> Any:
|
||||
import torch
|
||||
|
||||
def _preferred_device() -> Any:
|
||||
for value in args:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
return None
|
||||
|
||||
def _move_result_to_device(obj: Any, device: Any) -> Any:
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.to(device) if obj.device != device else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_move_result_to_device(v, device) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_move_result_to_device(v, device) for v in obj)
|
||||
return obj
|
||||
|
||||
# DynamicVRAM models must keep load/offload decisions in host process.
|
||||
# Child-side CUDA staging here can deadlock before first inference RPC.
|
||||
if self.is_dynamic():
|
||||
out = self._call_rpc("inner_model_apply_model", args, kwargs)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
|
||||
@ -360,7 +388,8 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
|
||||
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
keys = self._call_rpc("model_state_dict", filter_prefix)
|
||||
@ -526,6 +555,13 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
def memory_required(self, input_shape: Any) -> Any:
|
||||
return self._call_rpc("memory_required", input_shape)
|
||||
|
||||
def get_operation_state(self) -> Dict[str, Any]:
|
||||
state = self._call_rpc("get_operation_state")
|
||||
return state if isinstance(state, dict) else {}
|
||||
|
||||
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
|
||||
return bool(self._call_rpc("wait_for_idle", timeout_ms))
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return bool(self._call_rpc("is_dynamic"))
|
||||
|
||||
@ -771,6 +807,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
class _InnerModelProxy:
|
||||
def __init__(self, parent: ModelPatcherProxy):
|
||||
self._parent = parent
|
||||
self._model_sampling = None
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
@ -793,7 +830,11 @@ class _InnerModelProxy:
|
||||
manage_lifecycle=False,
|
||||
)
|
||||
if name == "model_sampling":
|
||||
return self._parent._call_rpc("get_model_object", "model_sampling")
|
||||
if self._model_sampling is None:
|
||||
self._model_sampling = self._parent._call_rpc(
|
||||
"get_model_object", "model_sampling"
|
||||
)
|
||||
return self._model_sampling
|
||||
if name == "extra_conds_shapes":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds_shapes", a, k
|
||||
|
||||
@ -2,8 +2,12 @@
|
||||
# RPC server for ModelPatcher isolation (child process)
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, List
|
||||
|
||||
try:
|
||||
@ -43,12 +47,191 @@ from comfy.isolation.proxies.base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OperationState:
|
||||
lease: threading.Lock = field(default_factory=threading.Lock)
|
||||
active_count: int = 0
|
||||
active_by_method: dict[str, int] = field(default_factory=dict)
|
||||
total_operations: int = 0
|
||||
last_method: Optional[str] = None
|
||||
last_started_ts: Optional[float] = None
|
||||
last_ended_ts: Optional[float] = None
|
||||
last_elapsed_ms: Optional[float] = None
|
||||
last_error: Optional[str] = None
|
||||
last_thread_id: Optional[int] = None
|
||||
last_loop_id: Optional[int] = None
|
||||
|
||||
|
||||
class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "model"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._pending_cleanup_ids: set[str] = set()
|
||||
self._operation_states: dict[str, _OperationState] = {}
|
||||
self._operation_state_cv = threading.Condition(self._lock)
|
||||
|
||||
def _get_or_create_operation_state(self, instance_id: str) -> _OperationState:
|
||||
state = self._operation_states.get(instance_id)
|
||||
if state is None:
|
||||
state = _OperationState()
|
||||
self._operation_states[instance_id] = state
|
||||
return state
|
||||
|
||||
def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]:
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
with self._operation_state_cv:
|
||||
state = self._get_or_create_operation_state(instance_id)
|
||||
state.active_count += 1
|
||||
state.active_by_method[method_name] = (
|
||||
state.active_by_method.get(method_name, 0) + 1
|
||||
)
|
||||
state.total_operations += 1
|
||||
state.last_method = method_name
|
||||
state.last_started_ts = start_epoch
|
||||
state.last_thread_id = threading.get_ident()
|
||||
try:
|
||||
state.last_loop_id = id(asyncio.get_running_loop())
|
||||
except RuntimeError:
|
||||
state.last_loop_id = None
|
||||
logger.debug(
|
||||
"ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s",
|
||||
instance_id,
|
||||
method_name,
|
||||
start_epoch,
|
||||
threading.get_ident(),
|
||||
state.last_loop_id,
|
||||
)
|
||||
return start_epoch, start_perf
|
||||
|
||||
def _end_operation(
|
||||
self,
|
||||
instance_id: str,
|
||||
method_name: str,
|
||||
start_perf: float,
|
||||
error: Optional[BaseException] = None,
|
||||
) -> None:
|
||||
end_epoch = time.time()
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
with self._operation_state_cv:
|
||||
state = self._get_or_create_operation_state(instance_id)
|
||||
state.active_count = max(0, state.active_count - 1)
|
||||
if method_name in state.active_by_method:
|
||||
remaining = state.active_by_method[method_name] - 1
|
||||
if remaining <= 0:
|
||||
state.active_by_method.pop(method_name, None)
|
||||
else:
|
||||
state.active_by_method[method_name] = remaining
|
||||
state.last_ended_ts = end_epoch
|
||||
state.last_elapsed_ms = elapsed_ms
|
||||
state.last_error = None if error is None else repr(error)
|
||||
if state.active_count == 0:
|
||||
self._operation_state_cv.notify_all()
|
||||
logger.debug(
|
||||
"ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s",
|
||||
instance_id,
|
||||
method_name,
|
||||
end_epoch,
|
||||
elapsed_ms,
|
||||
None if error is None else type(error).__name__,
|
||||
)
|
||||
|
||||
def _run_operation_with_lease(self, instance_id: str, method_name: str, fn):
|
||||
with self._operation_state_cv:
|
||||
state = self._get_or_create_operation_state(instance_id)
|
||||
lease = state.lease
|
||||
with lease:
|
||||
_, start_perf = self._begin_operation(instance_id, method_name)
|
||||
try:
|
||||
result = fn()
|
||||
except Exception as exc:
|
||||
self._end_operation(instance_id, method_name, start_perf, error=exc)
|
||||
raise
|
||||
self._end_operation(instance_id, method_name, start_perf)
|
||||
return result
|
||||
|
||||
def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]:
|
||||
with self._operation_state_cv:
|
||||
state = self._operation_states.get(instance_id)
|
||||
if state is None:
|
||||
return {
|
||||
"instance_id": instance_id,
|
||||
"active_count": 0,
|
||||
"active_methods": [],
|
||||
"total_operations": 0,
|
||||
"last_method": None,
|
||||
"last_started_ts": None,
|
||||
"last_ended_ts": None,
|
||||
"last_elapsed_ms": None,
|
||||
"last_error": None,
|
||||
"last_thread_id": None,
|
||||
"last_loop_id": None,
|
||||
}
|
||||
return {
|
||||
"instance_id": instance_id,
|
||||
"active_count": state.active_count,
|
||||
"active_methods": sorted(state.active_by_method.keys()),
|
||||
"total_operations": state.total_operations,
|
||||
"last_method": state.last_method,
|
||||
"last_started_ts": state.last_started_ts,
|
||||
"last_ended_ts": state.last_ended_ts,
|
||||
"last_elapsed_ms": state.last_elapsed_ms,
|
||||
"last_error": state.last_error,
|
||||
"last_thread_id": state.last_thread_id,
|
||||
"last_loop_id": state.last_loop_id,
|
||||
}
|
||||
|
||||
def unregister_sync(self, instance_id: str) -> None:
|
||||
with self._operation_state_cv:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance is not None:
|
||||
self._id_map.pop(id(instance), None)
|
||||
self._pending_cleanup_ids.discard(instance_id)
|
||||
self._operation_states.pop(instance_id, None)
|
||||
self._operation_state_cv.notify_all()
|
||||
|
||||
async def get_operation_state(self, instance_id: str) -> dict[str, Any]:
|
||||
return self._snapshot_operation_state(instance_id)
|
||||
|
||||
async def get_all_operation_states(self) -> dict[str, dict[str, Any]]:
|
||||
with self._operation_state_cv:
|
||||
ids = sorted(self._operation_states.keys())
|
||||
return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids}
|
||||
|
||||
async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool:
|
||||
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
|
||||
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
|
||||
with self._operation_state_cv:
|
||||
while True:
|
||||
active = self._operation_states.get(instance_id)
|
||||
if active is None or active.active_count == 0:
|
||||
return True
|
||||
if deadline is None:
|
||||
self._operation_state_cv.wait()
|
||||
continue
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
self._operation_state_cv.wait(timeout=remaining)
|
||||
|
||||
async def wait_all_idle(self, timeout_ms: int = 0) -> bool:
|
||||
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
|
||||
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
|
||||
with self._operation_state_cv:
|
||||
while True:
|
||||
has_active = any(
|
||||
state.active_count > 0 for state in self._operation_states.values()
|
||||
)
|
||||
if not has_active:
|
||||
return True
|
||||
if deadline is None:
|
||||
self._operation_state_cv.wait()
|
||||
continue
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
self._operation_state_cv.wait(timeout=remaining)
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
instance = self._get_instance(instance_id)
|
||||
@ -62,8 +245,10 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
return False
|
||||
|
||||
async def get_model_object(self, instance_id: str, name: str) -> Any:
|
||||
print(f"GP_DEBUG: get_model_object START for name={name}", flush=True)
|
||||
instance = self._get_instance(instance_id)
|
||||
if name == "model":
|
||||
print(f"GP_DEBUG: get_model_object END for name={name} (ModelObject)", flush=True)
|
||||
return f"<ModelObject: {type(instance.model).__name__}>"
|
||||
result = instance.get_model_object(name)
|
||||
if name == "model_sampling":
|
||||
@ -73,8 +258,16 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
)
|
||||
|
||||
registry = ModelSamplingRegistry()
|
||||
sampling_id = registry.register(result)
|
||||
# Preserve identity when upstream already returned a proxy. Re-registering
|
||||
# a proxy object creates proxy-of-proxy call chains.
|
||||
if isinstance(result, ModelSamplingProxy):
|
||||
sampling_id = result._instance_id
|
||||
else:
|
||||
sampling_id = registry.register(result)
|
||||
print(f"GP_DEBUG: get_model_object END for name={name} (model_sampling)", flush=True)
|
||||
return ModelSamplingProxy(sampling_id, registry)
|
||||
|
||||
print(f"GP_DEBUG: get_model_object END for name={name} (fallthrough)", flush=True)
|
||||
return detach_if_grad(result)
|
||||
|
||||
async def get_model_options(self, instance_id: str) -> dict:
|
||||
@ -163,7 +356,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
return self._get_instance(instance_id).lowvram_patch_counter()
|
||||
|
||||
async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
|
||||
return self._get_instance(instance_id).memory_required(input_shape)
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"memory_required",
|
||||
lambda: self._get_instance(instance_id).memory_required(input_shape),
|
||||
)
|
||||
|
||||
async def is_dynamic(self, instance_id: str) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
@ -186,7 +383,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
return None
|
||||
|
||||
async def model_dtype(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).model_dtype()
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"model_dtype",
|
||||
lambda: self._get_instance(instance_id).model_dtype(),
|
||||
)
|
||||
|
||||
async def model_patches_to(self, instance_id: str, device: Any) -> Any:
|
||||
return self._get_instance(instance_id).model_patches_to(device)
|
||||
@ -198,8 +399,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
extra_memory: Any,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).partially_load(
|
||||
device, extra_memory, force_patch_weights=force_patch_weights
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"partially_load",
|
||||
lambda: self._get_instance(instance_id).partially_load(
|
||||
device, extra_memory, force_patch_weights=force_patch_weights
|
||||
),
|
||||
)
|
||||
|
||||
async def partially_unload(
|
||||
@ -209,8 +414,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
memory_to_free: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
) -> int:
|
||||
return self._get_instance(instance_id).partially_unload(
|
||||
device_to, memory_to_free, force_patch_weights
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"partially_unload",
|
||||
lambda: self._get_instance(instance_id).partially_unload(
|
||||
device_to, memory_to_free, force_patch_weights
|
||||
),
|
||||
)
|
||||
|
||||
async def load(
|
||||
@ -221,8 +430,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).load(
|
||||
device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"load",
|
||||
lambda: self._get_instance(instance_id).load(
|
||||
device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
),
|
||||
)
|
||||
|
||||
async def patch_model(
|
||||
@ -233,20 +446,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
self._get_instance(instance_id).patch_model(
|
||||
device_to, lowvram_model_memory, load_weights, force_patch_weights
|
||||
)
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
|
||||
)
|
||||
return
|
||||
def _invoke() -> None:
|
||||
try:
|
||||
self._get_instance(instance_id).patch_model(
|
||||
device_to, lowvram_model_memory, load_weights, force_patch_weights
|
||||
)
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
|
||||
)
|
||||
return
|
||||
|
||||
self._run_operation_with_lease(instance_id, "patch_model", _invoke)
|
||||
|
||||
async def unpatch_model(
|
||||
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights)
|
||||
self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"unpatch_model",
|
||||
lambda: self._get_instance(instance_id).unpatch_model(
|
||||
device_to, unpatch_weights
|
||||
),
|
||||
)
|
||||
|
||||
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
|
||||
self._get_instance(instance_id).detach(unpatch_all)
|
||||
@ -262,26 +484,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
self._get_instance(instance_id).pre_run()
|
||||
|
||||
async def cleanup(self, instance_id: str) -> None:
|
||||
try:
|
||||
instance = self._get_instance(instance_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcher cleanup requested for missing instance %s",
|
||||
instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
def _invoke() -> None:
|
||||
try:
|
||||
instance = self._get_instance(instance_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcher cleanup requested for missing instance %s",
|
||||
instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
instance.cleanup()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._pending_cleanup_ids.add(instance_id)
|
||||
gc.collect()
|
||||
try:
|
||||
instance.cleanup()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._pending_cleanup_ids.add(instance_id)
|
||||
gc.collect()
|
||||
|
||||
self._run_operation_with_lease(instance_id, "cleanup", _invoke)
|
||||
|
||||
def sweep_pending_cleanup(self) -> int:
|
||||
removed = 0
|
||||
with self._lock:
|
||||
with self._operation_state_cv:
|
||||
pending_ids = list(self._pending_cleanup_ids)
|
||||
self._pending_cleanup_ids.clear()
|
||||
for instance_id in pending_ids:
|
||||
@ -289,17 +514,21 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
if instance is None:
|
||||
continue
|
||||
self._id_map.pop(id(instance), None)
|
||||
self._operation_states.pop(instance_id, None)
|
||||
removed += 1
|
||||
self._operation_state_cv.notify_all()
|
||||
|
||||
gc.collect()
|
||||
return removed
|
||||
|
||||
def purge_all(self) -> int:
|
||||
with self._lock:
|
||||
with self._operation_state_cv:
|
||||
removed = len(self._registry)
|
||||
self._registry.clear()
|
||||
self._id_map.clear()
|
||||
self._pending_cleanup_ids.clear()
|
||||
self._operation_states.clear()
|
||||
self._operation_state_cv.notify_all()
|
||||
gc.collect()
|
||||
return removed
|
||||
|
||||
@ -743,17 +972,52 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
async def inner_model_memory_required(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.memory_required(*args, **kwargs)
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"inner_model_memory_required",
|
||||
lambda: self._get_instance(instance_id).model.memory_required(
|
||||
*args, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def inner_model_extra_conds_shapes(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs)
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"inner_model_extra_conds_shapes",
|
||||
lambda: self._get_instance(instance_id).model.extra_conds_shapes(
|
||||
*args, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def inner_model_extra_conds(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
|
||||
def _invoke() -> Any:
|
||||
result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
|
||||
try:
|
||||
import torch
|
||||
import comfy.conds
|
||||
except Exception:
|
||||
return result
|
||||
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
if isinstance(obj, comfy.conds.CONDRegular):
|
||||
return type(obj)(_to_cpu(obj.cond))
|
||||
return obj
|
||||
|
||||
return _to_cpu(result)
|
||||
|
||||
return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke)
|
||||
|
||||
async def inner_model_state_dict(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
@ -767,82 +1031,160 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
async def inner_model_apply_model(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
target = getattr(instance, "load_device", None)
|
||||
if target is None and args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif target is None:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
def _invoke() -> Any:
|
||||
import torch
|
||||
|
||||
def _move(obj):
|
||||
if target is None:
|
||||
instance = self._get_instance(instance_id)
|
||||
target = getattr(instance, "load_device", None)
|
||||
if target is None and args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif target is None:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
|
||||
def _move(obj):
|
||||
if target is None:
|
||||
return obj
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return type(obj)(_move(o) for o in obj)
|
||||
if hasattr(obj, "to"):
|
||||
return obj.to(target)
|
||||
return obj
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return type(obj)(_move(o) for o in obj)
|
||||
if hasattr(obj, "to"):
|
||||
return obj.to(target)
|
||||
return obj
|
||||
|
||||
moved_args = tuple(_move(a) for a in args)
|
||||
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
|
||||
result = instance.model.apply_model(*moved_args, **moved_kwargs)
|
||||
return detach_if_grad(_move(result))
|
||||
moved_args = tuple(_move(a) for a in args)
|
||||
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
|
||||
result = instance.model.apply_model(*moved_args, **moved_kwargs)
|
||||
moved_result = detach_if_grad(_move(result))
|
||||
|
||||
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
|
||||
# at the transport boundary. Marshal dynamic-path results as CPU and let
|
||||
# the proxy restore device placement in the child process.
|
||||
is_dynamic_fn = getattr(instance, "is_dynamic", None)
|
||||
if callable(is_dynamic_fn) and is_dynamic_fn():
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
return obj
|
||||
|
||||
return _to_cpu(moved_result)
|
||||
return moved_result
|
||||
|
||||
return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke)
|
||||
|
||||
async def process_latent_in(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"process_latent_in",
|
||||
lambda: detach_if_grad(
|
||||
self._get_instance(instance_id).model.process_latent_in(
|
||||
*args, **kwargs
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async def process_latent_out(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.process_latent_out(*args, **kwargs)
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
return detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"process_latent_out: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
return detach_if_grad(result)
|
||||
import torch
|
||||
|
||||
def _invoke() -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.process_latent_out(*args, **kwargs)
|
||||
moved_result = None
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
moved_result = detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"process_latent_out: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
if moved_result is None:
|
||||
moved_result = detach_if_grad(result)
|
||||
|
||||
is_dynamic_fn = getattr(instance, "is_dynamic", None)
|
||||
if callable(is_dynamic_fn) and is_dynamic_fn():
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
return obj
|
||||
|
||||
return _to_cpu(moved_result)
|
||||
return moved_result
|
||||
|
||||
return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke)
|
||||
|
||||
async def scale_latent_inpaint(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.scale_latent_inpaint(*args, **kwargs)
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
return detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"scale_latent_inpaint: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
return detach_if_grad(result)
|
||||
import torch
|
||||
|
||||
def _invoke() -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.scale_latent_inpaint(*args, **kwargs)
|
||||
moved_result = None
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
moved_result = detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"scale_latent_inpaint: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
if moved_result is None:
|
||||
moved_result = detach_if_grad(result)
|
||||
|
||||
is_dynamic_fn = getattr(instance, "is_dynamic", None)
|
||||
if callable(is_dynamic_fn) and is_dynamic_fn():
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
return obj
|
||||
|
||||
return _to_cpu(moved_result)
|
||||
return moved_result
|
||||
|
||||
return self._run_operation_with_lease(
|
||||
instance_id, "scale_latent_inpaint", _invoke
|
||||
)
|
||||
|
||||
async def load_lora(
|
||||
self,
|
||||
|
||||
@ -3,6 +3,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _describe_value(obj: Any) -> str:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
torch = None
|
||||
try:
|
||||
if torch is not None and isinstance(obj, torch.Tensor):
|
||||
return (
|
||||
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
|
||||
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return "%s(id=%s)" % (type(obj).__name__, id(obj))
|
||||
|
||||
|
||||
def _prefer_device(*tensors: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def _to_cpu_for_rpc(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
t = obj.detach() if obj.requires_grad else obj
|
||||
if t.is_cuda:
|
||||
return t.to("cpu")
|
||||
return t
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_cpu_for_rpc(x) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class ModelSamplingRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "modelsampling"
|
||||
|
||||
@ -196,17 +233,93 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
return self._rpc_caller
|
||||
|
||||
def _call(self, method_name: str, *args: Any) -> Any:
|
||||
print(
|
||||
"ISO:modelsampling_call_enter method=%s instance_id=%s pid=%s"
|
||||
% (method_name, self._instance_id, os.getpid()),
|
||||
flush=True,
|
||||
)
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
print(
|
||||
"ISO:modelsampling_call_before_dispatch method=%s instance_id=%s pid=%s"
|
||||
% (method_name, self._instance_id, os.getpid()),
|
||||
flush=True,
|
||||
)
|
||||
result = method(self._instance_id, *args)
|
||||
timeout_ms = self._rpc_timeout_ms()
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
call_id = "%s:%s:%s:%.6f" % (
|
||||
self._instance_id,
|
||||
method_name,
|
||||
thread_id,
|
||||
start_perf,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
timeout_ms,
|
||||
)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(result)
|
||||
out = run_coro_in_new_loop(result)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(result)
|
||||
return result
|
||||
out = loop.run_until_complete(result)
|
||||
else:
|
||||
out = result
|
||||
print(
|
||||
"ISO:modelsampling_call_after_dispatch method=%s instance_id=%s pid=%s"
|
||||
% (method_name, self._instance_id, os.getpid()),
|
||||
flush=True,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
_describe_value(out),
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
)
|
||||
print(
|
||||
"ISO:modelsampling_call_return method=%s instance_id=%s pid=%s"
|
||||
% (method_name, self._instance_id, os.getpid()),
|
||||
flush=True,
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _rpc_timeout_ms() -> int:
|
||||
raw = os.environ.get(
|
||||
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
|
||||
)
|
||||
try:
|
||||
timeout_ms = int(raw)
|
||||
except ValueError:
|
||||
timeout_ms = 30000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
@property
|
||||
def sigma_min(self) -> Any:
|
||||
@ -235,10 +348,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
def noise_scaling(
|
||||
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
||||
) -> Any:
|
||||
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
|
||||
preferred_device = _prefer_device(noise, latent_image)
|
||||
out = self._call(
|
||||
"noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(noise),
|
||||
_to_cpu_for_rpc(latent_image),
|
||||
max_denoise,
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
||||
return self._call("inverse_noise_scaling", sigma, latent)
|
||||
preferred_device = _prefer_device(latent)
|
||||
out = self._call(
|
||||
"inverse_noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(latent),
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def timestep(self, sigma: Any) -> Any:
|
||||
return self._call("timestep", sigma)
|
||||
|
||||
@ -2,9 +2,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
@ -119,6 +121,24 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
class BaseProxy(Generic[T]):
|
||||
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
||||
__module__: str = "comfy.isolation.proxies.base"
|
||||
_TIMEOUT_RPC_METHODS = frozenset(
|
||||
{
|
||||
"partially_load",
|
||||
"partially_unload",
|
||||
"load",
|
||||
"patch_model",
|
||||
"unpatch_model",
|
||||
"inner_model_apply_model",
|
||||
"memory_required",
|
||||
"model_dtype",
|
||||
"inner_model_memory_required",
|
||||
"inner_model_extra_conds_shapes",
|
||||
"inner_model_extra_conds",
|
||||
"process_latent_in",
|
||||
"process_latent_out",
|
||||
"scale_latent_inpaint",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -148,39 +168,89 @@ class BaseProxy(Generic[T]):
|
||||
)
|
||||
return self._rpc_caller
|
||||
|
||||
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
|
||||
if method_name not in self._TIMEOUT_RPC_METHODS:
|
||||
return None
|
||||
try:
|
||||
timeout_ms = int(
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
|
||||
)
|
||||
except ValueError:
|
||||
timeout_ms = 120000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
|
||||
coro = method(self._instance_id, *args, **kwargs)
|
||||
if timeout_ms is not None:
|
||||
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
|
||||
|
||||
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
# If we are already in the global loop, we can't block on it?
|
||||
# Actually, this method is synchronous (__getattr__ -> lambda).
|
||||
# If called from async context in main loop, we need to handle that.
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
# We are in the main loop. We cannot await/block here if we are just a sync function.
|
||||
# But proxies are often called from sync code.
|
||||
# If called from sync code in main loop, creating a new loop is bad.
|
||||
# But we can't await `coro`.
|
||||
# This implies proxies MUST be awaited if called from async context?
|
||||
# Existing code used `run_coro_in_new_loop` which is weird.
|
||||
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
|
||||
# we use run_coroutine_threadsafe.
|
||||
pass
|
||||
except RuntimeError:
|
||||
# No running loop - we are in a worker thread.
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result()
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
loop_id: Optional[int] = id(running_loop)
|
||||
except RuntimeError:
|
||||
loop_id = None
|
||||
logger.debug(
|
||||
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
|
||||
"thread=%s loop=%s timeout_ms=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
loop_id,
|
||||
timeout_ms,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
pass
|
||||
except RuntimeError:
|
||||
# No running loop - we are in a worker thread.
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result(
|
||||
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
||||
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
||||
) from exc
|
||||
except concurrent.futures.TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
||||
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
||||
) from exc
|
||||
finally:
|
||||
end_epoch = time.time()
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
|
||||
"elapsed_ms=%.3f thread=%s loop=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
end_epoch,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
loop_id,
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {"_instance_id": self._instance_id}
|
||||
|
||||
@ -208,13 +208,18 @@ def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tenso
|
||||
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
print("CC_DEBUG2: enter _calc_cond_batch_outer", flush=True)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
print("CC_DEBUG2: before _calc_cond_batch executor.execute", flush=True)
|
||||
result = executor.execute(model, conds, x_in, timestep, model_options)
|
||||
print("CC_DEBUG2: after _calc_cond_batch executor.execute", flush=True)
|
||||
return result
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
print("CC_DEBUG2: enter _calc_cond_batch", flush=True)
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
@ -247,7 +252,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
print("CC_DEBUG: before prepare_state", flush=True)
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
print("CC_DEBUG: after prepare_state", flush=True)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@ -262,7 +269,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
print("CC_DEBUG: before get_free_memory", flush=True)
|
||||
free_memory = model.current_patcher.get_free_memory(x_in.device)
|
||||
print("CC_DEBUG: after get_free_memory", flush=True)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
@ -272,7 +281,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
print("CC_DEBUG: before memory_required", flush=True)
|
||||
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
|
||||
print("CC_DEBUG: after memory_required", flush=True)
|
||||
if memory_required * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@ -309,7 +321,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
else:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
print("CC_DEBUG: before apply_hooks", flush=True)
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
print("CC_DEBUG: after apply_hooks", flush=True)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
@ -331,9 +345,13 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
print("CC_DEBUG: before apply_model", flush=True)
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
print("CC_DEBUG: after apply_model", flush=True)
|
||||
else:
|
||||
print("CC_DEBUG: before apply_model", flush=True)
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
print("CC_DEBUG: after apply_model", flush=True)
|
||||
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
@ -768,6 +786,7 @@ class KSAMPLER(Sampler):
|
||||
self.inpaint_options = inpaint_options
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
print("CC_DEBUG3: enter KSAMPLER.sample", flush=True)
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
|
||||
model_k.latent_image = latent_image
|
||||
@ -777,15 +796,46 @@ class KSAMPLER(Sampler):
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
||||
print("CC_DEBUG3: before max_denoise", flush=True)
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
print("CC_DEBUG3: after max_denoise", flush=True)
|
||||
print("CC_DEBUG3: before model_sampling_attr", flush=True)
|
||||
model_sampling = model_wrap.inner_model.model_sampling
|
||||
print(
|
||||
"CC_DEBUG3: after model_sampling_attr type=%s id=%s instance_id=%s"
|
||||
% (
|
||||
type(model_sampling).__name__,
|
||||
id(model_sampling),
|
||||
getattr(model_sampling, "_instance_id", "n/a"),
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
print("CC_DEBUG3: before noise_scaling_call", flush=True)
|
||||
try:
|
||||
noise_scaled = model_sampling.noise_scaling(
|
||||
sigmas[0], noise, latent_image, max_denoise
|
||||
)
|
||||
print("CC_DEBUG3: after noise_scaling_call", flush=True)
|
||||
except Exception as e:
|
||||
print(
|
||||
"CC_DEBUG3: noise_scaling_exception type=%s msg=%s"
|
||||
% (type(e).__name__, str(e)),
|
||||
flush=True,
|
||||
)
|
||||
raise
|
||||
noise = noise_scaled
|
||||
print("CC_DEBUG3: after noise_assignment", flush=True)
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
print("CC_DEBUG3: before sampler_function", flush=True)
|
||||
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
||||
print("CC_DEBUG3: after sampler_function", flush=True)
|
||||
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
|
||||
print("CC_DEBUG3: after inverse_noise_scaling", flush=True)
|
||||
return samples
|
||||
|
||||
|
||||
@ -825,10 +875,16 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
for k in conds:
|
||||
calculate_start_end_timesteps(model, conds[k])
|
||||
|
||||
print('GP_DEBUG: before hasattr extra_conds', flush=True)
|
||||
print('GP_DEBUG: before hasattr extra_conds', flush=True)
|
||||
if hasattr(model, 'extra_conds'):
|
||||
print('GP_DEBUG: has extra_conds!', flush=True)
|
||||
print('GP_DEBUG: has extra_conds!', flush=True)
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
|
||||
|
||||
print('GP_DEBUG: before area make sure loop', flush=True)
|
||||
print('GP_DEBUG: before area make sure loop', flush=True)
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
@ -842,8 +898,11 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
for hook in c['hooks'].hooks:
|
||||
hook.initialize_timesteps(model)
|
||||
|
||||
print('GP_DEBUG: before pre_run_control loop', flush=True)
|
||||
for k in conds:
|
||||
print('GP_DEBUG: calling pre_run_control for key:', k, flush=True)
|
||||
pre_run_control(model, conds[k])
|
||||
print('GP_DEBUG: after pre_run_control loop', flush=True)
|
||||
|
||||
if "positive" in conds:
|
||||
positive = conds["positive"]
|
||||
@ -1005,6 +1064,8 @@ class CFGGuider:
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
|
||||
|
||||
print("GP_DEBUG: process_conds finished", flush=True)
|
||||
print("GP_DEBUG: process_conds finished", flush=True)
|
||||
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||
extra_args = {"model_options": extra_model_options, "seed": seed}
|
||||
@ -1014,6 +1075,7 @@ class CFGGuider:
|
||||
sampler,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
|
||||
)
|
||||
print("GP_DEBUG: before executor.execute", flush=True)
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
|
||||
28
execution.py
28
execution.py
@ -696,6 +696,24 @@ class PromptExecutor:
|
||||
except Exception:
|
||||
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
|
||||
|
||||
async def _wait_model_patcher_quiescence_safe(
|
||||
self,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
timeout_ms: int = 120000,
|
||||
marker: str = "EX:wait_model_patcher_idle",
|
||||
) -> None:
|
||||
try:
|
||||
from comfy.isolation import wait_for_model_patcher_quiescence
|
||||
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker
|
||||
)
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True)
|
||||
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
**data,
|
||||
@ -766,6 +784,11 @@ class PromptExecutor:
|
||||
# Boundary cleanup runs at the start of the next workflow in
|
||||
# isolation mode, matching non-isolated "next prompt" timing.
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
await self._wait_model_patcher_quiescence_safe(
|
||||
fail_loud=False,
|
||||
timeout_ms=120000,
|
||||
marker="EX:boundary_cleanup_wait_idle",
|
||||
)
|
||||
await self._flush_running_extensions_transport_state_safe()
|
||||
comfy.model_management.unload_all_models()
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
@ -806,6 +829,11 @@ class PromptExecutor:
|
||||
for node_id in execution_list.pendingNodes.keys():
|
||||
class_type = dynamic_prompt.get_node(node_id)["class_type"]
|
||||
pending_class_types.add(class_type)
|
||||
await self._wait_model_patcher_quiescence_safe(
|
||||
fail_loud=True,
|
||||
timeout_ms=120000,
|
||||
marker="EX:notify_graph_wait_idle",
|
||||
)
|
||||
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user