isolation+dynamicvram: stabilize ModelPatcher RPC path, add diagnostics; known process_latent_in timeout remains
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

- harden isolation ModelPatcher proxy/registry behavior for DynamicVRAM-backed patchers
- improve serializer/adapter boundaries (device/dtype/model refs) to reduce pre-inference lockups
- add structured ISO registry/modelsampling telemetry and explicit RPC timeout surfacing
- preserve isolation-first lifecycle handling and boundary cleanup sequencing
- validate isolated workflows: most targeted runs now complete under
  --use-sage-attention --use-process-isolation --disable-cuda-malloc

Known issue (reproducible):
- isolation_99_full_iso_stack still times out at SamplerCustom_ISO path
- failure is explicit RPC timeout:
  ModelPatcherProxy.process_latent_in(instance_id=model_0, timeout_ms=120000)
- this indicates the remaining stall is on process_latent_in RPC path, not generic startup/manager fetch
This commit is contained in:
John Pollock 2026-03-04 10:41:33 -06:00
parent 6ba03a0747
commit 26edd5663d
9 changed files with 927 additions and 154 deletions

View File

@ -26,6 +26,7 @@ PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
def initialize_proxies() -> None: def initialize_proxies() -> None:
@ -168,6 +169,11 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
async def notify_execution_graph(needed_class_types: Set[str]) -> None: async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"""Evict running extensions not needed for current execution.""" """Evict running extensions not needed for current execution."""
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:notify_graph_wait_idle",
)
async def _stop_extension( async def _stop_extension(
ext_name: str, extension: "ComfyNodeExtension", reason: str ext_name: str, extension: "ComfyNodeExtension", reason: str
@ -237,6 +243,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
async def flush_running_extensions_transport_state() -> int: async def flush_running_extensions_transport_state() -> int:
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:flush_transport_wait_idle",
)
total_flushed = 0 total_flushed = 0
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
flush_fn = getattr(extension, "flush_transport_state", None) flush_fn = getattr(extension, "flush_transport_state", None)
@ -263,6 +274,50 @@ async def flush_running_extensions_transport_state() -> int:
return total_flushed return total_flushed
async def wait_for_model_patcher_quiescence(
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
*,
fail_loud: bool = False,
marker: str = "ISO:wait_model_patcher_idle",
) -> bool:
try:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
registry = ModelPatcherRegistry()
start = time.perf_counter()
idle = await registry.wait_all_idle(timeout_ms)
elapsed_ms = (time.perf_counter() - start) * 1000.0
if idle:
logger.debug(
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
)
return True
states = await registry.get_all_operation_states()
logger.error(
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
states,
)
if fail_loud:
raise TimeoutError(
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
)
return False
except Exception:
if fail_loud:
raise
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
return False
def get_claimed_paths() -> Set[Path]: def get_claimed_paths() -> Set[Path]:
return _CLAIMED_PATHS return _CLAIMED_PATHS
@ -320,6 +375,7 @@ __all__ = [
"await_isolation_loading", "await_isolation_loading",
"notify_execution_graph", "notify_execution_graph",
"flush_running_extensions_transport_state", "flush_running_extensions_transport_state",
"wait_for_model_patcher_quiescence",
"get_claimed_paths", "get_claimed_paths",
"update_rpc_event_loops", "update_rpc_event_loops",
"IsolatedNodeSpec", "IsolatedNodeSpec",

View File

@ -83,6 +83,33 @@ class ComfyUIAdapter(IsolationAdapter):
logging.getLogger(pkg_name).setLevel(logging.ERROR) logging.getLogger(pkg_name).setLevel(logging.ERROR)
def register_serializers(self, registry: SerializerRegistryProtocol) -> None: def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
import torch
def serialize_device(obj: Any) -> Dict[str, Any]:
return {"__type__": "device", "device_str": str(obj)}
def deserialize_device(data: Dict[str, Any]) -> Any:
return torch.device(data["device_str"])
registry.register("device", serialize_device, deserialize_device)
_VALID_DTYPES = {
"float16", "float32", "float64", "bfloat16",
"int8", "int16", "int32", "int64",
"uint8", "bool",
}
def serialize_dtype(obj: Any) -> Dict[str, Any]:
return {"__type__": "dtype", "dtype_str": str(obj)}
def deserialize_dtype(data: Dict[str, Any]) -> Any:
dtype_name = data["dtype_str"].replace("torch.", "")
if dtype_name not in _VALID_DTYPES:
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
return getattr(torch, dtype_name)
registry.register("dtype", serialize_dtype, deserialize_dtype)
def serialize_model_patcher(obj: Any) -> Dict[str, Any]: def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy) # Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1": if os.environ.get("PYISOLATE_CHILD") == "1":
@ -193,6 +220,10 @@ class ComfyUIAdapter(IsolationAdapter):
f"ModelSampling in child lacks _instance_id: " f"ModelSampling in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}" f"{type(obj).__module__}.{type(obj).__name__}"
) )
# Host-side pass-through for proxies: do not re-register a proxy as a
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict # Host-side: register with ModelSamplingRegistry and return JSON-safe dict
ms_id = ModelSamplingRegistry().register(obj) ms_id = ModelSamplingRegistry().register(obj)
return {"__type__": "ModelSamplingRef", "ms_id": ms_id} return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
@ -211,22 +242,21 @@ class ComfyUIAdapter(IsolationAdapter):
else: else:
return ModelSamplingRegistry()._get_instance(data["ms_id"]) return ModelSamplingRegistry()._get_instance(data["ms_id"])
# Register ModelSampling type and proxy # Register all ModelSampling* and StableCascadeSampling classes dynamically
registry.register( import comfy.model_sampling
"ModelSamplingDiscrete",
serialize_model_sampling, for ms_cls in vars(comfy.model_sampling).values():
deserialize_model_sampling, if not isinstance(ms_cls, type):
) continue
registry.register( if not issubclass(ms_cls, torch.nn.Module):
"ModelSamplingContinuousEDM", continue
serialize_model_sampling, if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
deserialize_model_sampling, continue
) registry.register(
registry.register( ms_cls.__name__,
"ModelSamplingContinuousV", serialize_model_sampling,
serialize_model_sampling, deserialize_model_sampling,
deserialize_model_sampling, )
)
registry.register( registry.register(
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
) )

View File

@ -382,6 +382,10 @@ class ComfyNodeExtension(ExtensionBase):
if type(result).__name__ == "NodeOutput": if type(result).__name__ == "NodeOutput":
result = result.args result = result.args
print(
f"{LOG_PREFIX} ISO:child_result_ready node={node_name} type={type(result).__name__}",
flush=True,
)
if self._is_comfy_protocol_return(result): if self._is_comfy_protocol_return(result):
logger.debug( logger.debug(
"%s ISO:child_execute_done ext=%s node=%s protocol_return=1", "%s ISO:child_execute_done ext=%s node=%s protocol_return=1",
@ -389,10 +393,17 @@ class ComfyNodeExtension(ExtensionBase):
getattr(self, "name", "?"), getattr(self, "name", "?"),
node_name, node_name,
) )
return self._wrap_unpicklable_objects(result) print(f"{LOG_PREFIX} ISO:child_wrap_start node={node_name} protocol=1", flush=True)
wrapped = self._wrap_unpicklable_objects(result)
print(f"{LOG_PREFIX} ISO:child_wrap_done node={node_name} protocol=1", flush=True)
return wrapped
if not isinstance(result, tuple): if not isinstance(result, tuple):
result = (result,) result = (result,)
print(
f"{LOG_PREFIX} ISO:child_result_tuple node={node_name} outputs={len(result)}",
flush=True,
)
logger.debug( logger.debug(
"%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d", "%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d",
LOG_PREFIX, LOG_PREFIX,
@ -400,7 +411,10 @@ class ComfyNodeExtension(ExtensionBase):
node_name, node_name,
len(result), len(result),
) )
return self._wrap_unpicklable_objects(result) print(f"{LOG_PREFIX} ISO:child_wrap_start node={node_name} protocol=0", flush=True)
wrapped = self._wrap_unpicklable_objects(result)
print(f"{LOG_PREFIX} ISO:child_wrap_done node={node_name} protocol=0", flush=True)
return wrapped
async def flush_transport_state(self) -> int: async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1": if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
@ -443,7 +457,10 @@ class ComfyNodeExtension(ExtensionBase):
if isinstance(data, (str, int, float, bool, type(None))): if isinstance(data, (str, int, float, bool, type(None))):
return data return data
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.detach() if data.requires_grad else data tensor = data.detach() if data.requires_grad else data
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
return tensor.cpu()
return tensor
# Special-case clip vision outputs: preserve attribute access by packing fields # Special-case clip vision outputs: preserve attribute access by packing fields
if hasattr(data, "penultimate_hidden_states") or hasattr( if hasattr(data, "penultimate_hidden_states") or hasattr(

View File

@ -338,6 +338,34 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def apply_model(self, *args, **kwargs) -> Any: def apply_model(self, *args, **kwargs) -> Any:
import torch import torch
def _preferred_device() -> Any:
for value in args:
if isinstance(value, torch.Tensor):
return value.device
for value in kwargs.values():
if isinstance(value, torch.Tensor):
return value.device
return None
def _move_result_to_device(obj: Any, device: Any) -> Any:
if device is None:
return obj
if isinstance(obj, torch.Tensor):
return obj.to(device) if obj.device != device else obj
if isinstance(obj, dict):
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
if isinstance(obj, list):
return [_move_result_to_device(v, device) for v in obj]
if isinstance(obj, tuple):
return tuple(_move_result_to_device(v, device) for v in obj)
return obj
# DynamicVRAM models must keep load/offload decisions in host process.
# Child-side CUDA staging here can deadlock before first inference RPC.
if self.is_dynamic():
out = self._call_rpc("inner_model_apply_model", args, kwargs)
return _move_result_to_device(out, _preferred_device())
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs) required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
self._ensure_apply_model_headroom(required_bytes) self._ensure_apply_model_headroom(required_bytes)
@ -360,7 +388,8 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
args_cuda = _to_cuda(args) args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs) kwargs_cuda = _to_cuda(kwargs)
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda) out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
return _move_result_to_device(out, _preferred_device())
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any: def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
keys = self._call_rpc("model_state_dict", filter_prefix) keys = self._call_rpc("model_state_dict", filter_prefix)
@ -526,6 +555,13 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def memory_required(self, input_shape: Any) -> Any: def memory_required(self, input_shape: Any) -> Any:
return self._call_rpc("memory_required", input_shape) return self._call_rpc("memory_required", input_shape)
def get_operation_state(self) -> Dict[str, Any]:
state = self._call_rpc("get_operation_state")
return state if isinstance(state, dict) else {}
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
return bool(self._call_rpc("wait_for_idle", timeout_ms))
def is_dynamic(self) -> bool: def is_dynamic(self) -> bool:
return bool(self._call_rpc("is_dynamic")) return bool(self._call_rpc("is_dynamic"))
@ -771,6 +807,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
class _InnerModelProxy: class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy): def __init__(self, parent: ModelPatcherProxy):
self._parent = parent self._parent = parent
self._model_sampling = None
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
if name.startswith("_"): if name.startswith("_"):
@ -793,7 +830,11 @@ class _InnerModelProxy:
manage_lifecycle=False, manage_lifecycle=False,
) )
if name == "model_sampling": if name == "model_sampling":
return self._parent._call_rpc("get_model_object", "model_sampling") if self._model_sampling is None:
self._model_sampling = self._parent._call_rpc(
"get_model_object", "model_sampling"
)
return self._model_sampling
if name == "extra_conds_shapes": if name == "extra_conds_shapes":
return lambda *a, **k: self._parent._call_rpc( return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds_shapes", a, k "inner_model_extra_conds_shapes", a, k

View File

@ -2,8 +2,12 @@
# RPC server for ModelPatcher isolation (child process) # RPC server for ModelPatcher isolation (child process)
from __future__ import annotations from __future__ import annotations
import asyncio
import gc import gc
import logging import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional, List from typing import Any, Optional, List
try: try:
@ -43,12 +47,191 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class _OperationState:
lease: threading.Lock = field(default_factory=threading.Lock)
active_count: int = 0
active_by_method: dict[str, int] = field(default_factory=dict)
total_operations: int = 0
last_method: Optional[str] = None
last_started_ts: Optional[float] = None
last_ended_ts: Optional[float] = None
last_elapsed_ms: Optional[float] = None
last_error: Optional[str] = None
last_thread_id: Optional[int] = None
last_loop_id: Optional[int] = None
class ModelPatcherRegistry(BaseRegistry[Any]): class ModelPatcherRegistry(BaseRegistry[Any]):
_type_prefix = "model" _type_prefix = "model"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._pending_cleanup_ids: set[str] = set() self._pending_cleanup_ids: set[str] = set()
self._operation_states: dict[str, _OperationState] = {}
self._operation_state_cv = threading.Condition(self._lock)
def _get_or_create_operation_state(self, instance_id: str) -> _OperationState:
state = self._operation_states.get(instance_id)
if state is None:
state = _OperationState()
self._operation_states[instance_id] = state
return state
def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]:
start_epoch = time.time()
start_perf = time.perf_counter()
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
state.active_count += 1
state.active_by_method[method_name] = (
state.active_by_method.get(method_name, 0) + 1
)
state.total_operations += 1
state.last_method = method_name
state.last_started_ts = start_epoch
state.last_thread_id = threading.get_ident()
try:
state.last_loop_id = id(asyncio.get_running_loop())
except RuntimeError:
state.last_loop_id = None
logger.debug(
"ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s",
instance_id,
method_name,
start_epoch,
threading.get_ident(),
state.last_loop_id,
)
return start_epoch, start_perf
def _end_operation(
self,
instance_id: str,
method_name: str,
start_perf: float,
error: Optional[BaseException] = None,
) -> None:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
state.active_count = max(0, state.active_count - 1)
if method_name in state.active_by_method:
remaining = state.active_by_method[method_name] - 1
if remaining <= 0:
state.active_by_method.pop(method_name, None)
else:
state.active_by_method[method_name] = remaining
state.last_ended_ts = end_epoch
state.last_elapsed_ms = elapsed_ms
state.last_error = None if error is None else repr(error)
if state.active_count == 0:
self._operation_state_cv.notify_all()
logger.debug(
"ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s",
instance_id,
method_name,
end_epoch,
elapsed_ms,
None if error is None else type(error).__name__,
)
def _run_operation_with_lease(self, instance_id: str, method_name: str, fn):
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
lease = state.lease
with lease:
_, start_perf = self._begin_operation(instance_id, method_name)
try:
result = fn()
except Exception as exc:
self._end_operation(instance_id, method_name, start_perf, error=exc)
raise
self._end_operation(instance_id, method_name, start_perf)
return result
def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]:
with self._operation_state_cv:
state = self._operation_states.get(instance_id)
if state is None:
return {
"instance_id": instance_id,
"active_count": 0,
"active_methods": [],
"total_operations": 0,
"last_method": None,
"last_started_ts": None,
"last_ended_ts": None,
"last_elapsed_ms": None,
"last_error": None,
"last_thread_id": None,
"last_loop_id": None,
}
return {
"instance_id": instance_id,
"active_count": state.active_count,
"active_methods": sorted(state.active_by_method.keys()),
"total_operations": state.total_operations,
"last_method": state.last_method,
"last_started_ts": state.last_started_ts,
"last_ended_ts": state.last_ended_ts,
"last_elapsed_ms": state.last_elapsed_ms,
"last_error": state.last_error,
"last_thread_id": state.last_thread_id,
"last_loop_id": state.last_loop_id,
}
def unregister_sync(self, instance_id: str) -> None:
with self._operation_state_cv:
instance = self._registry.pop(instance_id, None)
if instance is not None:
self._id_map.pop(id(instance), None)
self._pending_cleanup_ids.discard(instance_id)
self._operation_states.pop(instance_id, None)
self._operation_state_cv.notify_all()
async def get_operation_state(self, instance_id: str) -> dict[str, Any]:
return self._snapshot_operation_state(instance_id)
async def get_all_operation_states(self) -> dict[str, dict[str, Any]]:
with self._operation_state_cv:
ids = sorted(self._operation_states.keys())
return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids}
async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool:
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
with self._operation_state_cv:
while True:
active = self._operation_states.get(instance_id)
if active is None or active.active_count == 0:
return True
if deadline is None:
self._operation_state_cv.wait()
continue
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
self._operation_state_cv.wait(timeout=remaining)
async def wait_all_idle(self, timeout_ms: int = 0) -> bool:
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
with self._operation_state_cv:
while True:
has_active = any(
state.active_count > 0 for state in self._operation_states.values()
)
if not has_active:
return True
if deadline is None:
self._operation_state_cv.wait()
continue
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
self._operation_state_cv.wait(timeout=remaining)
async def clone(self, instance_id: str) -> str: async def clone(self, instance_id: str) -> str:
instance = self._get_instance(instance_id) instance = self._get_instance(instance_id)
@ -62,8 +245,10 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return False return False
async def get_model_object(self, instance_id: str, name: str) -> Any: async def get_model_object(self, instance_id: str, name: str) -> Any:
print(f"GP_DEBUG: get_model_object START for name={name}", flush=True)
instance = self._get_instance(instance_id) instance = self._get_instance(instance_id)
if name == "model": if name == "model":
print(f"GP_DEBUG: get_model_object END for name={name} (ModelObject)", flush=True)
return f"<ModelObject: {type(instance.model).__name__}>" return f"<ModelObject: {type(instance.model).__name__}>"
result = instance.get_model_object(name) result = instance.get_model_object(name)
if name == "model_sampling": if name == "model_sampling":
@ -73,8 +258,16 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
) )
registry = ModelSamplingRegistry() registry = ModelSamplingRegistry()
sampling_id = registry.register(result) # Preserve identity when upstream already returned a proxy. Re-registering
# a proxy object creates proxy-of-proxy call chains.
if isinstance(result, ModelSamplingProxy):
sampling_id = result._instance_id
else:
sampling_id = registry.register(result)
print(f"GP_DEBUG: get_model_object END for name={name} (model_sampling)", flush=True)
return ModelSamplingProxy(sampling_id, registry) return ModelSamplingProxy(sampling_id, registry)
print(f"GP_DEBUG: get_model_object END for name={name} (fallthrough)", flush=True)
return detach_if_grad(result) return detach_if_grad(result)
async def get_model_options(self, instance_id: str) -> dict: async def get_model_options(self, instance_id: str) -> dict:
@ -163,7 +356,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return self._get_instance(instance_id).lowvram_patch_counter() return self._get_instance(instance_id).lowvram_patch_counter()
async def memory_required(self, instance_id: str, input_shape: Any) -> Any: async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
return self._get_instance(instance_id).memory_required(input_shape) return self._run_operation_with_lease(
instance_id,
"memory_required",
lambda: self._get_instance(instance_id).memory_required(input_shape),
)
async def is_dynamic(self, instance_id: str) -> bool: async def is_dynamic(self, instance_id: str) -> bool:
instance = self._get_instance(instance_id) instance = self._get_instance(instance_id)
@ -186,7 +383,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return None return None
async def model_dtype(self, instance_id: str) -> Any: async def model_dtype(self, instance_id: str) -> Any:
return self._get_instance(instance_id).model_dtype() return self._run_operation_with_lease(
instance_id,
"model_dtype",
lambda: self._get_instance(instance_id).model_dtype(),
)
async def model_patches_to(self, instance_id: str, device: Any) -> Any: async def model_patches_to(self, instance_id: str, device: Any) -> Any:
return self._get_instance(instance_id).model_patches_to(device) return self._get_instance(instance_id).model_patches_to(device)
@ -198,8 +399,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
extra_memory: Any, extra_memory: Any,
force_patch_weights: bool = False, force_patch_weights: bool = False,
) -> Any: ) -> Any:
return self._get_instance(instance_id).partially_load( return self._run_operation_with_lease(
device, extra_memory, force_patch_weights=force_patch_weights instance_id,
"partially_load",
lambda: self._get_instance(instance_id).partially_load(
device, extra_memory, force_patch_weights=force_patch_weights
),
) )
async def partially_unload( async def partially_unload(
@ -209,8 +414,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
memory_to_free: int = 0, memory_to_free: int = 0,
force_patch_weights: bool = False, force_patch_weights: bool = False,
) -> int: ) -> int:
return self._get_instance(instance_id).partially_unload( return self._run_operation_with_lease(
device_to, memory_to_free, force_patch_weights instance_id,
"partially_unload",
lambda: self._get_instance(instance_id).partially_unload(
device_to, memory_to_free, force_patch_weights
),
) )
async def load( async def load(
@ -221,8 +430,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
force_patch_weights: bool = False, force_patch_weights: bool = False,
full_load: bool = False, full_load: bool = False,
) -> None: ) -> None:
self._get_instance(instance_id).load( self._run_operation_with_lease(
device_to, lowvram_model_memory, force_patch_weights, full_load instance_id,
"load",
lambda: self._get_instance(instance_id).load(
device_to, lowvram_model_memory, force_patch_weights, full_load
),
) )
async def patch_model( async def patch_model(
@ -233,20 +446,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
load_weights: bool = True, load_weights: bool = True,
force_patch_weights: bool = False, force_patch_weights: bool = False,
) -> None: ) -> None:
try: def _invoke() -> None:
self._get_instance(instance_id).patch_model( try:
device_to, lowvram_model_memory, load_weights, force_patch_weights self._get_instance(instance_id).patch_model(
) device_to, lowvram_model_memory, load_weights, force_patch_weights
except AttributeError as e: )
logger.error( except AttributeError as e:
f"Isolation Error: Failed to patch model attribute: {e}. Skipping." logger.error(
) f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
return )
return
self._run_operation_with_lease(instance_id, "patch_model", _invoke)
async def unpatch_model( async def unpatch_model(
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
) -> None: ) -> None:
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights) self._run_operation_with_lease(
instance_id,
"unpatch_model",
lambda: self._get_instance(instance_id).unpatch_model(
device_to, unpatch_weights
),
)
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None: async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
self._get_instance(instance_id).detach(unpatch_all) self._get_instance(instance_id).detach(unpatch_all)
@ -262,26 +484,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
self._get_instance(instance_id).pre_run() self._get_instance(instance_id).pre_run()
async def cleanup(self, instance_id: str) -> None: async def cleanup(self, instance_id: str) -> None:
try: def _invoke() -> None:
instance = self._get_instance(instance_id) try:
except Exception: instance = self._get_instance(instance_id)
logger.debug( except Exception:
"ModelPatcher cleanup requested for missing instance %s", logger.debug(
instance_id, "ModelPatcher cleanup requested for missing instance %s",
exc_info=True, instance_id,
) exc_info=True,
return )
return
try: try:
instance.cleanup() instance.cleanup()
finally: finally:
with self._lock: with self._lock:
self._pending_cleanup_ids.add(instance_id) self._pending_cleanup_ids.add(instance_id)
gc.collect() gc.collect()
self._run_operation_with_lease(instance_id, "cleanup", _invoke)
def sweep_pending_cleanup(self) -> int: def sweep_pending_cleanup(self) -> int:
removed = 0 removed = 0
with self._lock: with self._operation_state_cv:
pending_ids = list(self._pending_cleanup_ids) pending_ids = list(self._pending_cleanup_ids)
self._pending_cleanup_ids.clear() self._pending_cleanup_ids.clear()
for instance_id in pending_ids: for instance_id in pending_ids:
@ -289,17 +514,21 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
if instance is None: if instance is None:
continue continue
self._id_map.pop(id(instance), None) self._id_map.pop(id(instance), None)
self._operation_states.pop(instance_id, None)
removed += 1 removed += 1
self._operation_state_cv.notify_all()
gc.collect() gc.collect()
return removed return removed
def purge_all(self) -> int: def purge_all(self) -> int:
with self._lock: with self._operation_state_cv:
removed = len(self._registry) removed = len(self._registry)
self._registry.clear() self._registry.clear()
self._id_map.clear() self._id_map.clear()
self._pending_cleanup_ids.clear() self._pending_cleanup_ids.clear()
self._operation_states.clear()
self._operation_state_cv.notify_all()
gc.collect() gc.collect()
return removed return removed
@ -743,17 +972,52 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def inner_model_memory_required( async def inner_model_memory_required(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
return self._get_instance(instance_id).model.memory_required(*args, **kwargs) return self._run_operation_with_lease(
instance_id,
"inner_model_memory_required",
lambda: self._get_instance(instance_id).model.memory_required(
*args, **kwargs
),
)
async def inner_model_extra_conds_shapes( async def inner_model_extra_conds_shapes(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs) return self._run_operation_with_lease(
instance_id,
"inner_model_extra_conds_shapes",
lambda: self._get_instance(instance_id).model.extra_conds_shapes(
*args, **kwargs
),
)
async def inner_model_extra_conds( async def inner_model_extra_conds(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs) def _invoke() -> Any:
result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
try:
import torch
import comfy.conds
except Exception:
return result
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
if isinstance(obj, comfy.conds.CONDRegular):
return type(obj)(_to_cpu(obj.cond))
return obj
return _to_cpu(result)
return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke)
async def inner_model_state_dict( async def inner_model_state_dict(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
@ -767,82 +1031,160 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def inner_model_apply_model( async def inner_model_apply_model(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
instance = self._get_instance(instance_id) def _invoke() -> Any:
target = getattr(instance, "load_device", None) import torch
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _move(obj): instance = self._get_instance(instance_id)
if target is None: target = getattr(instance, "load_device", None)
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _move(obj):
if target is None:
return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj
moved_args = tuple(_move(a) for a in args) moved_args = tuple(_move(a) for a in args)
moved_kwargs = {k: _move(v) for k, v in kwargs.items()} moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
result = instance.model.apply_model(*moved_args, **moved_kwargs) result = instance.model.apply_model(*moved_args, **moved_kwargs)
return detach_if_grad(_move(result)) moved_result = detach_if_grad(_move(result))
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
# at the transport boundary. Marshal dynamic-path results as CPU and let
# the proxy restore device placement in the child process.
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke)
async def process_latent_in( async def process_latent_in(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
return detach_if_grad( return self._run_operation_with_lease(
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs) instance_id,
"process_latent_in",
lambda: detach_if_grad(
self._get_instance(instance_id).model.process_latent_in(
*args, **kwargs
)
),
) )
async def process_latent_out( async def process_latent_out(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
instance = self._get_instance(instance_id) import torch
result = instance.model.process_latent_out(*args, **kwargs)
try: def _invoke() -> Any:
target = None instance = self._get_instance(instance_id)
if args and hasattr(args[0], "device"): result = instance.model.process_latent_out(*args, **kwargs)
target = args[0].device moved_result = None
elif kwargs: try:
for v in kwargs.values(): target = None
if hasattr(v, "device"): if args and hasattr(args[0], "device"):
target = v.device target = args[0].device
break elif kwargs:
if target is not None and hasattr(result, "to"): for v in kwargs.values():
return detach_if_grad(result.to(target)) if hasattr(v, "device"):
except Exception: target = v.device
logger.debug( break
"process_latent_out: failed to move result to target device", if target is not None and hasattr(result, "to"):
exc_info=True, moved_result = detach_if_grad(result.to(target))
) except Exception:
return detach_if_grad(result) logger.debug(
"process_latent_out: failed to move result to target device",
exc_info=True,
)
if moved_result is None:
moved_result = detach_if_grad(result)
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke)
async def scale_latent_inpaint( async def scale_latent_inpaint(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
instance = self._get_instance(instance_id) import torch
result = instance.model.scale_latent_inpaint(*args, **kwargs)
try: def _invoke() -> Any:
target = None instance = self._get_instance(instance_id)
if args and hasattr(args[0], "device"): result = instance.model.scale_latent_inpaint(*args, **kwargs)
target = args[0].device moved_result = None
elif kwargs: try:
for v in kwargs.values(): target = None
if hasattr(v, "device"): if args and hasattr(args[0], "device"):
target = v.device target = args[0].device
break elif kwargs:
if target is not None and hasattr(result, "to"): for v in kwargs.values():
return detach_if_grad(result.to(target)) if hasattr(v, "device"):
except Exception: target = v.device
logger.debug( break
"scale_latent_inpaint: failed to move result to target device", if target is not None and hasattr(result, "to"):
exc_info=True, moved_result = detach_if_grad(result.to(target))
) except Exception:
return detach_if_grad(result) logger.debug(
"scale_latent_inpaint: failed to move result to target device",
exc_info=True,
)
if moved_result is None:
moved_result = detach_if_grad(result)
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(
instance_id, "scale_latent_inpaint", _invoke
)
async def load_lora( async def load_lora(
self, self,

View File

@ -3,6 +3,9 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
import threading
import time
from typing import Any from typing import Any
from comfy.isolation.proxies.base import ( from comfy.isolation.proxies.base import (
@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _describe_value(obj: Any) -> str:
try:
import torch
except Exception:
torch = None
try:
if torch is not None and isinstance(obj, torch.Tensor):
return (
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
)
except Exception:
pass
return "%s(id=%s)" % (type(obj).__name__, id(obj))
def _prefer_device(*tensors: Any) -> Any: def _prefer_device(*tensors: Any) -> Any:
try: try:
import torch import torch
@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any:
return obj return obj
def _to_cpu_for_rpc(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
t = obj.detach() if obj.requires_grad else obj
if t.is_cuda:
return t.to("cpu")
return t
if isinstance(obj, (list, tuple)):
converted = [_to_cpu_for_rpc(x) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
return obj
class ModelSamplingRegistry(BaseRegistry[Any]): class ModelSamplingRegistry(BaseRegistry[Any]):
_type_prefix = "modelsampling" _type_prefix = "modelsampling"
@ -196,17 +233,93 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
return self._rpc_caller return self._rpc_caller
def _call(self, method_name: str, *args: Any) -> Any: def _call(self, method_name: str, *args: Any) -> Any:
print(
"ISO:modelsampling_call_enter method=%s instance_id=%s pid=%s"
% (method_name, self._instance_id, os.getpid()),
flush=True,
)
rpc = self._get_rpc() rpc = self._get_rpc()
method = getattr(rpc, method_name) method = getattr(rpc, method_name)
print(
"ISO:modelsampling_call_before_dispatch method=%s instance_id=%s pid=%s"
% (method_name, self._instance_id, os.getpid()),
flush=True,
)
result = method(self._instance_id, *args) result = method(self._instance_id, *args)
timeout_ms = self._rpc_timeout_ms()
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
call_id = "%s:%s:%s:%.6f" % (
self._instance_id,
method_name,
thread_id,
start_perf,
)
logger.debug(
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
method_name,
self._instance_id,
call_id,
start_epoch,
thread_id,
timeout_ms,
)
if asyncio.iscoroutine(result): if asyncio.iscoroutine(result):
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
try: try:
asyncio.get_running_loop() asyncio.get_running_loop()
return run_coro_in_new_loop(result) out = run_coro_in_new_loop(result)
except RuntimeError: except RuntimeError:
loop = get_thread_loop() loop = get_thread_loop()
return loop.run_until_complete(result) out = loop.run_until_complete(result)
return result else:
out = result
print(
"ISO:modelsampling_call_after_dispatch method=%s instance_id=%s pid=%s"
% (method_name, self._instance_id, os.getpid()),
flush=True,
)
logger.debug(
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
method_name,
self._instance_id,
call_id,
_describe_value(out),
)
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
method_name,
self._instance_id,
call_id,
elapsed_ms,
thread_id,
)
logger.debug(
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
method_name,
self._instance_id,
call_id,
)
print(
"ISO:modelsampling_call_return method=%s instance_id=%s pid=%s"
% (method_name, self._instance_id, os.getpid()),
flush=True,
)
return out
@staticmethod
def _rpc_timeout_ms() -> int:
raw = os.environ.get(
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
)
try:
timeout_ms = int(raw)
except ValueError:
timeout_ms = 30000
return max(1, timeout_ms)
@property @property
def sigma_min(self) -> Any: def sigma_min(self) -> Any:
@ -235,10 +348,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
def noise_scaling( def noise_scaling(
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
) -> Any: ) -> Any:
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise) preferred_device = _prefer_device(noise, latent_image)
out = self._call(
"noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(noise),
_to_cpu_for_rpc(latent_image),
max_denoise,
)
return _to_device(out, preferred_device)
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any: def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
return self._call("inverse_noise_scaling", sigma, latent) preferred_device = _prefer_device(latent)
out = self._call(
"inverse_noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(latent),
)
return _to_device(out, preferred_device)
def timestep(self, sigma: Any) -> Any: def timestep(self, sigma: Any) -> Any:
return self._call("timestep", sigma) return self._call("timestep", sigma)

View File

@ -2,9 +2,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import concurrent.futures
import logging import logging
import os import os
import threading import threading
import time
import weakref import weakref
from typing import Any, Callable, Dict, Generic, Optional, TypeVar from typing import Any, Callable, Dict, Generic, Optional, TypeVar
@ -119,6 +121,24 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
class BaseProxy(Generic[T]): class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg] _registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base" __module__: str = "comfy.isolation.proxies.base"
_TIMEOUT_RPC_METHODS = frozenset(
{
"partially_load",
"partially_unload",
"load",
"patch_model",
"unpatch_model",
"inner_model_apply_model",
"memory_required",
"model_dtype",
"inner_model_memory_required",
"inner_model_extra_conds_shapes",
"inner_model_extra_conds",
"process_latent_in",
"process_latent_out",
"scale_latent_inpaint",
}
)
def __init__( def __init__(
self, self,
@ -148,39 +168,89 @@ class BaseProxy(Generic[T]):
) )
return self._rpc_caller return self._rpc_caller
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
if method_name not in self._TIMEOUT_RPC_METHODS:
return None
try:
timeout_ms = int(
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
)
except ValueError:
timeout_ms = 120000
return max(1, timeout_ms)
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any: def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
rpc = self._get_rpc() rpc = self._get_rpc()
method = getattr(rpc, method_name) method = getattr(rpc, method_name)
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
coro = method(self._instance_id, *args, **kwargs) coro = method(self._instance_id, *args, **kwargs)
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads start_epoch = time.time()
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): start_perf = time.perf_counter()
try: thread_id = threading.get_ident()
# If we are already in the global loop, we can't block on it? try:
# Actually, this method is synchronous (__getattr__ -> lambda). running_loop = asyncio.get_running_loop()
# If called from async context in main loop, we need to handle that. loop_id: Optional[int] = id(running_loop)
curr_loop = asyncio.get_running_loop() except RuntimeError:
if curr_loop is _GLOBAL_LOOP: loop_id = None
# We are in the main loop. We cannot await/block here if we are just a sync function. logger.debug(
# But proxies are often called from sync code. "ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
# If called from sync code in main loop, creating a new loop is bad. "thread=%s loop=%s timeout_ms=%s",
# But we can't await `coro`. self.__class__.__name__,
# This implies proxies MUST be awaited if called from async context? method_name,
# Existing code used `run_coro_in_new_loop` which is weird. self._instance_id,
# Let's trust that if we are in a thread (RuntimeError on get_running_loop), start_epoch,
# we use run_coroutine_threadsafe. thread_id,
pass loop_id,
except RuntimeError: timeout_ms,
# No running loop - we are in a worker thread. )
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result()
try: try:
asyncio.get_running_loop() # If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
return run_coro_in_new_loop(coro) if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
except RuntimeError: try:
loop = get_thread_loop() curr_loop = asyncio.get_running_loop()
return loop.run_until_complete(coro) if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
finally:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
"elapsed_ms=%.3f thread=%s loop=%s",
self.__class__.__name__,
method_name,
self._instance_id,
end_epoch,
elapsed_ms,
thread_id,
loop_id,
)
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:
return {"_instance_id": self._instance_id} return {"_instance_id": self._instance_id}

View File

@ -208,13 +208,18 @@ def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tenso
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options) return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
print("CC_DEBUG2: enter _calc_cond_batch_outer", flush=True)
executor = comfy.patcher_extension.WrapperExecutor.new_executor( executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch, _calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
) )
return executor.execute(model, conds, x_in, timestep, model_options) print("CC_DEBUG2: before _calc_cond_batch executor.execute", flush=True)
result = executor.execute(model, conds, x_in, timestep, model_options)
print("CC_DEBUG2: after _calc_cond_batch executor.execute", flush=True)
return result
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
print("CC_DEBUG2: enter _calc_cond_batch", flush=True)
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
out_conds = [] out_conds = []
out_counts = [] out_counts = []
@ -247,7 +252,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds: if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
print("CC_DEBUG: before prepare_state", flush=True)
model.current_patcher.prepare_state(timestep) model.current_patcher.prepare_state(timestep)
print("CC_DEBUG: after prepare_state", flush=True)
# run every hooked_to_run separately # run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items(): for hooks, to_run in hooked_to_run.items():
@ -262,7 +269,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
to_batch_temp.reverse() to_batch_temp.reverse()
to_batch = to_batch_temp[:1] to_batch = to_batch_temp[:1]
print("CC_DEBUG: before get_free_memory", flush=True)
free_memory = model.current_patcher.get_free_memory(x_in.device) free_memory = model.current_patcher.get_free_memory(x_in.device)
print("CC_DEBUG: after get_free_memory", flush=True)
for i in range(1, len(to_batch_temp) + 1): for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i] batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
@ -272,7 +281,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for k, v in to_run[tt][0].conditioning.items(): for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size()) cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: print("CC_DEBUG: before memory_required", flush=True)
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
print("CC_DEBUG: after memory_required", flush=True)
if memory_required * 1.5 < free_memory:
to_batch = batch_amount to_batch = batch_amount
break break
@ -309,7 +321,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
else: else:
timestep_ = torch.cat([timestep] * batch_chunks) timestep_ = torch.cat([timestep] * batch_chunks)
print("CC_DEBUG: before apply_hooks", flush=True)
transformer_options = model.current_patcher.apply_hooks(hooks=hooks) transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
print("CC_DEBUG: after apply_hooks", flush=True)
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'], model_options['transformer_options'],
@ -331,9 +345,13 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options: if 'model_function_wrapper' in model_options:
print("CC_DEBUG: before apply_model", flush=True)
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
print("CC_DEBUG: after apply_model", flush=True)
else: else:
print("CC_DEBUG: before apply_model", flush=True)
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
print("CC_DEBUG: after apply_model", flush=True)
for o in range(batch_chunks): for o in range(batch_chunks):
cond_index = cond_or_uncond[o] cond_index = cond_or_uncond[o]
@ -768,6 +786,7 @@ class KSAMPLER(Sampler):
self.inpaint_options = inpaint_options self.inpaint_options = inpaint_options
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
print("CC_DEBUG3: enter KSAMPLER.sample", flush=True)
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap, sigmas) model_k = KSamplerX0Inpaint(model_wrap, sigmas)
model_k.latent_image = latent_image model_k.latent_image = latent_image
@ -777,15 +796,46 @@ class KSAMPLER(Sampler):
else: else:
model_k.noise = noise model_k.noise = noise
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) print("CC_DEBUG3: before max_denoise", flush=True)
max_denoise = self.max_denoise(model_wrap, sigmas)
print("CC_DEBUG3: after max_denoise", flush=True)
print("CC_DEBUG3: before model_sampling_attr", flush=True)
model_sampling = model_wrap.inner_model.model_sampling
print(
"CC_DEBUG3: after model_sampling_attr type=%s id=%s instance_id=%s"
% (
type(model_sampling).__name__,
id(model_sampling),
getattr(model_sampling, "_instance_id", "n/a"),
),
flush=True,
)
print("CC_DEBUG3: before noise_scaling_call", flush=True)
try:
noise_scaled = model_sampling.noise_scaling(
sigmas[0], noise, latent_image, max_denoise
)
print("CC_DEBUG3: after noise_scaling_call", flush=True)
except Exception as e:
print(
"CC_DEBUG3: noise_scaling_exception type=%s msg=%s"
% (type(e).__name__, str(e)),
flush=True,
)
raise
noise = noise_scaled
print("CC_DEBUG3: after noise_assignment", flush=True)
k_callback = None k_callback = None
total_steps = len(sigmas) - 1 total_steps = len(sigmas) - 1
if callback is not None: if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
print("CC_DEBUG3: before sampler_function", flush=True)
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
print("CC_DEBUG3: after sampler_function", flush=True)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
print("CC_DEBUG3: after inverse_noise_scaling", flush=True)
return samples return samples
@ -825,10 +875,16 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
for k in conds: for k in conds:
calculate_start_end_timesteps(model, conds[k]) calculate_start_end_timesteps(model, conds[k])
print('GP_DEBUG: before hasattr extra_conds', flush=True)
print('GP_DEBUG: before hasattr extra_conds', flush=True)
if hasattr(model, 'extra_conds'): if hasattr(model, 'extra_conds'):
print('GP_DEBUG: has extra_conds!', flush=True)
print('GP_DEBUG: has extra_conds!', flush=True)
for k in conds: for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes) conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
print('GP_DEBUG: before area make sure loop', flush=True)
print('GP_DEBUG: before area make sure loop', flush=True)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for k in conds: for k in conds:
for c in conds[k]: for c in conds[k]:
@ -842,8 +898,11 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
for hook in c['hooks'].hooks: for hook in c['hooks'].hooks:
hook.initialize_timesteps(model) hook.initialize_timesteps(model)
print('GP_DEBUG: before pre_run_control loop', flush=True)
for k in conds: for k in conds:
print('GP_DEBUG: calling pre_run_control for key:', k, flush=True)
pre_run_control(model, conds[k]) pre_run_control(model, conds[k])
print('GP_DEBUG: after pre_run_control loop', flush=True)
if "positive" in conds: if "positive" in conds:
positive = conds["positive"] positive = conds["positive"]
@ -1005,6 +1064,8 @@ class CFGGuider:
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes) self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
print("GP_DEBUG: process_conds finished", flush=True)
print("GP_DEBUG: process_conds finished", flush=True)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
extra_args = {"model_options": extra_model_options, "seed": seed} extra_args = {"model_options": extra_model_options, "seed": seed}
@ -1014,6 +1075,7 @@ class CFGGuider:
sampler, sampler,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
) )
print("GP_DEBUG: before executor.execute", flush=True)
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32)) return self.inner_model.process_latent_out(samples.to(torch.float32))

View File

@ -696,6 +696,24 @@ class PromptExecutor:
except Exception: except Exception:
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True) logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
async def _wait_model_patcher_quiescence_safe(
self,
*,
fail_loud: bool = False,
timeout_ms: int = 120000,
marker: str = "EX:wait_model_patcher_idle",
) -> None:
try:
from comfy.isolation import wait_for_model_patcher_quiescence
await wait_for_model_patcher_quiescence(
timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker
)
except Exception:
if fail_loud:
raise
logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True)
def add_message(self, event, data: dict, broadcast: bool): def add_message(self, event, data: dict, broadcast: bool):
data = { data = {
**data, **data,
@ -766,6 +784,11 @@ class PromptExecutor:
# Boundary cleanup runs at the start of the next workflow in # Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing. # isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._wait_model_patcher_quiescence_safe(
fail_loud=False,
timeout_ms=120000,
marker="EX:boundary_cleanup_wait_idle",
)
await self._flush_running_extensions_transport_state_safe() await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc() comfy.model_management.cleanup_models_gc()
@ -806,6 +829,11 @@ class PromptExecutor:
for node_id in execution_list.pendingNodes.keys(): for node_id in execution_list.pendingNodes.keys():
class_type = dynamic_prompt.get_node(node_id)["class_type"] class_type = dynamic_prompt.get_node(node_id)["class_type"]
pending_class_types.add(class_type) pending_class_types.add(class_type)
await self._wait_model_patcher_quiescence_safe(
fail_loud=True,
timeout_ms=120000,
marker="EX:notify_graph_wait_idle",
)
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True) await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
while not execution_list.is_empty(): while not execution_list.is_empty():