diff --git a/comfy/isolation/__init__.py b/comfy/isolation/__init__.py index c72a92807..ed6d4265c 100644 --- a/comfy/isolation/__init__.py +++ b/comfy/isolation/__init__.py @@ -26,6 +26,7 @@ PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True) logger = logging.getLogger(__name__) _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 +_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000 def initialize_proxies() -> None: @@ -168,6 +169,11 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]: async def notify_execution_graph(needed_class_types: Set[str]) -> None: """Evict running extensions not needed for current execution.""" + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:notify_graph_wait_idle", + ) async def _stop_extension( ext_name: str, extension: "ComfyNodeExtension", reason: str @@ -237,6 +243,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None: async def flush_running_extensions_transport_state() -> int: + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:flush_transport_wait_idle", + ) total_flushed = 0 for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): flush_fn = getattr(extension, "flush_transport_state", None) @@ -263,6 +274,50 @@ async def flush_running_extensions_transport_state() -> int: return total_flushed +async def wait_for_model_patcher_quiescence( + timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS, + *, + fail_loud: bool = False, + marker: str = "ISO:wait_model_patcher_idle", +) -> bool: + try: + from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry + + registry = ModelPatcherRegistry() + start = time.perf_counter() + idle = await registry.wait_all_idle(timeout_ms) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + if idle: + logger.debug( + "%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + ) + return True + + states = await registry.get_all_operation_states() + logger.error( + "%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + states, + ) + if fail_loud: + raise TimeoutError( + f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms" + ) + return False + except Exception: + if fail_loud: + raise + logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True) + return False + + def get_claimed_paths() -> Set[Path]: return _CLAIMED_PATHS @@ -320,6 +375,7 @@ __all__ = [ "await_isolation_loading", "notify_execution_graph", "flush_running_extensions_transport_state", + "wait_for_model_patcher_quiescence", "get_claimed_paths", "update_rpc_event_loops", "IsolatedNodeSpec", diff --git a/comfy/isolation/adapter.py b/comfy/isolation/adapter.py index 2dea2f0f0..ed1cda004 100644 --- a/comfy/isolation/adapter.py +++ b/comfy/isolation/adapter.py @@ -83,6 +83,33 @@ class ComfyUIAdapter(IsolationAdapter): logging.getLogger(pkg_name).setLevel(logging.ERROR) def register_serializers(self, registry: SerializerRegistryProtocol) -> None: + import torch + + def serialize_device(obj: Any) -> Dict[str, Any]: + return {"__type__": "device", "device_str": str(obj)} + + def deserialize_device(data: Dict[str, Any]) -> Any: + return torch.device(data["device_str"]) + + registry.register("device", serialize_device, deserialize_device) + + _VALID_DTYPES = { + "float16", "float32", "float64", "bfloat16", + "int8", "int16", "int32", "int64", + "uint8", "bool", + } + + def serialize_dtype(obj: Any) -> Dict[str, Any]: + return {"__type__": "dtype", "dtype_str": str(obj)} + + def deserialize_dtype(data: Dict[str, Any]) -> Any: + dtype_name = data["dtype_str"].replace("torch.", "") + if dtype_name not in _VALID_DTYPES: + raise ValueError(f"Invalid dtype: {data['dtype_str']}") + return getattr(torch, dtype_name) + + registry.register("dtype", serialize_dtype, deserialize_dtype) + def serialize_model_patcher(obj: Any) -> Dict[str, Any]: # Child-side: must already have _instance_id (proxy) if os.environ.get("PYISOLATE_CHILD") == "1": @@ -193,6 +220,10 @@ class ComfyUIAdapter(IsolationAdapter): f"ModelSampling in child lacks _instance_id: " f"{type(obj).__module__}.{type(obj).__name__}" ) + # Host-side pass-through for proxies: do not re-register a proxy as a + # new ModelSamplingRef, or we create proxy-of-proxy indirection. + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} # Host-side: register with ModelSamplingRegistry and return JSON-safe dict ms_id = ModelSamplingRegistry().register(obj) return {"__type__": "ModelSamplingRef", "ms_id": ms_id} @@ -211,22 +242,21 @@ class ComfyUIAdapter(IsolationAdapter): else: return ModelSamplingRegistry()._get_instance(data["ms_id"]) - # Register ModelSampling type and proxy - registry.register( - "ModelSamplingDiscrete", - serialize_model_sampling, - deserialize_model_sampling, - ) - registry.register( - "ModelSamplingContinuousEDM", - serialize_model_sampling, - deserialize_model_sampling, - ) - registry.register( - "ModelSamplingContinuousV", - serialize_model_sampling, - deserialize_model_sampling, - ) + # Register all ModelSampling* and StableCascadeSampling classes dynamically + import comfy.model_sampling + + for ms_cls in vars(comfy.model_sampling).values(): + if not isinstance(ms_cls, type): + continue + if not issubclass(ms_cls, torch.nn.Module): + continue + if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"): + continue + registry.register( + ms_cls.__name__, + serialize_model_sampling, + deserialize_model_sampling, + ) registry.register( "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling ) diff --git a/comfy/isolation/extension_wrapper.py b/comfy/isolation/extension_wrapper.py index 23148e470..473d27bd4 100644 --- a/comfy/isolation/extension_wrapper.py +++ b/comfy/isolation/extension_wrapper.py @@ -382,6 +382,10 @@ class ComfyNodeExtension(ExtensionBase): if type(result).__name__ == "NodeOutput": result = result.args + print( + f"{LOG_PREFIX} ISO:child_result_ready node={node_name} type={type(result).__name__}", + flush=True, + ) if self._is_comfy_protocol_return(result): logger.debug( "%s ISO:child_execute_done ext=%s node=%s protocol_return=1", @@ -389,10 +393,17 @@ class ComfyNodeExtension(ExtensionBase): getattr(self, "name", "?"), node_name, ) - return self._wrap_unpicklable_objects(result) + print(f"{LOG_PREFIX} ISO:child_wrap_start node={node_name} protocol=1", flush=True) + wrapped = self._wrap_unpicklable_objects(result) + print(f"{LOG_PREFIX} ISO:child_wrap_done node={node_name} protocol=1", flush=True) + return wrapped if not isinstance(result, tuple): result = (result,) + print( + f"{LOG_PREFIX} ISO:child_result_tuple node={node_name} outputs={len(result)}", + flush=True, + ) logger.debug( "%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d", LOG_PREFIX, @@ -400,7 +411,10 @@ class ComfyNodeExtension(ExtensionBase): node_name, len(result), ) - return self._wrap_unpicklable_objects(result) + print(f"{LOG_PREFIX} ISO:child_wrap_start node={node_name} protocol=0", flush=True) + wrapped = self._wrap_unpicklable_objects(result) + print(f"{LOG_PREFIX} ISO:child_wrap_done node={node_name} protocol=0", flush=True) + return wrapped async def flush_transport_state(self) -> int: if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1": @@ -443,7 +457,10 @@ class ComfyNodeExtension(ExtensionBase): if isinstance(data, (str, int, float, bool, type(None))): return data if isinstance(data, torch.Tensor): - return data.detach() if data.requires_grad else data + tensor = data.detach() if data.requires_grad else data + if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu": + return tensor.cpu() + return tensor # Special-case clip vision outputs: preserve attribute access by packing fields if hasattr(data, "penultimate_hidden_states") or hasattr( diff --git a/comfy/isolation/model_patcher_proxy.py b/comfy/isolation/model_patcher_proxy.py index 6769996b8..e1c513933 100644 --- a/comfy/isolation/model_patcher_proxy.py +++ b/comfy/isolation/model_patcher_proxy.py @@ -338,6 +338,34 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): def apply_model(self, *args, **kwargs) -> Any: import torch + def _preferred_device() -> Any: + for value in args: + if isinstance(value, torch.Tensor): + return value.device + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + return value.device + return None + + def _move_result_to_device(obj: Any, device: Any) -> Any: + if device is None: + return obj + if isinstance(obj, torch.Tensor): + return obj.to(device) if obj.device != device else obj + if isinstance(obj, dict): + return {k: _move_result_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, list): + return [_move_result_to_device(v, device) for v in obj] + if isinstance(obj, tuple): + return tuple(_move_result_to_device(v, device) for v in obj) + return obj + + # DynamicVRAM models must keep load/offload decisions in host process. + # Child-side CUDA staging here can deadlock before first inference RPC. + if self.is_dynamic(): + out = self._call_rpc("inner_model_apply_model", args, kwargs) + return _move_result_to_device(out, _preferred_device()) + required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs) self._ensure_apply_model_headroom(required_bytes) @@ -360,7 +388,8 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): args_cuda = _to_cuda(args) kwargs_cuda = _to_cuda(kwargs) - return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda) + out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda) + return _move_result_to_device(out, _preferred_device()) def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any: keys = self._call_rpc("model_state_dict", filter_prefix) @@ -526,6 +555,13 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): def memory_required(self, input_shape: Any) -> Any: return self._call_rpc("memory_required", input_shape) + def get_operation_state(self) -> Dict[str, Any]: + state = self._call_rpc("get_operation_state") + return state if isinstance(state, dict) else {} + + def wait_for_idle(self, timeout_ms: int = 0) -> bool: + return bool(self._call_rpc("wait_for_idle", timeout_ms)) + def is_dynamic(self) -> bool: return bool(self._call_rpc("is_dynamic")) @@ -771,6 +807,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): class _InnerModelProxy: def __init__(self, parent: ModelPatcherProxy): self._parent = parent + self._model_sampling = None def __getattr__(self, name: str) -> Any: if name.startswith("_"): @@ -793,7 +830,11 @@ class _InnerModelProxy: manage_lifecycle=False, ) if name == "model_sampling": - return self._parent._call_rpc("get_model_object", "model_sampling") + if self._model_sampling is None: + self._model_sampling = self._parent._call_rpc( + "get_model_object", "model_sampling" + ) + return self._model_sampling if name == "extra_conds_shapes": return lambda *a, **k: self._parent._call_rpc( "inner_model_extra_conds_shapes", a, k diff --git a/comfy/isolation/model_patcher_proxy_registry.py b/comfy/isolation/model_patcher_proxy_registry.py index 7224c3233..cb4558d2f 100644 --- a/comfy/isolation/model_patcher_proxy_registry.py +++ b/comfy/isolation/model_patcher_proxy_registry.py @@ -2,8 +2,12 @@ # RPC server for ModelPatcher isolation (child process) from __future__ import annotations +import asyncio import gc import logging +import threading +import time +from dataclasses import dataclass, field from typing import Any, Optional, List try: @@ -43,12 +47,191 @@ from comfy.isolation.proxies.base import ( logger = logging.getLogger(__name__) +@dataclass +class _OperationState: + lease: threading.Lock = field(default_factory=threading.Lock) + active_count: int = 0 + active_by_method: dict[str, int] = field(default_factory=dict) + total_operations: int = 0 + last_method: Optional[str] = None + last_started_ts: Optional[float] = None + last_ended_ts: Optional[float] = None + last_elapsed_ms: Optional[float] = None + last_error: Optional[str] = None + last_thread_id: Optional[int] = None + last_loop_id: Optional[int] = None + + class ModelPatcherRegistry(BaseRegistry[Any]): _type_prefix = "model" def __init__(self) -> None: super().__init__() self._pending_cleanup_ids: set[str] = set() + self._operation_states: dict[str, _OperationState] = {} + self._operation_state_cv = threading.Condition(self._lock) + + def _get_or_create_operation_state(self, instance_id: str) -> _OperationState: + state = self._operation_states.get(instance_id) + if state is None: + state = _OperationState() + self._operation_states[instance_id] = state + return state + + def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]: + start_epoch = time.time() + start_perf = time.perf_counter() + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + state.active_count += 1 + state.active_by_method[method_name] = ( + state.active_by_method.get(method_name, 0) + 1 + ) + state.total_operations += 1 + state.last_method = method_name + state.last_started_ts = start_epoch + state.last_thread_id = threading.get_ident() + try: + state.last_loop_id = id(asyncio.get_running_loop()) + except RuntimeError: + state.last_loop_id = None + logger.debug( + "ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s", + instance_id, + method_name, + start_epoch, + threading.get_ident(), + state.last_loop_id, + ) + return start_epoch, start_perf + + def _end_operation( + self, + instance_id: str, + method_name: str, + start_perf: float, + error: Optional[BaseException] = None, + ) -> None: + end_epoch = time.time() + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + state.active_count = max(0, state.active_count - 1) + if method_name in state.active_by_method: + remaining = state.active_by_method[method_name] - 1 + if remaining <= 0: + state.active_by_method.pop(method_name, None) + else: + state.active_by_method[method_name] = remaining + state.last_ended_ts = end_epoch + state.last_elapsed_ms = elapsed_ms + state.last_error = None if error is None else repr(error) + if state.active_count == 0: + self._operation_state_cv.notify_all() + logger.debug( + "ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s", + instance_id, + method_name, + end_epoch, + elapsed_ms, + None if error is None else type(error).__name__, + ) + + def _run_operation_with_lease(self, instance_id: str, method_name: str, fn): + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + lease = state.lease + with lease: + _, start_perf = self._begin_operation(instance_id, method_name) + try: + result = fn() + except Exception as exc: + self._end_operation(instance_id, method_name, start_perf, error=exc) + raise + self._end_operation(instance_id, method_name, start_perf) + return result + + def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]: + with self._operation_state_cv: + state = self._operation_states.get(instance_id) + if state is None: + return { + "instance_id": instance_id, + "active_count": 0, + "active_methods": [], + "total_operations": 0, + "last_method": None, + "last_started_ts": None, + "last_ended_ts": None, + "last_elapsed_ms": None, + "last_error": None, + "last_thread_id": None, + "last_loop_id": None, + } + return { + "instance_id": instance_id, + "active_count": state.active_count, + "active_methods": sorted(state.active_by_method.keys()), + "total_operations": state.total_operations, + "last_method": state.last_method, + "last_started_ts": state.last_started_ts, + "last_ended_ts": state.last_ended_ts, + "last_elapsed_ms": state.last_elapsed_ms, + "last_error": state.last_error, + "last_thread_id": state.last_thread_id, + "last_loop_id": state.last_loop_id, + } + + def unregister_sync(self, instance_id: str) -> None: + with self._operation_state_cv: + instance = self._registry.pop(instance_id, None) + if instance is not None: + self._id_map.pop(id(instance), None) + self._pending_cleanup_ids.discard(instance_id) + self._operation_states.pop(instance_id, None) + self._operation_state_cv.notify_all() + + async def get_operation_state(self, instance_id: str) -> dict[str, Any]: + return self._snapshot_operation_state(instance_id) + + async def get_all_operation_states(self) -> dict[str, dict[str, Any]]: + with self._operation_state_cv: + ids = sorted(self._operation_states.keys()) + return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids} + + async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool: + timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0) + deadline = None if timeout_s is None else (time.monotonic() + timeout_s) + with self._operation_state_cv: + while True: + active = self._operation_states.get(instance_id) + if active is None or active.active_count == 0: + return True + if deadline is None: + self._operation_state_cv.wait() + continue + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + self._operation_state_cv.wait(timeout=remaining) + + async def wait_all_idle(self, timeout_ms: int = 0) -> bool: + timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0) + deadline = None if timeout_s is None else (time.monotonic() + timeout_s) + with self._operation_state_cv: + while True: + has_active = any( + state.active_count > 0 for state in self._operation_states.values() + ) + if not has_active: + return True + if deadline is None: + self._operation_state_cv.wait() + continue + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + self._operation_state_cv.wait(timeout=remaining) async def clone(self, instance_id: str) -> str: instance = self._get_instance(instance_id) @@ -62,8 +245,10 @@ class ModelPatcherRegistry(BaseRegistry[Any]): return False async def get_model_object(self, instance_id: str, name: str) -> Any: + print(f"GP_DEBUG: get_model_object START for name={name}", flush=True) instance = self._get_instance(instance_id) if name == "model": + print(f"GP_DEBUG: get_model_object END for name={name} (ModelObject)", flush=True) return f"" result = instance.get_model_object(name) if name == "model_sampling": @@ -73,8 +258,16 @@ class ModelPatcherRegistry(BaseRegistry[Any]): ) registry = ModelSamplingRegistry() - sampling_id = registry.register(result) + # Preserve identity when upstream already returned a proxy. Re-registering + # a proxy object creates proxy-of-proxy call chains. + if isinstance(result, ModelSamplingProxy): + sampling_id = result._instance_id + else: + sampling_id = registry.register(result) + print(f"GP_DEBUG: get_model_object END for name={name} (model_sampling)", flush=True) return ModelSamplingProxy(sampling_id, registry) + + print(f"GP_DEBUG: get_model_object END for name={name} (fallthrough)", flush=True) return detach_if_grad(result) async def get_model_options(self, instance_id: str) -> dict: @@ -163,7 +356,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]): return self._get_instance(instance_id).lowvram_patch_counter() async def memory_required(self, instance_id: str, input_shape: Any) -> Any: - return self._get_instance(instance_id).memory_required(input_shape) + return self._run_operation_with_lease( + instance_id, + "memory_required", + lambda: self._get_instance(instance_id).memory_required(input_shape), + ) async def is_dynamic(self, instance_id: str) -> bool: instance = self._get_instance(instance_id) @@ -186,7 +383,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]): return None async def model_dtype(self, instance_id: str) -> Any: - return self._get_instance(instance_id).model_dtype() + return self._run_operation_with_lease( + instance_id, + "model_dtype", + lambda: self._get_instance(instance_id).model_dtype(), + ) async def model_patches_to(self, instance_id: str, device: Any) -> Any: return self._get_instance(instance_id).model_patches_to(device) @@ -198,8 +399,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]): extra_memory: Any, force_patch_weights: bool = False, ) -> Any: - return self._get_instance(instance_id).partially_load( - device, extra_memory, force_patch_weights=force_patch_weights + return self._run_operation_with_lease( + instance_id, + "partially_load", + lambda: self._get_instance(instance_id).partially_load( + device, extra_memory, force_patch_weights=force_patch_weights + ), ) async def partially_unload( @@ -209,8 +414,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]): memory_to_free: int = 0, force_patch_weights: bool = False, ) -> int: - return self._get_instance(instance_id).partially_unload( - device_to, memory_to_free, force_patch_weights + return self._run_operation_with_lease( + instance_id, + "partially_unload", + lambda: self._get_instance(instance_id).partially_unload( + device_to, memory_to_free, force_patch_weights + ), ) async def load( @@ -221,8 +430,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]): force_patch_weights: bool = False, full_load: bool = False, ) -> None: - self._get_instance(instance_id).load( - device_to, lowvram_model_memory, force_patch_weights, full_load + self._run_operation_with_lease( + instance_id, + "load", + lambda: self._get_instance(instance_id).load( + device_to, lowvram_model_memory, force_patch_weights, full_load + ), ) async def patch_model( @@ -233,20 +446,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]): load_weights: bool = True, force_patch_weights: bool = False, ) -> None: - try: - self._get_instance(instance_id).patch_model( - device_to, lowvram_model_memory, load_weights, force_patch_weights - ) - except AttributeError as e: - logger.error( - f"Isolation Error: Failed to patch model attribute: {e}. Skipping." - ) - return + def _invoke() -> None: + try: + self._get_instance(instance_id).patch_model( + device_to, lowvram_model_memory, load_weights, force_patch_weights + ) + except AttributeError as e: + logger.error( + f"Isolation Error: Failed to patch model attribute: {e}. Skipping." + ) + return + + self._run_operation_with_lease(instance_id, "patch_model", _invoke) async def unpatch_model( self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True ) -> None: - self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights) + self._run_operation_with_lease( + instance_id, + "unpatch_model", + lambda: self._get_instance(instance_id).unpatch_model( + device_to, unpatch_weights + ), + ) async def detach(self, instance_id: str, unpatch_all: bool = True) -> None: self._get_instance(instance_id).detach(unpatch_all) @@ -262,26 +484,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]): self._get_instance(instance_id).pre_run() async def cleanup(self, instance_id: str) -> None: - try: - instance = self._get_instance(instance_id) - except Exception: - logger.debug( - "ModelPatcher cleanup requested for missing instance %s", - instance_id, - exc_info=True, - ) - return + def _invoke() -> None: + try: + instance = self._get_instance(instance_id) + except Exception: + logger.debug( + "ModelPatcher cleanup requested for missing instance %s", + instance_id, + exc_info=True, + ) + return - try: - instance.cleanup() - finally: - with self._lock: - self._pending_cleanup_ids.add(instance_id) - gc.collect() + try: + instance.cleanup() + finally: + with self._lock: + self._pending_cleanup_ids.add(instance_id) + gc.collect() + + self._run_operation_with_lease(instance_id, "cleanup", _invoke) def sweep_pending_cleanup(self) -> int: removed = 0 - with self._lock: + with self._operation_state_cv: pending_ids = list(self._pending_cleanup_ids) self._pending_cleanup_ids.clear() for instance_id in pending_ids: @@ -289,17 +514,21 @@ class ModelPatcherRegistry(BaseRegistry[Any]): if instance is None: continue self._id_map.pop(id(instance), None) + self._operation_states.pop(instance_id, None) removed += 1 + self._operation_state_cv.notify_all() gc.collect() return removed def purge_all(self) -> int: - with self._lock: + with self._operation_state_cv: removed = len(self._registry) self._registry.clear() self._id_map.clear() self._pending_cleanup_ids.clear() + self._operation_states.clear() + self._operation_state_cv.notify_all() gc.collect() return removed @@ -743,17 +972,52 @@ class ModelPatcherRegistry(BaseRegistry[Any]): async def inner_model_memory_required( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: - return self._get_instance(instance_id).model.memory_required(*args, **kwargs) + return self._run_operation_with_lease( + instance_id, + "inner_model_memory_required", + lambda: self._get_instance(instance_id).model.memory_required( + *args, **kwargs + ), + ) async def inner_model_extra_conds_shapes( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: - return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs) + return self._run_operation_with_lease( + instance_id, + "inner_model_extra_conds_shapes", + lambda: self._get_instance(instance_id).model.extra_conds_shapes( + *args, **kwargs + ), + ) async def inner_model_extra_conds( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: - return self._get_instance(instance_id).model.extra_conds(*args, **kwargs) + def _invoke() -> Any: + result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs) + try: + import torch + import comfy.conds + except Exception: + return result + + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + if isinstance(obj, comfy.conds.CONDRegular): + return type(obj)(_to_cpu(obj.cond)) + return obj + + return _to_cpu(result) + + return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke) async def inner_model_state_dict( self, instance_id: str, args: tuple, kwargs: dict @@ -767,82 +1031,160 @@ class ModelPatcherRegistry(BaseRegistry[Any]): async def inner_model_apply_model( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: - instance = self._get_instance(instance_id) - target = getattr(instance, "load_device", None) - if target is None and args and hasattr(args[0], "device"): - target = args[0].device - elif target is None: - for v in kwargs.values(): - if hasattr(v, "device"): - target = v.device - break + def _invoke() -> Any: + import torch - def _move(obj): - if target is None: + instance = self._get_instance(instance_id) + target = getattr(instance, "load_device", None) + if target is None and args and hasattr(args[0], "device"): + target = args[0].device + elif target is None: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + + def _move(obj): + if target is None: + return obj + if isinstance(obj, (tuple, list)): + return type(obj)(_move(o) for o in obj) + if hasattr(obj, "to"): + return obj.to(target) return obj - if isinstance(obj, (tuple, list)): - return type(obj)(_move(o) for o in obj) - if hasattr(obj, "to"): - return obj.to(target) - return obj - moved_args = tuple(_move(a) for a in args) - moved_kwargs = {k: _move(v) for k, v in kwargs.items()} - result = instance.model.apply_model(*moved_args, **moved_kwargs) - return detach_if_grad(_move(result)) + moved_args = tuple(_move(a) for a in args) + moved_kwargs = {k: _move(v) for k, v in kwargs.items()} + result = instance.model.apply_model(*moved_args, **moved_kwargs) + moved_result = detach_if_grad(_move(result)) + + # DynamicVRAM + isolation: returning CUDA tensors across RPC can stall + # at the transport boundary. Marshal dynamic-path results as CPU and let + # the proxy restore device placement in the child process. + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke) async def process_latent_in( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: - return detach_if_grad( - self._get_instance(instance_id).model.process_latent_in(*args, **kwargs) + return self._run_operation_with_lease( + instance_id, + "process_latent_in", + lambda: detach_if_grad( + self._get_instance(instance_id).model.process_latent_in( + *args, **kwargs + ) + ), ) async def process_latent_out( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: - instance = self._get_instance(instance_id) - result = instance.model.process_latent_out(*args, **kwargs) - try: - target = None - if args and hasattr(args[0], "device"): - target = args[0].device - elif kwargs: - for v in kwargs.values(): - if hasattr(v, "device"): - target = v.device - break - if target is not None and hasattr(result, "to"): - return detach_if_grad(result.to(target)) - except Exception: - logger.debug( - "process_latent_out: failed to move result to target device", - exc_info=True, - ) - return detach_if_grad(result) + import torch + + def _invoke() -> Any: + instance = self._get_instance(instance_id) + result = instance.model.process_latent_out(*args, **kwargs) + moved_result = None + try: + target = None + if args and hasattr(args[0], "device"): + target = args[0].device + elif kwargs: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + if target is not None and hasattr(result, "to"): + moved_result = detach_if_grad(result.to(target)) + except Exception: + logger.debug( + "process_latent_out: failed to move result to target device", + exc_info=True, + ) + if moved_result is None: + moved_result = detach_if_grad(result) + + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke) async def scale_latent_inpaint( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: - instance = self._get_instance(instance_id) - result = instance.model.scale_latent_inpaint(*args, **kwargs) - try: - target = None - if args and hasattr(args[0], "device"): - target = args[0].device - elif kwargs: - for v in kwargs.values(): - if hasattr(v, "device"): - target = v.device - break - if target is not None and hasattr(result, "to"): - return detach_if_grad(result.to(target)) - except Exception: - logger.debug( - "scale_latent_inpaint: failed to move result to target device", - exc_info=True, - ) - return detach_if_grad(result) + import torch + + def _invoke() -> Any: + instance = self._get_instance(instance_id) + result = instance.model.scale_latent_inpaint(*args, **kwargs) + moved_result = None + try: + target = None + if args and hasattr(args[0], "device"): + target = args[0].device + elif kwargs: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + if target is not None and hasattr(result, "to"): + moved_result = detach_if_grad(result.to(target)) + except Exception: + logger.debug( + "scale_latent_inpaint: failed to move result to target device", + exc_info=True, + ) + if moved_result is None: + moved_result = detach_if_grad(result) + + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease( + instance_id, "scale_latent_inpaint", _invoke + ) async def load_lora( self, diff --git a/comfy/isolation/model_sampling_proxy.py b/comfy/isolation/model_sampling_proxy.py index 886c60409..8831ff573 100644 --- a/comfy/isolation/model_sampling_proxy.py +++ b/comfy/isolation/model_sampling_proxy.py @@ -3,6 +3,9 @@ from __future__ import annotations import asyncio import logging +import os +import threading +import time from typing import Any from comfy.isolation.proxies.base import ( @@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import ( logger = logging.getLogger(__name__) +def _describe_value(obj: Any) -> str: + try: + import torch + except Exception: + torch = None + try: + if torch is not None and isinstance(obj, torch.Tensor): + return ( + "Tensor(shape=%s,dtype=%s,device=%s,id=%s)" + % (tuple(obj.shape), obj.dtype, obj.device, id(obj)) + ) + except Exception: + pass + return "%s(id=%s)" % (type(obj).__name__, id(obj)) + + def _prefer_device(*tensors: Any) -> Any: try: import torch @@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any: return obj +def _to_cpu_for_rpc(obj: Any) -> Any: + try: + import torch + except Exception: + return obj + if isinstance(obj, torch.Tensor): + t = obj.detach() if obj.requires_grad else obj + if t.is_cuda: + return t.to("cpu") + return t + if isinstance(obj, (list, tuple)): + converted = [_to_cpu_for_rpc(x) for x in obj] + return type(obj)(converted) if isinstance(obj, tuple) else converted + if isinstance(obj, dict): + return {k: _to_cpu_for_rpc(v) for k, v in obj.items()} + return obj + + class ModelSamplingRegistry(BaseRegistry[Any]): _type_prefix = "modelsampling" @@ -196,17 +233,93 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]): return self._rpc_caller def _call(self, method_name: str, *args: Any) -> Any: + print( + "ISO:modelsampling_call_enter method=%s instance_id=%s pid=%s" + % (method_name, self._instance_id, os.getpid()), + flush=True, + ) rpc = self._get_rpc() method = getattr(rpc, method_name) + print( + "ISO:modelsampling_call_before_dispatch method=%s instance_id=%s pid=%s" + % (method_name, self._instance_id, os.getpid()), + flush=True, + ) result = method(self._instance_id, *args) + timeout_ms = self._rpc_timeout_ms() + start_epoch = time.time() + start_perf = time.perf_counter() + thread_id = threading.get_ident() + call_id = "%s:%s:%s:%.6f" % ( + self._instance_id, + method_name, + thread_id, + start_perf, + ) + logger.debug( + "ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s", + method_name, + self._instance_id, + call_id, + start_epoch, + thread_id, + timeout_ms, + ) if asyncio.iscoroutine(result): + result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0) try: asyncio.get_running_loop() - return run_coro_in_new_loop(result) + out = run_coro_in_new_loop(result) except RuntimeError: loop = get_thread_loop() - return loop.run_until_complete(result) - return result + out = loop.run_until_complete(result) + else: + out = result + print( + "ISO:modelsampling_call_after_dispatch method=%s instance_id=%s pid=%s" + % (method_name, self._instance_id, os.getpid()), + flush=True, + ) + logger.debug( + "ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s", + method_name, + self._instance_id, + call_id, + _describe_value(out), + ) + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + logger.debug( + "ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s", + method_name, + self._instance_id, + call_id, + elapsed_ms, + thread_id, + ) + logger.debug( + "ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s", + method_name, + self._instance_id, + call_id, + ) + print( + "ISO:modelsampling_call_return method=%s instance_id=%s pid=%s" + % (method_name, self._instance_id, os.getpid()), + flush=True, + ) + return out + + @staticmethod + def _rpc_timeout_ms() -> int: + raw = os.environ.get( + "COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS", + os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"), + ) + try: + timeout_ms = int(raw) + except ValueError: + timeout_ms = 30000 + return max(1, timeout_ms) @property def sigma_min(self) -> Any: @@ -235,10 +348,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]): def noise_scaling( self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False ) -> Any: - return self._call("noise_scaling", sigma, noise, latent_image, max_denoise) + preferred_device = _prefer_device(noise, latent_image) + out = self._call( + "noise_scaling", + _to_cpu_for_rpc(sigma), + _to_cpu_for_rpc(noise), + _to_cpu_for_rpc(latent_image), + max_denoise, + ) + return _to_device(out, preferred_device) def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any: - return self._call("inverse_noise_scaling", sigma, latent) + preferred_device = _prefer_device(latent) + out = self._call( + "inverse_noise_scaling", + _to_cpu_for_rpc(sigma), + _to_cpu_for_rpc(latent), + ) + return _to_device(out, preferred_device) def timestep(self, sigma: Any) -> Any: return self._call("timestep", sigma) diff --git a/comfy/isolation/proxies/base.py b/comfy/isolation/proxies/base.py index 90f75b1e3..71cc1943c 100644 --- a/comfy/isolation/proxies/base.py +++ b/comfy/isolation/proxies/base.py @@ -2,9 +2,11 @@ from __future__ import annotations import asyncio +import concurrent.futures import logging import os import threading +import time import weakref from typing import Any, Callable, Dict, Generic, Optional, TypeVar @@ -119,6 +121,24 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None: class BaseProxy(Generic[T]): _registry_class: type = BaseRegistry # type: ignore[type-arg] __module__: str = "comfy.isolation.proxies.base" + _TIMEOUT_RPC_METHODS = frozenset( + { + "partially_load", + "partially_unload", + "load", + "patch_model", + "unpatch_model", + "inner_model_apply_model", + "memory_required", + "model_dtype", + "inner_model_memory_required", + "inner_model_extra_conds_shapes", + "inner_model_extra_conds", + "process_latent_in", + "process_latent_out", + "scale_latent_inpaint", + } + ) def __init__( self, @@ -148,39 +168,89 @@ class BaseProxy(Generic[T]): ) return self._rpc_caller + def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]: + if method_name not in self._TIMEOUT_RPC_METHODS: + return None + try: + timeout_ms = int( + os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000") + ) + except ValueError: + timeout_ms = 120000 + return max(1, timeout_ms) + def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any: rpc = self._get_rpc() method = getattr(rpc, method_name) + timeout_ms = self._rpc_timeout_ms_for_method(method_name) coro = method(self._instance_id, *args, **kwargs) + if timeout_ms is not None: + coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0) - # If we have a global loop (Main Thread Loop), use it for dispatch from worker threads - if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): - try: - # If we are already in the global loop, we can't block on it? - # Actually, this method is synchronous (__getattr__ -> lambda). - # If called from async context in main loop, we need to handle that. - curr_loop = asyncio.get_running_loop() - if curr_loop is _GLOBAL_LOOP: - # We are in the main loop. We cannot await/block here if we are just a sync function. - # But proxies are often called from sync code. - # If called from sync code in main loop, creating a new loop is bad. - # But we can't await `coro`. - # This implies proxies MUST be awaited if called from async context? - # Existing code used `run_coro_in_new_loop` which is weird. - # Let's trust that if we are in a thread (RuntimeError on get_running_loop), - # we use run_coroutine_threadsafe. - pass - except RuntimeError: - # No running loop - we are in a worker thread. - future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP) - return future.result() + start_epoch = time.time() + start_perf = time.perf_counter() + thread_id = threading.get_ident() + try: + running_loop = asyncio.get_running_loop() + loop_id: Optional[int] = id(running_loop) + except RuntimeError: + loop_id = None + logger.debug( + "ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f " + "thread=%s loop=%s timeout_ms=%s", + self.__class__.__name__, + method_name, + self._instance_id, + start_epoch, + thread_id, + loop_id, + timeout_ms, + ) try: - asyncio.get_running_loop() - return run_coro_in_new_loop(coro) - except RuntimeError: - loop = get_thread_loop() - return loop.run_until_complete(coro) + # If we have a global loop (Main Thread Loop), use it for dispatch from worker threads + if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): + try: + curr_loop = asyncio.get_running_loop() + if curr_loop is _GLOBAL_LOOP: + pass + except RuntimeError: + # No running loop - we are in a worker thread. + future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP) + return future.result( + timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None + ) + + try: + asyncio.get_running_loop() + return run_coro_in_new_loop(coro) + except RuntimeError: + loop = get_thread_loop() + return loop.run_until_complete(coro) + except asyncio.TimeoutError as exc: + raise TimeoutError( + f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} " + f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})" + ) from exc + except concurrent.futures.TimeoutError as exc: + raise TimeoutError( + f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} " + f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})" + ) from exc + finally: + end_epoch = time.time() + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + logger.debug( + "ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f " + "elapsed_ms=%.3f thread=%s loop=%s", + self.__class__.__name__, + method_name, + self._instance_id, + end_epoch, + elapsed_ms, + thread_id, + loop_id, + ) def __getstate__(self) -> Dict[str, Any]: return {"_instance_id": self._instance_id} diff --git a/comfy/samplers.py b/comfy/samplers.py index 8c1d4cc61..6daf13ede 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -208,13 +208,18 @@ def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tenso return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options) def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): + print("CC_DEBUG2: enter _calc_cond_batch_outer", flush=True) executor = comfy.patcher_extension.WrapperExecutor.new_executor( _calc_cond_batch, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) ) - return executor.execute(model, conds, x_in, timestep, model_options) + print("CC_DEBUG2: before _calc_cond_batch executor.execute", flush=True) + result = executor.execute(model, conds, x_in, timestep, model_options) + print("CC_DEBUG2: after _calc_cond_batch executor.execute", flush=True) + return result def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): + print("CC_DEBUG2: enter _calc_cond_batch", flush=True) isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" out_conds = [] out_counts = [] @@ -247,7 +252,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) + print("CC_DEBUG: before prepare_state", flush=True) model.current_patcher.prepare_state(timestep) + print("CC_DEBUG: after prepare_state", flush=True) # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): @@ -262,7 +269,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] + print("CC_DEBUG: before get_free_memory", flush=True) free_memory = model.current_patcher.get_free_memory(x_in.device) + print("CC_DEBUG: after get_free_memory", flush=True) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] @@ -272,7 +281,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for k, v in to_run[tt][0].conditioning.items(): cond_shapes[k].append(v.size()) - if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: + print("CC_DEBUG: before memory_required", flush=True) + memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes) + print("CC_DEBUG: after memory_required", flush=True) + if memory_required * 1.5 < free_memory: to_batch = batch_amount break @@ -309,7 +321,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens else: timestep_ = torch.cat([timestep] * batch_chunks) + print("CC_DEBUG: before apply_hooks", flush=True) transformer_options = model.current_patcher.apply_hooks(hooks=hooks) + print("CC_DEBUG: after apply_hooks", flush=True) if 'transformer_options' in model_options: transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, model_options['transformer_options'], @@ -331,9 +345,13 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) if 'model_function_wrapper' in model_options: + print("CC_DEBUG: before apply_model", flush=True) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + print("CC_DEBUG: after apply_model", flush=True) else: + print("CC_DEBUG: before apply_model", flush=True) output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + print("CC_DEBUG: after apply_model", flush=True) for o in range(batch_chunks): cond_index = cond_or_uncond[o] @@ -768,6 +786,7 @@ class KSAMPLER(Sampler): self.inpaint_options = inpaint_options def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + print("CC_DEBUG3: enter KSAMPLER.sample", flush=True) extra_args["denoise_mask"] = denoise_mask model_k = KSamplerX0Inpaint(model_wrap, sigmas) model_k.latent_image = latent_image @@ -777,15 +796,46 @@ class KSAMPLER(Sampler): else: model_k.noise = noise - noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) + print("CC_DEBUG3: before max_denoise", flush=True) + max_denoise = self.max_denoise(model_wrap, sigmas) + print("CC_DEBUG3: after max_denoise", flush=True) + print("CC_DEBUG3: before model_sampling_attr", flush=True) + model_sampling = model_wrap.inner_model.model_sampling + print( + "CC_DEBUG3: after model_sampling_attr type=%s id=%s instance_id=%s" + % ( + type(model_sampling).__name__, + id(model_sampling), + getattr(model_sampling, "_instance_id", "n/a"), + ), + flush=True, + ) + print("CC_DEBUG3: before noise_scaling_call", flush=True) + try: + noise_scaled = model_sampling.noise_scaling( + sigmas[0], noise, latent_image, max_denoise + ) + print("CC_DEBUG3: after noise_scaling_call", flush=True) + except Exception as e: + print( + "CC_DEBUG3: noise_scaling_exception type=%s msg=%s" + % (type(e).__name__, str(e)), + flush=True, + ) + raise + noise = noise_scaled + print("CC_DEBUG3: after noise_assignment", flush=True) k_callback = None total_steps = len(sigmas) - 1 if callback is not None: k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + print("CC_DEBUG3: before sampler_function", flush=True) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + print("CC_DEBUG3: after sampler_function", flush=True) samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) + print("CC_DEBUG3: after inverse_noise_scaling", flush=True) return samples @@ -825,10 +875,16 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N for k in conds: calculate_start_end_timesteps(model, conds[k]) + print('GP_DEBUG: before hasattr extra_conds', flush=True) + print('GP_DEBUG: before hasattr extra_conds', flush=True) if hasattr(model, 'extra_conds'): + print('GP_DEBUG: has extra_conds!', flush=True) + print('GP_DEBUG: has extra_conds!', flush=True) for k in conds: conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes) + print('GP_DEBUG: before area make sure loop', flush=True) + print('GP_DEBUG: before area make sure loop', flush=True) #make sure each cond area has an opposite one with the same area for k in conds: for c in conds[k]: @@ -842,8 +898,11 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N for hook in c['hooks'].hooks: hook.initialize_timesteps(model) + print('GP_DEBUG: before pre_run_control loop', flush=True) for k in conds: + print('GP_DEBUG: calling pre_run_control for key:', k, flush=True) pre_run_control(model, conds[k]) + print('GP_DEBUG: after pre_run_control loop', flush=True) if "positive" in conds: positive = conds["positive"] @@ -1005,6 +1064,8 @@ class CFGGuider: self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes) + print("GP_DEBUG: process_conds finished", flush=True) + print("GP_DEBUG: process_conds finished", flush=True) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas extra_args = {"model_options": extra_model_options, "seed": seed} @@ -1014,6 +1075,7 @@ class CFGGuider: sampler, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True) ) + print("GP_DEBUG: before executor.execute", flush=True) samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) diff --git a/execution.py b/execution.py index aceb75340..d04498a0d 100644 --- a/execution.py +++ b/execution.py @@ -696,6 +696,24 @@ class PromptExecutor: except Exception: logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True) + async def _wait_model_patcher_quiescence_safe( + self, + *, + fail_loud: bool = False, + timeout_ms: int = 120000, + marker: str = "EX:wait_model_patcher_idle", + ) -> None: + try: + from comfy.isolation import wait_for_model_patcher_quiescence + + await wait_for_model_patcher_quiescence( + timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker + ) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True) + def add_message(self, event, data: dict, broadcast: bool): data = { **data, @@ -766,6 +784,11 @@ class PromptExecutor: # Boundary cleanup runs at the start of the next workflow in # isolation mode, matching non-isolated "next prompt" timing. self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + await self._wait_model_patcher_quiescence_safe( + fail_loud=False, + timeout_ms=120000, + marker="EX:boundary_cleanup_wait_idle", + ) await self._flush_running_extensions_transport_state_safe() comfy.model_management.unload_all_models() comfy.model_management.cleanup_models_gc() @@ -806,6 +829,11 @@ class PromptExecutor: for node_id in execution_list.pendingNodes.keys(): class_type = dynamic_prompt.get_node(node_id)["class_type"] pending_class_types.add(class_type) + await self._wait_model_patcher_quiescence_safe( + fail_loud=True, + timeout_ms=120000, + marker="EX:notify_graph_wait_idle", + ) await self._notify_execution_graph_safe(pending_class_types, fail_loud=True) while not execution_list.is_empty():