From 0c7bc74e82c787ac229582d68200037527e16f6d Mon Sep 17 00:00:00 2001 From: John Pollock Date: Tue, 7 Apr 2026 05:58:31 -0500 Subject: [PATCH] feat(isolation): execution engine integration for isolated workers Wires isolation into ComfyUI's execution pipeline: child process startup in main.py, isolated node dispatch in execution.py with boundary cleanup, graph notification, quiescence waits, and RPC event loop coordination. Integrates with master's try/finally and RAM pressure structures. --- execution.py | 144 ++++++++++++++++++++++++++++++++++++++++++++++----- main.py | 109 ++++++++++++++++++++++++++++---------- 2 files changed, 212 insertions(+), 41 deletions(-) diff --git a/execution.py b/execution.py index 5e02dffb2..fc54c3e66 100644 --- a/execution.py +++ b/execution.py @@ -1,7 +1,9 @@ import copy +import gc import heapq import inspect import logging +import os import sys import threading import time @@ -42,6 +44,8 @@ from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_re from comfy_api.latest import io, _io from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger +_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False + class ExecutionResult(Enum): SUCCESS = 0 @@ -262,20 +266,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no state, just create clone - if is_class(obj): - type_obj = obj - obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(v3_data) - # otherwise, use class instance to populate/reuse some fields + # Check for isolated node - skip validation and class cloning + if hasattr(obj, "_pyisolate_extension"): + # Isolated Node: The stub is just a proxy; real validation happens in child process + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) + # Inject hidden inputs so they're available in the isolated child process + inputs.update(v3_data.get("hidden_inputs", {})) + f = getattr(obj, func) + # Standard V3 Node (Existing Logic) + else: - type_obj = type(obj) - type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) - f = make_locked_method_func(type_obj, func, class_clone) - # in case of dynamic inputs, restructure inputs to expected nested dict - if v3_data is not None: - inputs = _io.build_nested_inputs(inputs, v3_data) + # if is just a class, then assign no resources or state, just create clone + if is_class(obj): + type_obj = obj + obj.VALIDATE_CLASS() + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) + # otherwise, use class instance to populate/reuse some fields + else: + type_obj = type(obj) + type_obj.VALIDATE_CLASS() + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) + f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -537,7 +552,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() - comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None) + if vbar_lib is not None: + comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + else: + global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED + if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED: + logging.warning( + "DynamicVRAM backend unavailable for watermark reset; " + "skipping vbar reset for this process." + ) + _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True if has_pending_tasks: pending_async_nodes[unique_id] = output_data @@ -546,6 +571,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, tasks = [x for x in output_data if isinstance(x, asyncio.Task)] await asyncio.gather(*tasks, return_exceptions=True) unblock() + + # Keep isolation node execution deterministic by default, but allow + # opt-out for diagnostics. + isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes") + if args.use_process_isolation and isolation_sequential: + await await_completion() + return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs) + asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: @@ -657,6 +690,46 @@ class PromptExecutor: self.status_messages = [] self.success = True + async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import notify_execution_graph + await notify_execution_graph(class_types, caches=self.caches.all) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:notify_execution_graph failed", exc_info=True) + + async def _flush_running_extensions_transport_state_safe(self) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import flush_running_extensions_transport_state + await flush_running_extensions_transport_state() + except Exception: + logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True) + + async def _wait_model_patcher_quiescence_safe( + self, + *, + fail_loud: bool = False, + timeout_ms: int = 120000, + marker: str = "EX:wait_model_patcher_idle", + ) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import wait_for_model_patcher_quiescence + + await wait_for_model_patcher_quiescence( + timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker + ) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True) + def add_message(self, event, data: dict, broadcast: bool): data = { **data, @@ -711,6 +784,18 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + if args.use_process_isolation: + # Update RPC event loops for all isolated extensions. + # This is critical for serial workflow execution - each asyncio.run() creates + # a new event loop, and RPC instances must be updated to use it. + try: + from comfy.isolation import update_rpc_event_loops + update_rpc_event_loops() + except ImportError: + pass # Isolation not available + except Exception as e: + logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}") + set_preview_method(extra_data.get("preview_method")) nodes.interrupt_processing(False) @@ -723,6 +808,25 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) + if args.use_process_isolation: + try: + # Boundary cleanup runs at the start of the next workflow in + # isolation mode, matching non-isolated "next prompt" timing. + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + await self._wait_model_patcher_quiescence_safe( + fail_loud=False, + timeout_ms=120000, + marker="EX:boundary_cleanup_wait_idle", + ) + await self._flush_running_extensions_transport_state_safe() + comfy.model_management.unload_all_models() + comfy.model_management.cleanup_models_gc() + comfy.model_management.cleanup_models() + gc.collect() + comfy.model_management.soft_empty_cache() + except Exception: + logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True) + self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None @@ -760,6 +864,18 @@ class PromptExecutor: for node_id in list(execute_outputs): execution_list.add_node(node_id) + if args.use_process_isolation: + pending_class_types = set() + for node_id in execution_list.pendingNodes.keys(): + class_type = dynamic_prompt.get_node(node_id)["class_type"] + pending_class_types.add(class_type) + await self._wait_model_patcher_quiescence_safe( + fail_loud=True, + timeout_ms=120000, + marker="EX:notify_graph_wait_idle", + ) + await self._notify_execution_graph_safe(pending_class_types, fail_loud=True) + while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: diff --git a/main.py b/main.py index 12b04719d..2f97afb74 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,21 @@ +import os +import sys + +IS_PYISOLATE_CHILD = os.environ.get("PYISOLATE_CHILD") == "1" + +if __name__ == "__main__" and IS_PYISOLATE_CHILD: + del os.environ["PYISOLATE_CHILD"] + IS_PYISOLATE_CHILD = False + +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +if CURRENT_DIR not in sys.path: + sys.path.insert(0, CURRENT_DIR) + +IS_PRIMARY_PROCESS = (not IS_PYISOLATE_CHILD) and __name__ == "__main__" + import comfy.options comfy.options.enable_args_parsing() -import os import importlib.util import shutil import importlib.metadata @@ -12,7 +26,7 @@ from app.logger import setup_logger from app.assets.seeder import asset_seeder from app.assets.services import register_output_files import itertools -import utils.extra_config +import utils.extra_config # noqa: F401 from utils.mime_types import init_mime_types import faulthandler import logging @@ -22,12 +36,45 @@ from comfy_execution.utils import get_executing_context from comfy_api import feature_flags from app.database.db import init_db, dependencies_available -if __name__ == "__main__": - #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. +import comfy_aimdo.control + +if enables_dynamic_vram(): + if not comfy_aimdo.control.init(): + logging.warning( + "DynamicVRAM requested, but comfy-aimdo failed to initialize early. " + "Will fall back to legacy model loading if device init fails." + ) + +if '--use-process-isolation' in sys.argv: + from comfy.isolation import initialize_proxies + initialize_proxies() + + # Explicitly register the ComfyUI adapter for pyisolate (v1.0 architecture) + try: + import pyisolate + from comfy.isolation.adapter import ComfyUIAdapter + pyisolate.register_adapter(ComfyUIAdapter()) + logging.info("PyIsolate adapter registered: comfyui") + except ImportError: + logging.warning("PyIsolate not installed or version too old for explicit registration") + except Exception as e: + logging.error(f"Failed to register PyIsolate adapter: {e}") + + if not IS_PYISOLATE_CHILD: + if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:native' + +if not IS_PYISOLATE_CHILD: + from comfy_execution.progress import get_progress_state + from comfy_execution.utils import get_executing_context + from comfy_api import feature_flags + +if IS_PRIMARY_PROCESS: os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +if not IS_PYISOLATE_CHILD: + setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) faulthandler.enable(file=sys.stderr, all_threads=False) @@ -93,14 +140,15 @@ if args.enable_manager: def apply_custom_paths(): + from utils import extra_config # Deferred import - spawn re-runs main.py # extra model paths extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): - utils.extra_config.load_extra_path_config(extra_model_paths_config_path) + extra_config.load_extra_path_config(extra_model_paths_config_path) if args.extra_model_paths_config: for config_path in itertools.chain(*args.extra_model_paths_config): - utils.extra_config.load_extra_path_config(config_path) + extra_config.load_extra_path_config(config_path) # --output-directory, --input-directory, --user-directory if args.output_directory: @@ -173,15 +221,17 @@ def execute_prestartup_script(): else: import_message = " (PRESTARTUP FAILED)" logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) - logging.info("") + logging.info("") -apply_custom_paths() -init_mime_types() +if not IS_PYISOLATE_CHILD: + apply_custom_paths() + init_mime_types() -if args.enable_manager: +if args.enable_manager and not IS_PYISOLATE_CHILD: comfyui_manager.prestartup() -execute_prestartup_script() +if not IS_PYISOLATE_CHILD: + execute_prestartup_script() # Main code @@ -192,17 +242,17 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") - import comfy.utils -import execution -import server -from protocol import BinaryEventTypes -import nodes -import comfy.model_management -import comfyui_version -import app.logger -import hook_breaker_ac10a0 +if not IS_PYISOLATE_CHILD: + import execution + import server + from protocol import BinaryEventTypes + import nodes + import comfy.model_management + import comfyui_version + import app.logger + import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher @@ -462,6 +512,10 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + if args.use_process_isolation: + from comfy.isolation import start_isolation_loading_early + start_isolation_loading_early(asyncio_loop) + if args.enable_manager and not args.disable_manager_ui: comfyui_manager.start() @@ -506,12 +560,13 @@ def start_comfyui(asyncio_loop=None): if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("Python version: {}".format(sys.version)) - logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) - for package in ("comfy-aimdo", "comfy-kitchen"): - try: - logging.info("{} version: {}".format(package, importlib.metadata.version(package))) - except: - pass + if not IS_PYISOLATE_CHILD: + logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + for package in ("comfy-aimdo", "comfy-kitchen"): + try: + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + except: + pass if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")