isolation+dynamicvram: stabilize ModelPatcher RPC path, add diagnostics; known process_latent_in timeout remains
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

- harden isolation ModelPatcher proxy/registry behavior for DynamicVRAM-backed patchers
- improve serializer/adapter boundaries (device/dtype/model refs) to reduce pre-inference lockups
- add structured ISO registry/modelsampling telemetry and explicit RPC timeout surfacing
- preserve isolation-first lifecycle handling and boundary cleanup sequencing
- validate isolated workflows: most targeted runs now complete under
  --use-sage-attention --use-process-isolation --disable-cuda-malloc

Known issue (reproducible):
- isolation_99_full_iso_stack still times out at SamplerCustom_ISO path
- failure is explicit RPC timeout:
  ModelPatcherProxy.process_latent_in(instance_id=model_0, timeout_ms=120000)
- this indicates the remaining stall is on process_latent_in RPC path, not generic startup/manager fetch
This commit is contained in:
John Pollock 2026-03-04 10:41:33 -06:00
parent 6ba03a0747
commit 26edd5663d
9 changed files with 927 additions and 154 deletions

View File

@ -26,6 +26,7 @@ PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(__name__)
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
def initialize_proxies() -> None:
@ -168,6 +169,11 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"""Evict running extensions not needed for current execution."""
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:notify_graph_wait_idle",
)
async def _stop_extension(
ext_name: str, extension: "ComfyNodeExtension", reason: str
@ -237,6 +243,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
async def flush_running_extensions_transport_state() -> int:
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:flush_transport_wait_idle",
)
total_flushed = 0
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
flush_fn = getattr(extension, "flush_transport_state", None)
@ -263,6 +274,50 @@ async def flush_running_extensions_transport_state() -> int:
return total_flushed
async def wait_for_model_patcher_quiescence(
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
*,
fail_loud: bool = False,
marker: str = "ISO:wait_model_patcher_idle",
) -> bool:
try:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
registry = ModelPatcherRegistry()
start = time.perf_counter()
idle = await registry.wait_all_idle(timeout_ms)
elapsed_ms = (time.perf_counter() - start) * 1000.0
if idle:
logger.debug(
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
)
return True
states = await registry.get_all_operation_states()
logger.error(
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
states,
)
if fail_loud:
raise TimeoutError(
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
)
return False
except Exception:
if fail_loud:
raise
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
return False
def get_claimed_paths() -> Set[Path]:
return _CLAIMED_PATHS
@ -320,6 +375,7 @@ __all__ = [
"await_isolation_loading",
"notify_execution_graph",
"flush_running_extensions_transport_state",
"wait_for_model_patcher_quiescence",
"get_claimed_paths",
"update_rpc_event_loops",
"IsolatedNodeSpec",

View File

@ -83,6 +83,33 @@ class ComfyUIAdapter(IsolationAdapter):
logging.getLogger(pkg_name).setLevel(logging.ERROR)
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
import torch
def serialize_device(obj: Any) -> Dict[str, Any]:
return {"__type__": "device", "device_str": str(obj)}
def deserialize_device(data: Dict[str, Any]) -> Any:
return torch.device(data["device_str"])
registry.register("device", serialize_device, deserialize_device)
_VALID_DTYPES = {
"float16", "float32", "float64", "bfloat16",
"int8", "int16", "int32", "int64",
"uint8", "bool",
}
def serialize_dtype(obj: Any) -> Dict[str, Any]:
return {"__type__": "dtype", "dtype_str": str(obj)}
def deserialize_dtype(data: Dict[str, Any]) -> Any:
dtype_name = data["dtype_str"].replace("torch.", "")
if dtype_name not in _VALID_DTYPES:
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
return getattr(torch, dtype_name)
registry.register("dtype", serialize_dtype, deserialize_dtype)
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
@ -193,6 +220,10 @@ class ComfyUIAdapter(IsolationAdapter):
f"ModelSampling in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}"
)
# Host-side pass-through for proxies: do not re-register a proxy as a
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
ms_id = ModelSamplingRegistry().register(obj)
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
@ -211,22 +242,21 @@ class ComfyUIAdapter(IsolationAdapter):
else:
return ModelSamplingRegistry()._get_instance(data["ms_id"])
# Register ModelSampling type and proxy
registry.register(
"ModelSamplingDiscrete",
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingContinuousEDM",
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingContinuousV",
serialize_model_sampling,
deserialize_model_sampling,
)
# Register all ModelSampling* and StableCascadeSampling classes dynamically
import comfy.model_sampling
for ms_cls in vars(comfy.model_sampling).values():
if not isinstance(ms_cls, type):
continue
if not issubclass(ms_cls, torch.nn.Module):
continue
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
continue
registry.register(
ms_cls.__name__,
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
)

View File

@ -382,6 +382,10 @@ class ComfyNodeExtension(ExtensionBase):
if type(result).__name__ == "NodeOutput":
result = result.args
print(
f"{LOG_PREFIX} ISO:child_result_ready node={node_name} type={type(result).__name__}",
flush=True,
)
if self._is_comfy_protocol_return(result):
logger.debug(
"%s ISO:child_execute_done ext=%s node=%s protocol_return=1",
@ -389,10 +393,17 @@ class ComfyNodeExtension(ExtensionBase):
getattr(self, "name", "?"),
node_name,
)
return self._wrap_unpicklable_objects(result)
print(f"{LOG_PREFIX} ISO:child_wrap_start node={node_name} protocol=1", flush=True)
wrapped = self._wrap_unpicklable_objects(result)
print(f"{LOG_PREFIX} ISO:child_wrap_done node={node_name} protocol=1", flush=True)
return wrapped
if not isinstance(result, tuple):
result = (result,)
print(
f"{LOG_PREFIX} ISO:child_result_tuple node={node_name} outputs={len(result)}",
flush=True,
)
logger.debug(
"%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d",
LOG_PREFIX,
@ -400,7 +411,10 @@ class ComfyNodeExtension(ExtensionBase):
node_name,
len(result),
)
return self._wrap_unpicklable_objects(result)
print(f"{LOG_PREFIX} ISO:child_wrap_start node={node_name} protocol=0", flush=True)
wrapped = self._wrap_unpicklable_objects(result)
print(f"{LOG_PREFIX} ISO:child_wrap_done node={node_name} protocol=0", flush=True)
return wrapped
async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
@ -443,7 +457,10 @@ class ComfyNodeExtension(ExtensionBase):
if isinstance(data, (str, int, float, bool, type(None))):
return data
if isinstance(data, torch.Tensor):
return data.detach() if data.requires_grad else data
tensor = data.detach() if data.requires_grad else data
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
return tensor.cpu()
return tensor
# Special-case clip vision outputs: preserve attribute access by packing fields
if hasattr(data, "penultimate_hidden_states") or hasattr(

View File

@ -338,6 +338,34 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def apply_model(self, *args, **kwargs) -> Any:
import torch
def _preferred_device() -> Any:
for value in args:
if isinstance(value, torch.Tensor):
return value.device
for value in kwargs.values():
if isinstance(value, torch.Tensor):
return value.device
return None
def _move_result_to_device(obj: Any, device: Any) -> Any:
if device is None:
return obj
if isinstance(obj, torch.Tensor):
return obj.to(device) if obj.device != device else obj
if isinstance(obj, dict):
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
if isinstance(obj, list):
return [_move_result_to_device(v, device) for v in obj]
if isinstance(obj, tuple):
return tuple(_move_result_to_device(v, device) for v in obj)
return obj
# DynamicVRAM models must keep load/offload decisions in host process.
# Child-side CUDA staging here can deadlock before first inference RPC.
if self.is_dynamic():
out = self._call_rpc("inner_model_apply_model", args, kwargs)
return _move_result_to_device(out, _preferred_device())
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
self._ensure_apply_model_headroom(required_bytes)
@ -360,7 +388,8 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
return _move_result_to_device(out, _preferred_device())
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
keys = self._call_rpc("model_state_dict", filter_prefix)
@ -526,6 +555,13 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def memory_required(self, input_shape: Any) -> Any:
return self._call_rpc("memory_required", input_shape)
def get_operation_state(self) -> Dict[str, Any]:
state = self._call_rpc("get_operation_state")
return state if isinstance(state, dict) else {}
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
return bool(self._call_rpc("wait_for_idle", timeout_ms))
def is_dynamic(self) -> bool:
return bool(self._call_rpc("is_dynamic"))
@ -771,6 +807,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy):
self._parent = parent
self._model_sampling = None
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
@ -793,7 +830,11 @@ class _InnerModelProxy:
manage_lifecycle=False,
)
if name == "model_sampling":
return self._parent._call_rpc("get_model_object", "model_sampling")
if self._model_sampling is None:
self._model_sampling = self._parent._call_rpc(
"get_model_object", "model_sampling"
)
return self._model_sampling
if name == "extra_conds_shapes":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds_shapes", a, k

View File

@ -2,8 +2,12 @@
# RPC server for ModelPatcher isolation (child process)
from __future__ import annotations
import asyncio
import gc
import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional, List
try:
@ -43,12 +47,191 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__)
@dataclass
class _OperationState:
lease: threading.Lock = field(default_factory=threading.Lock)
active_count: int = 0
active_by_method: dict[str, int] = field(default_factory=dict)
total_operations: int = 0
last_method: Optional[str] = None
last_started_ts: Optional[float] = None
last_ended_ts: Optional[float] = None
last_elapsed_ms: Optional[float] = None
last_error: Optional[str] = None
last_thread_id: Optional[int] = None
last_loop_id: Optional[int] = None
class ModelPatcherRegistry(BaseRegistry[Any]):
_type_prefix = "model"
def __init__(self) -> None:
super().__init__()
self._pending_cleanup_ids: set[str] = set()
self._operation_states: dict[str, _OperationState] = {}
self._operation_state_cv = threading.Condition(self._lock)
def _get_or_create_operation_state(self, instance_id: str) -> _OperationState:
state = self._operation_states.get(instance_id)
if state is None:
state = _OperationState()
self._operation_states[instance_id] = state
return state
def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]:
start_epoch = time.time()
start_perf = time.perf_counter()
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
state.active_count += 1
state.active_by_method[method_name] = (
state.active_by_method.get(method_name, 0) + 1
)
state.total_operations += 1
state.last_method = method_name
state.last_started_ts = start_epoch
state.last_thread_id = threading.get_ident()
try:
state.last_loop_id = id(asyncio.get_running_loop())
except RuntimeError:
state.last_loop_id = None
logger.debug(
"ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s",
instance_id,
method_name,
start_epoch,
threading.get_ident(),
state.last_loop_id,
)
return start_epoch, start_perf
def _end_operation(
self,
instance_id: str,
method_name: str,
start_perf: float,
error: Optional[BaseException] = None,
) -> None:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
state.active_count = max(0, state.active_count - 1)
if method_name in state.active_by_method:
remaining = state.active_by_method[method_name] - 1
if remaining <= 0:
state.active_by_method.pop(method_name, None)
else:
state.active_by_method[method_name] = remaining
state.last_ended_ts = end_epoch
state.last_elapsed_ms = elapsed_ms
state.last_error = None if error is None else repr(error)
if state.active_count == 0:
self._operation_state_cv.notify_all()
logger.debug(
"ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s",
instance_id,
method_name,
end_epoch,
elapsed_ms,
None if error is None else type(error).__name__,
)
def _run_operation_with_lease(self, instance_id: str, method_name: str, fn):
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
lease = state.lease
with lease:
_, start_perf = self._begin_operation(instance_id, method_name)
try:
result = fn()
except Exception as exc:
self._end_operation(instance_id, method_name, start_perf, error=exc)
raise
self._end_operation(instance_id, method_name, start_perf)
return result
def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]:
with self._operation_state_cv:
state = self._operation_states.get(instance_id)
if state is None:
return {
"instance_id": instance_id,
"active_count": 0,
"active_methods": [],
"total_operations": 0,
"last_method": None,
"last_started_ts": None,
"last_ended_ts": None,
"last_elapsed_ms": None,
"last_error": None,
"last_thread_id": None,
"last_loop_id": None,
}
return {
"instance_id": instance_id,
"active_count": state.active_count,
"active_methods": sorted(state.active_by_method.keys()),
"total_operations": state.total_operations,
"last_method": state.last_method,
"last_started_ts": state.last_started_ts,
"last_ended_ts": state.last_ended_ts,
"last_elapsed_ms": state.last_elapsed_ms,
"last_error": state.last_error,
"last_thread_id": state.last_thread_id,
"last_loop_id": state.last_loop_id,
}
def unregister_sync(self, instance_id: str) -> None:
with self._operation_state_cv:
instance = self._registry.pop(instance_id, None)
if instance is not None:
self._id_map.pop(id(instance), None)
self._pending_cleanup_ids.discard(instance_id)
self._operation_states.pop(instance_id, None)
self._operation_state_cv.notify_all()
async def get_operation_state(self, instance_id: str) -> dict[str, Any]:
return self._snapshot_operation_state(instance_id)
async def get_all_operation_states(self) -> dict[str, dict[str, Any]]:
with self._operation_state_cv:
ids = sorted(self._operation_states.keys())
return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids}
async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool:
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
with self._operation_state_cv:
while True:
active = self._operation_states.get(instance_id)
if active is None or active.active_count == 0:
return True
if deadline is None:
self._operation_state_cv.wait()
continue
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
self._operation_state_cv.wait(timeout=remaining)
async def wait_all_idle(self, timeout_ms: int = 0) -> bool:
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
with self._operation_state_cv:
while True:
has_active = any(
state.active_count > 0 for state in self._operation_states.values()
)
if not has_active:
return True
if deadline is None:
self._operation_state_cv.wait()
continue
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
self._operation_state_cv.wait(timeout=remaining)
async def clone(self, instance_id: str) -> str:
instance = self._get_instance(instance_id)
@ -62,8 +245,10 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return False
async def get_model_object(self, instance_id: str, name: str) -> Any:
print(f"GP_DEBUG: get_model_object START for name={name}", flush=True)
instance = self._get_instance(instance_id)
if name == "model":
print(f"GP_DEBUG: get_model_object END for name={name} (ModelObject)", flush=True)
return f"<ModelObject: {type(instance.model).__name__}>"
result = instance.get_model_object(name)
if name == "model_sampling":
@ -73,8 +258,16 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
)
registry = ModelSamplingRegistry()
sampling_id = registry.register(result)
# Preserve identity when upstream already returned a proxy. Re-registering
# a proxy object creates proxy-of-proxy call chains.
if isinstance(result, ModelSamplingProxy):
sampling_id = result._instance_id
else:
sampling_id = registry.register(result)
print(f"GP_DEBUG: get_model_object END for name={name} (model_sampling)", flush=True)
return ModelSamplingProxy(sampling_id, registry)
print(f"GP_DEBUG: get_model_object END for name={name} (fallthrough)", flush=True)
return detach_if_grad(result)
async def get_model_options(self, instance_id: str) -> dict:
@ -163,7 +356,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return self._get_instance(instance_id).lowvram_patch_counter()
async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
return self._get_instance(instance_id).memory_required(input_shape)
return self._run_operation_with_lease(
instance_id,
"memory_required",
lambda: self._get_instance(instance_id).memory_required(input_shape),
)
async def is_dynamic(self, instance_id: str) -> bool:
instance = self._get_instance(instance_id)
@ -186,7 +383,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return None
async def model_dtype(self, instance_id: str) -> Any:
return self._get_instance(instance_id).model_dtype()
return self._run_operation_with_lease(
instance_id,
"model_dtype",
lambda: self._get_instance(instance_id).model_dtype(),
)
async def model_patches_to(self, instance_id: str, device: Any) -> Any:
return self._get_instance(instance_id).model_patches_to(device)
@ -198,8 +399,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
extra_memory: Any,
force_patch_weights: bool = False,
) -> Any:
return self._get_instance(instance_id).partially_load(
device, extra_memory, force_patch_weights=force_patch_weights
return self._run_operation_with_lease(
instance_id,
"partially_load",
lambda: self._get_instance(instance_id).partially_load(
device, extra_memory, force_patch_weights=force_patch_weights
),
)
async def partially_unload(
@ -209,8 +414,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
memory_to_free: int = 0,
force_patch_weights: bool = False,
) -> int:
return self._get_instance(instance_id).partially_unload(
device_to, memory_to_free, force_patch_weights
return self._run_operation_with_lease(
instance_id,
"partially_unload",
lambda: self._get_instance(instance_id).partially_unload(
device_to, memory_to_free, force_patch_weights
),
)
async def load(
@ -221,8 +430,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
force_patch_weights: bool = False,
full_load: bool = False,
) -> None:
self._get_instance(instance_id).load(
device_to, lowvram_model_memory, force_patch_weights, full_load
self._run_operation_with_lease(
instance_id,
"load",
lambda: self._get_instance(instance_id).load(
device_to, lowvram_model_memory, force_patch_weights, full_load
),
)
async def patch_model(
@ -233,20 +446,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
load_weights: bool = True,
force_patch_weights: bool = False,
) -> None:
try:
self._get_instance(instance_id).patch_model(
device_to, lowvram_model_memory, load_weights, force_patch_weights
)
except AttributeError as e:
logger.error(
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
)
return
def _invoke() -> None:
try:
self._get_instance(instance_id).patch_model(
device_to, lowvram_model_memory, load_weights, force_patch_weights
)
except AttributeError as e:
logger.error(
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
)
return
self._run_operation_with_lease(instance_id, "patch_model", _invoke)
async def unpatch_model(
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
) -> None:
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights)
self._run_operation_with_lease(
instance_id,
"unpatch_model",
lambda: self._get_instance(instance_id).unpatch_model(
device_to, unpatch_weights
),
)
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
self._get_instance(instance_id).detach(unpatch_all)
@ -262,26 +484,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
self._get_instance(instance_id).pre_run()
async def cleanup(self, instance_id: str) -> None:
try:
instance = self._get_instance(instance_id)
except Exception:
logger.debug(
"ModelPatcher cleanup requested for missing instance %s",
instance_id,
exc_info=True,
)
return
def _invoke() -> None:
try:
instance = self._get_instance(instance_id)
except Exception:
logger.debug(
"ModelPatcher cleanup requested for missing instance %s",
instance_id,
exc_info=True,
)
return
try:
instance.cleanup()
finally:
with self._lock:
self._pending_cleanup_ids.add(instance_id)
gc.collect()
try:
instance.cleanup()
finally:
with self._lock:
self._pending_cleanup_ids.add(instance_id)
gc.collect()
self._run_operation_with_lease(instance_id, "cleanup", _invoke)
def sweep_pending_cleanup(self) -> int:
removed = 0
with self._lock:
with self._operation_state_cv:
pending_ids = list(self._pending_cleanup_ids)
self._pending_cleanup_ids.clear()
for instance_id in pending_ids:
@ -289,17 +514,21 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
if instance is None:
continue
self._id_map.pop(id(instance), None)
self._operation_states.pop(instance_id, None)
removed += 1
self._operation_state_cv.notify_all()
gc.collect()
return removed
def purge_all(self) -> int:
with self._lock:
with self._operation_state_cv:
removed = len(self._registry)
self._registry.clear()
self._id_map.clear()
self._pending_cleanup_ids.clear()
self._operation_states.clear()
self._operation_state_cv.notify_all()
gc.collect()
return removed
@ -743,17 +972,52 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def inner_model_memory_required(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.memory_required(*args, **kwargs)
return self._run_operation_with_lease(
instance_id,
"inner_model_memory_required",
lambda: self._get_instance(instance_id).model.memory_required(
*args, **kwargs
),
)
async def inner_model_extra_conds_shapes(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs)
return self._run_operation_with_lease(
instance_id,
"inner_model_extra_conds_shapes",
lambda: self._get_instance(instance_id).model.extra_conds_shapes(
*args, **kwargs
),
)
async def inner_model_extra_conds(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
def _invoke() -> Any:
result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
try:
import torch
import comfy.conds
except Exception:
return result
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
if isinstance(obj, comfy.conds.CONDRegular):
return type(obj)(_to_cpu(obj.cond))
return obj
return _to_cpu(result)
return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke)
async def inner_model_state_dict(
self, instance_id: str, args: tuple, kwargs: dict
@ -767,82 +1031,160 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def inner_model_apply_model(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
target = getattr(instance, "load_device", None)
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _invoke() -> Any:
import torch
def _move(obj):
if target is None:
instance = self._get_instance(instance_id)
target = getattr(instance, "load_device", None)
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _move(obj):
if target is None:
return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj
moved_args = tuple(_move(a) for a in args)
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
result = instance.model.apply_model(*moved_args, **moved_kwargs)
return detach_if_grad(_move(result))
moved_args = tuple(_move(a) for a in args)
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
result = instance.model.apply_model(*moved_args, **moved_kwargs)
moved_result = detach_if_grad(_move(result))
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
# at the transport boundary. Marshal dynamic-path results as CPU and let
# the proxy restore device placement in the child process.
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke)
async def process_latent_in(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
return self._run_operation_with_lease(
instance_id,
"process_latent_in",
lambda: detach_if_grad(
self._get_instance(instance_id).model.process_latent_in(
*args, **kwargs
)
),
)
async def process_latent_out(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
result = instance.model.process_latent_out(*args, **kwargs)
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
return detach_if_grad(result.to(target))
except Exception:
logger.debug(
"process_latent_out: failed to move result to target device",
exc_info=True,
)
return detach_if_grad(result)
import torch
def _invoke() -> Any:
instance = self._get_instance(instance_id)
result = instance.model.process_latent_out(*args, **kwargs)
moved_result = None
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
moved_result = detach_if_grad(result.to(target))
except Exception:
logger.debug(
"process_latent_out: failed to move result to target device",
exc_info=True,
)
if moved_result is None:
moved_result = detach_if_grad(result)
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke)
async def scale_latent_inpaint(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
result = instance.model.scale_latent_inpaint(*args, **kwargs)
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
return detach_if_grad(result.to(target))
except Exception:
logger.debug(
"scale_latent_inpaint: failed to move result to target device",
exc_info=True,
)
return detach_if_grad(result)
import torch
def _invoke() -> Any:
instance = self._get_instance(instance_id)
result = instance.model.scale_latent_inpaint(*args, **kwargs)
moved_result = None
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
moved_result = detach_if_grad(result.to(target))
except Exception:
logger.debug(
"scale_latent_inpaint: failed to move result to target device",
exc_info=True,
)
if moved_result is None:
moved_result = detach_if_grad(result)
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(
instance_id, "scale_latent_inpaint", _invoke
)
async def load_lora(
self,

View File

@ -3,6 +3,9 @@ from __future__ import annotations
import asyncio
import logging
import os
import threading
import time
from typing import Any
from comfy.isolation.proxies.base import (
@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__)
def _describe_value(obj: Any) -> str:
try:
import torch
except Exception:
torch = None
try:
if torch is not None and isinstance(obj, torch.Tensor):
return (
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
)
except Exception:
pass
return "%s(id=%s)" % (type(obj).__name__, id(obj))
def _prefer_device(*tensors: Any) -> Any:
try:
import torch
@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any:
return obj
def _to_cpu_for_rpc(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
t = obj.detach() if obj.requires_grad else obj
if t.is_cuda:
return t.to("cpu")
return t
if isinstance(obj, (list, tuple)):
converted = [_to_cpu_for_rpc(x) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
return obj
class ModelSamplingRegistry(BaseRegistry[Any]):
_type_prefix = "modelsampling"
@ -196,17 +233,93 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
return self._rpc_caller
def _call(self, method_name: str, *args: Any) -> Any:
print(
"ISO:modelsampling_call_enter method=%s instance_id=%s pid=%s"
% (method_name, self._instance_id, os.getpid()),
flush=True,
)
rpc = self._get_rpc()
method = getattr(rpc, method_name)
print(
"ISO:modelsampling_call_before_dispatch method=%s instance_id=%s pid=%s"
% (method_name, self._instance_id, os.getpid()),
flush=True,
)
result = method(self._instance_id, *args)
timeout_ms = self._rpc_timeout_ms()
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
call_id = "%s:%s:%s:%.6f" % (
self._instance_id,
method_name,
thread_id,
start_perf,
)
logger.debug(
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
method_name,
self._instance_id,
call_id,
start_epoch,
thread_id,
timeout_ms,
)
if asyncio.iscoroutine(result):
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(result)
out = run_coro_in_new_loop(result)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(result)
return result
out = loop.run_until_complete(result)
else:
out = result
print(
"ISO:modelsampling_call_after_dispatch method=%s instance_id=%s pid=%s"
% (method_name, self._instance_id, os.getpid()),
flush=True,
)
logger.debug(
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
method_name,
self._instance_id,
call_id,
_describe_value(out),
)
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
method_name,
self._instance_id,
call_id,
elapsed_ms,
thread_id,
)
logger.debug(
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
method_name,
self._instance_id,
call_id,
)
print(
"ISO:modelsampling_call_return method=%s instance_id=%s pid=%s"
% (method_name, self._instance_id, os.getpid()),
flush=True,
)
return out
@staticmethod
def _rpc_timeout_ms() -> int:
raw = os.environ.get(
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
)
try:
timeout_ms = int(raw)
except ValueError:
timeout_ms = 30000
return max(1, timeout_ms)
@property
def sigma_min(self) -> Any:
@ -235,10 +348,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
def noise_scaling(
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
) -> Any:
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
preferred_device = _prefer_device(noise, latent_image)
out = self._call(
"noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(noise),
_to_cpu_for_rpc(latent_image),
max_denoise,
)
return _to_device(out, preferred_device)
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
return self._call("inverse_noise_scaling", sigma, latent)
preferred_device = _prefer_device(latent)
out = self._call(
"inverse_noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(latent),
)
return _to_device(out, preferred_device)
def timestep(self, sigma: Any) -> Any:
return self._call("timestep", sigma)

View File

@ -2,9 +2,11 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import logging
import os
import threading
import time
import weakref
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
@ -119,6 +121,24 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base"
_TIMEOUT_RPC_METHODS = frozenset(
{
"partially_load",
"partially_unload",
"load",
"patch_model",
"unpatch_model",
"inner_model_apply_model",
"memory_required",
"model_dtype",
"inner_model_memory_required",
"inner_model_extra_conds_shapes",
"inner_model_extra_conds",
"process_latent_in",
"process_latent_out",
"scale_latent_inpaint",
}
)
def __init__(
self,
@ -148,39 +168,89 @@ class BaseProxy(Generic[T]):
)
return self._rpc_caller
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
if method_name not in self._TIMEOUT_RPC_METHODS:
return None
try:
timeout_ms = int(
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
)
except ValueError:
timeout_ms = 120000
return max(1, timeout_ms)
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
coro = method(self._instance_id, *args, **kwargs)
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
# If we are already in the global loop, we can't block on it?
# Actually, this method is synchronous (__getattr__ -> lambda).
# If called from async context in main loop, we need to handle that.
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
# We are in the main loop. We cannot await/block here if we are just a sync function.
# But proxies are often called from sync code.
# If called from sync code in main loop, creating a new loop is bad.
# But we can't await `coro`.
# This implies proxies MUST be awaited if called from async context?
# Existing code used `run_coro_in_new_loop` which is weird.
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
# we use run_coroutine_threadsafe.
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result()
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
try:
running_loop = asyncio.get_running_loop()
loop_id: Optional[int] = id(running_loop)
except RuntimeError:
loop_id = None
logger.debug(
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
"thread=%s loop=%s timeout_ms=%s",
self.__class__.__name__,
method_name,
self._instance_id,
start_epoch,
thread_id,
loop_id,
timeout_ms,
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
finally:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
"elapsed_ms=%.3f thread=%s loop=%s",
self.__class__.__name__,
method_name,
self._instance_id,
end_epoch,
elapsed_ms,
thread_id,
loop_id,
)
def __getstate__(self) -> Dict[str, Any]:
return {"_instance_id": self._instance_id}

View File

@ -208,13 +208,18 @@ def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tenso
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
print("CC_DEBUG2: enter _calc_cond_batch_outer", flush=True)
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
print("CC_DEBUG2: before _calc_cond_batch executor.execute", flush=True)
result = executor.execute(model, conds, x_in, timestep, model_options)
print("CC_DEBUG2: after _calc_cond_batch executor.execute", flush=True)
return result
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
print("CC_DEBUG2: enter _calc_cond_batch", flush=True)
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
out_conds = []
out_counts = []
@ -247,7 +252,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
print("CC_DEBUG: before prepare_state", flush=True)
model.current_patcher.prepare_state(timestep)
print("CC_DEBUG: after prepare_state", flush=True)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@ -262,7 +269,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
print("CC_DEBUG: before get_free_memory", flush=True)
free_memory = model.current_patcher.get_free_memory(x_in.device)
print("CC_DEBUG: after get_free_memory", flush=True)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
@ -272,7 +281,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
print("CC_DEBUG: before memory_required", flush=True)
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
print("CC_DEBUG: after memory_required", flush=True)
if memory_required * 1.5 < free_memory:
to_batch = batch_amount
break
@ -309,7 +321,9 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
else:
timestep_ = torch.cat([timestep] * batch_chunks)
print("CC_DEBUG: before apply_hooks", flush=True)
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
print("CC_DEBUG: after apply_hooks", flush=True)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
@ -331,9 +345,13 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
print("CC_DEBUG: before apply_model", flush=True)
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
print("CC_DEBUG: after apply_model", flush=True)
else:
print("CC_DEBUG: before apply_model", flush=True)
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
print("CC_DEBUG: after apply_model", flush=True)
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
@ -768,6 +786,7 @@ class KSAMPLER(Sampler):
self.inpaint_options = inpaint_options
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
print("CC_DEBUG3: enter KSAMPLER.sample", flush=True)
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
model_k.latent_image = latent_image
@ -777,15 +796,46 @@ class KSAMPLER(Sampler):
else:
model_k.noise = noise
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
print("CC_DEBUG3: before max_denoise", flush=True)
max_denoise = self.max_denoise(model_wrap, sigmas)
print("CC_DEBUG3: after max_denoise", flush=True)
print("CC_DEBUG3: before model_sampling_attr", flush=True)
model_sampling = model_wrap.inner_model.model_sampling
print(
"CC_DEBUG3: after model_sampling_attr type=%s id=%s instance_id=%s"
% (
type(model_sampling).__name__,
id(model_sampling),
getattr(model_sampling, "_instance_id", "n/a"),
),
flush=True,
)
print("CC_DEBUG3: before noise_scaling_call", flush=True)
try:
noise_scaled = model_sampling.noise_scaling(
sigmas[0], noise, latent_image, max_denoise
)
print("CC_DEBUG3: after noise_scaling_call", flush=True)
except Exception as e:
print(
"CC_DEBUG3: noise_scaling_exception type=%s msg=%s"
% (type(e).__name__, str(e)),
flush=True,
)
raise
noise = noise_scaled
print("CC_DEBUG3: after noise_assignment", flush=True)
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
print("CC_DEBUG3: before sampler_function", flush=True)
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
print("CC_DEBUG3: after sampler_function", flush=True)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
print("CC_DEBUG3: after inverse_noise_scaling", flush=True)
return samples
@ -825,10 +875,16 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
for k in conds:
calculate_start_end_timesteps(model, conds[k])
print('GP_DEBUG: before hasattr extra_conds', flush=True)
print('GP_DEBUG: before hasattr extra_conds', flush=True)
if hasattr(model, 'extra_conds'):
print('GP_DEBUG: has extra_conds!', flush=True)
print('GP_DEBUG: has extra_conds!', flush=True)
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
print('GP_DEBUG: before area make sure loop', flush=True)
print('GP_DEBUG: before area make sure loop', flush=True)
#make sure each cond area has an opposite one with the same area
for k in conds:
for c in conds[k]:
@ -842,8 +898,11 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
for hook in c['hooks'].hooks:
hook.initialize_timesteps(model)
print('GP_DEBUG: before pre_run_control loop', flush=True)
for k in conds:
print('GP_DEBUG: calling pre_run_control for key:', k, flush=True)
pre_run_control(model, conds[k])
print('GP_DEBUG: after pre_run_control loop', flush=True)
if "positive" in conds:
positive = conds["positive"]
@ -1005,6 +1064,8 @@ class CFGGuider:
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
print("GP_DEBUG: process_conds finished", flush=True)
print("GP_DEBUG: process_conds finished", flush=True)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
extra_args = {"model_options": extra_model_options, "seed": seed}
@ -1014,6 +1075,7 @@ class CFGGuider:
sampler,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
)
print("GP_DEBUG: before executor.execute", flush=True)
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))

View File

@ -696,6 +696,24 @@ class PromptExecutor:
except Exception:
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
async def _wait_model_patcher_quiescence_safe(
self,
*,
fail_loud: bool = False,
timeout_ms: int = 120000,
marker: str = "EX:wait_model_patcher_idle",
) -> None:
try:
from comfy.isolation import wait_for_model_patcher_quiescence
await wait_for_model_patcher_quiescence(
timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker
)
except Exception:
if fail_loud:
raise
logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True)
def add_message(self, event, data: dict, broadcast: bool):
data = {
**data,
@ -766,6 +784,11 @@ class PromptExecutor:
# Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._wait_model_patcher_quiescence_safe(
fail_loud=False,
timeout_ms=120000,
marker="EX:boundary_cleanup_wait_idle",
)
await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc()
@ -806,6 +829,11 @@ class PromptExecutor:
for node_id in execution_list.pendingNodes.keys():
class_type = dynamic_prompt.get_node(node_id)["class_type"]
pending_class_types.add(class_type)
await self._wait_model_patcher_quiescence_safe(
fail_loud=True,
timeout_ms=120000,
marker="EX:notify_graph_wait_idle",
)
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
while not execution_list.is_empty():