diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 0f9700995..fb048c95d 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -46,7 +46,7 @@ from ..execution_context import context_execute_node, context_execute_prompt from ..execution_ext import should_panic_on_exception from ..node_requests_caching import use_requests_caching from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode -from ..nodes_context import get_nodes +from ..nodes_context import get_nodes, vanilla_node_execution_environment from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \ ProgressRegistry from comfy_execution.validation import validate_node_input @@ -456,7 +456,7 @@ def format_value(x) -> FormattedValue: async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: """ - + Executes a prompt :param server: :param dynprompt: :param caches: @@ -468,8 +468,9 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca :param pending_subgraph_results: :return: """ - with context_execute_node(node_id), \ - use_requests_caching(): + with (context_execute_node(node_id), + vanilla_node_execution_environment(), + use_requests_caching()): return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 7b51fdf8a..dfae1ce6f 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -61,7 +61,7 @@ BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions] ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]] -NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any] +NonPrimitiveTypeSpec = Tuple[CommonReturnTypes] | Tuple[CommonReturnTypes, dict] InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec] diff --git a/comfy/nodes/vanilla_node_importing.py b/comfy/nodes/vanilla_node_importing.py index cadf437c2..3ccb5ad9a 100644 --- a/comfy/nodes/vanilla_node_importing.py +++ b/comfy/nodes/vanilla_node_importing.py @@ -11,6 +11,7 @@ import types from contextlib import contextmanager from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath from typing import Iterable, Any, Generator +from unittest.mock import patch from comfy_compatibility.vanilla import prepare_vanilla_environment from . import base_nodes @@ -138,6 +139,7 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E if module.__name__.lower() in ( "comfyui-manager", "comfyui_ryanonyheinside", + "comfyui-easy-use", ): from ..cmd import folder_paths old_file = folder_paths.__file__ @@ -145,13 +147,11 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E try: # mitigate path new_path = join(abspath(join(dirname(old_file), "..", "..")), basename(old_file)) - folder_paths.__file__ = new_path - # mitigate JS copy - sys.modules['nodes'].EXTENSION_WEB_DIRS = {} - yield ExportedNodes() + with patch.object(folder_paths, "__file__", new_path), \ + patch.object(sys.modules['nodes'], "EXTENSION_WEB_DIRS", {}, create=True): # mitigate JS copy + yield ExportedNodes() finally: - folder_paths.__file__ = old_file # todo: mitigate "/manager/reboot" # todo: mitigate process_wrap # todo: unfortunately, we shouldn't restore the patches here, they will have to be applied forever. diff --git a/comfy/nodes_context.py b/comfy/nodes_context.py index 800f9fb2d..3f01a39ec 100644 --- a/comfy/nodes_context.py +++ b/comfy/nodes_context.py @@ -1,5 +1,9 @@ # todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive +import collections.abc +import sys import threading +from contextlib import contextmanager +from unittest.mock import patch import lazy_object_proxy @@ -24,3 +28,29 @@ def get_nodes() -> ExportedNodes: if len(current_ctx.custom_nodes) == 0: return nodes return exported_nodes_view(nodes, current_ctx.custom_nodes) + + +class _NodeClassMappingsShim(collections.abc.Mapping): + def __iter__(self): + for key in get_nodes().NODE_CLASS_MAPPINGS: + yield key + + def __getitem__(self, item): + return get_nodes().NODE_CLASS_MAPPINGS[item] + + def __len__(self): + return len(get_nodes().NODE_CLASS_MAPPINGS) + + # todo: does this need to be mutable? + + +@contextmanager +def vanilla_node_execution_environment(): + # check if we're running with patched nodes + if 'nodes' in sys.modules: + # this ensures NODE_CLASS_MAPPINGS is loaded lazily and contains all the nodes loaded so far, not just the base nodes + # easy-use and other nodes expect NODE_CLASS_MAPPINGS to contain all the nodes in the environment + with patch('nodes.NODE_CLASS_MAPPINGS', _NodeClassMappingsShim()): + yield + else: + yield diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index fe32b3b80..bce7242de 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -684,7 +684,7 @@ class SaveImagesResponse(CustomNode): mut_srgb_to_linear(image_as_numpy_array[:, :, :3]) image_scaled = image_as_numpy_array.astype(np.float32) if bits == 16: - cv_save_options = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF] + cv_save_options = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF] else: image_scaled = np.clip(image_as_numpy_array * 65535, 0, 65535).astype(np.uint16) @@ -779,7 +779,7 @@ class SaveImagesResponse(CustomNode): if save_method == 'pil': image_as_pil.save(local_path, format=save_format, **additional_args) else: - cv2.imwrite(local_path, image_scaled) + cv2.imwrite(local_path, image_scaled) img_item: SaveNodeResultWithName = { "abs_path": str(abs_path), @@ -809,6 +809,7 @@ class ImageRequestParameter(CustomNode): }, "optional": { **_open_api_common_schema, + "default_if_empty": ("IMAGE",) } } @@ -816,7 +817,9 @@ class ImageRequestParameter(CustomNode): FUNCTION = "execute" CATEGORY = "api/openapi" - def execute(self, value: str = "", *args, **kwargs) -> ValidatedNodeResult: + def execute(self, value: str = "", default_if_empty=None, *args, **kwargs) -> ValidatedNodeResult: + if value.strip() == "": + return (default_if_empty,) output_images = [] f: OpenFile fsspec_kwargs = {} diff --git a/tests/unit/test_nodes_context.py b/tests/unit/test_nodes_context.py new file mode 100644 index 000000000..abfd7f78c --- /dev/null +++ b/tests/unit/test_nodes_context.py @@ -0,0 +1,128 @@ +from unittest.mock import patch + +import pytest +import sys + +from comfy.client.embedded_comfy_client import Comfy +from comfy.component_model.make_mutable import make_mutable +from comfy.distributed.process_pool_executor import ProcessPoolExecutor +from comfy.execution_context import context_add_custom_nodes +from comfy.nodes.package_typing import CustomNode, ExportedNodes +from tests.unit.test_panics import ThrowsExceptionNode + + +def disable_vanilla(*args): + patch_disable_vanilla = globals()['prepare_vanilla_environment'] = patch('comfy_compatibility.vanilla.prepare_vanilla_environment', lambda: "patched") + patch_disable_vanilla.start() + from comfy_compatibility.vanilla import prepare_vanilla_environment + assert prepare_vanilla_environment() == "patched" + + +def enable_vanilla(*args): + patch_disable_vanilla = globals()['prepare_vanilla_environment'] + patch_disable_vanilla.stop() + + +class AssertVanillaImportFails(CustomNode): + @classmethod + def INPUT_TYPES(s): + return {"required": {}} + + RETURN_TYPES = ("STRING",) + FUNCTION = "assert_import_fails" + CATEGORY = "Testing/Nodes" + + def assert_import_fails(self) -> tuple[str]: + try: + # sometimes, other code like fluxtapoz has a directly called nodes, and for + # development purposes, its source directory is added to path, and this + # can be imported, so import nodes cannot be used + if 'nodes' in sys.modules: + assert 'NODE_CLASS_MAPPINGS' not in sys.modules['nodes'] or 'SplitImageWithAlpha' not in sys.modules['nodes'].NODE_CLASS_MAPPINGS + del sys.modules['nodes'] + + except ModuleNotFoundError: + pass + return ("dummy",) + + +class PrepareVanillaEnvironment(CustomNode): + @classmethod + def INPUT_TYPES(s): + return {"required": { + "input": ("STRING", {}), + }} + + RETURN_TYPES = ("STRING",) + FUNCTION = "prepare" + CATEGORY = "Testing/Nodes" + + def prepare(self, input: str) -> tuple[str]: + enable_vanilla() + from comfy_compatibility import vanilla + vanilla.prepare_vanilla_environment() + assert "nodes" in sys.modules + return ("dummy",) + + +class AssertVanillaImportSucceeds(CustomNode): + @classmethod + def INPUT_TYPES(s): + return {"required": { + "input": ("STRING", {}), + }} + + RETURN_TYPES = ("STRING",) + FUNCTION = "assert_import_succeeds" + CATEGORY = "Testing/Nodes" + + def assert_import_succeeds(self, input: str) -> tuple[str]: + import nodes + assert "SplitImageWithAlpha" in nodes.NODE_CLASS_MAPPINGS + return ("",) + + +def create_nodes_context_workflow(): + """Create a workflow that uses our test node to raise an exception""" + return make_mutable({ + "1": {"class_type": "AssertVanillaImportFails", "inputs": {}}, + "2": {"class_type": "PrepareVanillaEnvironment", "inputs": {"input": ["1", 0]}}, + "3": {"class_type": "AssertVanillaImportSucceeds", "inputs": {"input": ["2", 0]}}, + "4": {"class_type": "PreviewString", "inputs": {"value": ["3", 0]}}, + }) + + +TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestExceptionNode": "Test Exception Node", + "AssertVanillaImportFails": "Assert Vanilla Import Fails", + "PrepareVanillaEnvironment": "Prepare Vanilla Environment", + "AssertVanillaImportSucceeds": "Assert Vanilla Import Succeeds", +} + +EXECUTOR_FACTORIES = [ + (ProcessPoolExecutor, {"max_workers": 1}), +] + + +@pytest.mark.asyncio +async def test_nodes_context_shim(): + """Test panic behavior with different executor types""" + + # Initialize the specific executor + executor = ProcessPoolExecutor(max_workers=1, initializer=disable_vanilla) + + if 'nodes' in sys.modules: + # something else imported it + del sys.modules['nodes'] + assert 'nodes' not in sys.modules + with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS={ + "TestExceptionNode": ThrowsExceptionNode, + "AssertVanillaImportFails": AssertVanillaImportFails, + "PrepareVanillaEnvironment": PrepareVanillaEnvironment, + "AssertVanillaImportSucceeds": AssertVanillaImportSucceeds, + }, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)): + async with Comfy(executor=executor) as client: + # Queue our failing workflow + workflow = create_nodes_context_workflow() + await client.queue_prompt(workflow) + assert 'nodes' not in sys.modules