mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Better integration with logic nodes from EasyUse
- ImageRequestParameter now returns None or a provided default when the value of its path / URL is empty, instead of erroring - Custom nodes which touch nodes.NODE_CLASS_MAPPINGS will once again see all the nodes available during execution, instead of only the base nodes
This commit is contained in:
parent
2d2d625ed0
commit
82bffb7855
@ -46,7 +46,7 @@ from ..execution_context import context_execute_node, context_execute_prompt
|
|||||||
from ..execution_ext import should_panic_on_exception
|
from ..execution_ext import should_panic_on_exception
|
||||||
from ..node_requests_caching import use_requests_caching
|
from ..node_requests_caching import use_requests_caching
|
||||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||||
from ..nodes_context import get_nodes
|
from ..nodes_context import get_nodes, vanilla_node_execution_environment
|
||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
||||||
ProgressRegistry
|
ProgressRegistry
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
@ -456,7 +456,7 @@ def format_value(x) -> FormattedValue:
|
|||||||
|
|
||||||
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||||
"""
|
"""
|
||||||
|
Executes a prompt
|
||||||
:param server:
|
:param server:
|
||||||
:param dynprompt:
|
:param dynprompt:
|
||||||
:param caches:
|
:param caches:
|
||||||
@ -468,8 +468,9 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
|
|||||||
:param pending_subgraph_results:
|
:param pending_subgraph_results:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with context_execute_node(node_id), \
|
with (context_execute_node(node_id),
|
||||||
use_requests_caching():
|
vanilla_node_execution_environment(),
|
||||||
|
use_requests_caching()):
|
||||||
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -61,7 +61,7 @@ BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
|
|||||||
|
|
||||||
ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]]
|
ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]]
|
||||||
|
|
||||||
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes] | Tuple[CommonReturnTypes, dict]
|
||||||
|
|
||||||
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import types
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
||||||
from typing import Iterable, Any, Generator
|
from typing import Iterable, Any, Generator
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from comfy_compatibility.vanilla import prepare_vanilla_environment
|
from comfy_compatibility.vanilla import prepare_vanilla_environment
|
||||||
from . import base_nodes
|
from . import base_nodes
|
||||||
@ -138,6 +139,7 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E
|
|||||||
if module.__name__.lower() in (
|
if module.__name__.lower() in (
|
||||||
"comfyui-manager",
|
"comfyui-manager",
|
||||||
"comfyui_ryanonyheinside",
|
"comfyui_ryanonyheinside",
|
||||||
|
"comfyui-easy-use",
|
||||||
):
|
):
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
old_file = folder_paths.__file__
|
old_file = folder_paths.__file__
|
||||||
@ -145,13 +147,11 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E
|
|||||||
try:
|
try:
|
||||||
# mitigate path
|
# mitigate path
|
||||||
new_path = join(abspath(join(dirname(old_file), "..", "..")), basename(old_file))
|
new_path = join(abspath(join(dirname(old_file), "..", "..")), basename(old_file))
|
||||||
folder_paths.__file__ = new_path
|
|
||||||
# mitigate JS copy
|
|
||||||
sys.modules['nodes'].EXTENSION_WEB_DIRS = {}
|
|
||||||
|
|
||||||
yield ExportedNodes()
|
with patch.object(folder_paths, "__file__", new_path), \
|
||||||
|
patch.object(sys.modules['nodes'], "EXTENSION_WEB_DIRS", {}, create=True): # mitigate JS copy
|
||||||
|
yield ExportedNodes()
|
||||||
finally:
|
finally:
|
||||||
folder_paths.__file__ = old_file
|
|
||||||
# todo: mitigate "/manager/reboot"
|
# todo: mitigate "/manager/reboot"
|
||||||
# todo: mitigate process_wrap
|
# todo: mitigate process_wrap
|
||||||
# todo: unfortunately, we shouldn't restore the patches here, they will have to be applied forever.
|
# todo: unfortunately, we shouldn't restore the patches here, they will have to be applied forever.
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
|
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
|
||||||
|
import collections.abc
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import lazy_object_proxy
|
import lazy_object_proxy
|
||||||
|
|
||||||
@ -24,3 +28,29 @@ def get_nodes() -> ExportedNodes:
|
|||||||
if len(current_ctx.custom_nodes) == 0:
|
if len(current_ctx.custom_nodes) == 0:
|
||||||
return nodes
|
return nodes
|
||||||
return exported_nodes_view(nodes, current_ctx.custom_nodes)
|
return exported_nodes_view(nodes, current_ctx.custom_nodes)
|
||||||
|
|
||||||
|
|
||||||
|
class _NodeClassMappingsShim(collections.abc.Mapping):
|
||||||
|
def __iter__(self):
|
||||||
|
for key in get_nodes().NODE_CLASS_MAPPINGS:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return get_nodes().NODE_CLASS_MAPPINGS[item]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(get_nodes().NODE_CLASS_MAPPINGS)
|
||||||
|
|
||||||
|
# todo: does this need to be mutable?
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def vanilla_node_execution_environment():
|
||||||
|
# check if we're running with patched nodes
|
||||||
|
if 'nodes' in sys.modules:
|
||||||
|
# this ensures NODE_CLASS_MAPPINGS is loaded lazily and contains all the nodes loaded so far, not just the base nodes
|
||||||
|
# easy-use and other nodes expect NODE_CLASS_MAPPINGS to contain all the nodes in the environment
|
||||||
|
with patch('nodes.NODE_CLASS_MAPPINGS', _NodeClassMappingsShim()):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|||||||
@ -684,7 +684,7 @@ class SaveImagesResponse(CustomNode):
|
|||||||
mut_srgb_to_linear(image_as_numpy_array[:, :, :3])
|
mut_srgb_to_linear(image_as_numpy_array[:, :, :3])
|
||||||
image_scaled = image_as_numpy_array.astype(np.float32)
|
image_scaled = image_as_numpy_array.astype(np.float32)
|
||||||
if bits == 16:
|
if bits == 16:
|
||||||
cv_save_options = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]
|
cv_save_options = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]
|
||||||
else:
|
else:
|
||||||
image_scaled = np.clip(image_as_numpy_array * 65535, 0, 65535).astype(np.uint16)
|
image_scaled = np.clip(image_as_numpy_array * 65535, 0, 65535).astype(np.uint16)
|
||||||
|
|
||||||
@ -779,7 +779,7 @@ class SaveImagesResponse(CustomNode):
|
|||||||
if save_method == 'pil':
|
if save_method == 'pil':
|
||||||
image_as_pil.save(local_path, format=save_format, **additional_args)
|
image_as_pil.save(local_path, format=save_format, **additional_args)
|
||||||
else:
|
else:
|
||||||
cv2.imwrite(local_path, image_scaled)
|
cv2.imwrite(local_path, image_scaled)
|
||||||
|
|
||||||
img_item: SaveNodeResultWithName = {
|
img_item: SaveNodeResultWithName = {
|
||||||
"abs_path": str(abs_path),
|
"abs_path": str(abs_path),
|
||||||
@ -809,6 +809,7 @@ class ImageRequestParameter(CustomNode):
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
**_open_api_common_schema,
|
**_open_api_common_schema,
|
||||||
|
"default_if_empty": ("IMAGE",)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -816,7 +817,9 @@ class ImageRequestParameter(CustomNode):
|
|||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "api/openapi"
|
CATEGORY = "api/openapi"
|
||||||
|
|
||||||
def execute(self, value: str = "", *args, **kwargs) -> ValidatedNodeResult:
|
def execute(self, value: str = "", default_if_empty=None, *args, **kwargs) -> ValidatedNodeResult:
|
||||||
|
if value.strip() == "":
|
||||||
|
return (default_if_empty,)
|
||||||
output_images = []
|
output_images = []
|
||||||
f: OpenFile
|
f: OpenFile
|
||||||
fsspec_kwargs = {}
|
fsspec_kwargs = {}
|
||||||
|
|||||||
128
tests/unit/test_nodes_context.py
Normal file
128
tests/unit/test_nodes_context.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from comfy.client.embedded_comfy_client import Comfy
|
||||||
|
from comfy.component_model.make_mutable import make_mutable
|
||||||
|
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
|
||||||
|
from comfy.execution_context import context_add_custom_nodes
|
||||||
|
from comfy.nodes.package_typing import CustomNode, ExportedNodes
|
||||||
|
from tests.unit.test_panics import ThrowsExceptionNode
|
||||||
|
|
||||||
|
|
||||||
|
def disable_vanilla(*args):
|
||||||
|
patch_disable_vanilla = globals()['prepare_vanilla_environment'] = patch('comfy_compatibility.vanilla.prepare_vanilla_environment', lambda: "patched")
|
||||||
|
patch_disable_vanilla.start()
|
||||||
|
from comfy_compatibility.vanilla import prepare_vanilla_environment
|
||||||
|
assert prepare_vanilla_environment() == "patched"
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vanilla(*args):
|
||||||
|
patch_disable_vanilla = globals()['prepare_vanilla_environment']
|
||||||
|
patch_disable_vanilla.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class AssertVanillaImportFails(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
FUNCTION = "assert_import_fails"
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def assert_import_fails(self) -> tuple[str]:
|
||||||
|
try:
|
||||||
|
# sometimes, other code like fluxtapoz has a directly called nodes, and for
|
||||||
|
# development purposes, its source directory is added to path, and this
|
||||||
|
# can be imported, so import nodes cannot be used
|
||||||
|
if 'nodes' in sys.modules:
|
||||||
|
assert 'NODE_CLASS_MAPPINGS' not in sys.modules['nodes'] or 'SplitImageWithAlpha' not in sys.modules['nodes'].NODE_CLASS_MAPPINGS
|
||||||
|
del sys.modules['nodes']
|
||||||
|
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
return ("dummy",)
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareVanillaEnvironment(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"input": ("STRING", {}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
FUNCTION = "prepare"
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def prepare(self, input: str) -> tuple[str]:
|
||||||
|
enable_vanilla()
|
||||||
|
from comfy_compatibility import vanilla
|
||||||
|
vanilla.prepare_vanilla_environment()
|
||||||
|
assert "nodes" in sys.modules
|
||||||
|
return ("dummy",)
|
||||||
|
|
||||||
|
|
||||||
|
class AssertVanillaImportSucceeds(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"input": ("STRING", {}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
FUNCTION = "assert_import_succeeds"
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def assert_import_succeeds(self, input: str) -> tuple[str]:
|
||||||
|
import nodes
|
||||||
|
assert "SplitImageWithAlpha" in nodes.NODE_CLASS_MAPPINGS
|
||||||
|
return ("",)
|
||||||
|
|
||||||
|
|
||||||
|
def create_nodes_context_workflow():
|
||||||
|
"""Create a workflow that uses our test node to raise an exception"""
|
||||||
|
return make_mutable({
|
||||||
|
"1": {"class_type": "AssertVanillaImportFails", "inputs": {}},
|
||||||
|
"2": {"class_type": "PrepareVanillaEnvironment", "inputs": {"input": ["1", 0]}},
|
||||||
|
"3": {"class_type": "AssertVanillaImportSucceeds", "inputs": {"input": ["2", 0]}},
|
||||||
|
"4": {"class_type": "PreviewString", "inputs": {"value": ["3", 0]}},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"TestExceptionNode": "Test Exception Node",
|
||||||
|
"AssertVanillaImportFails": "Assert Vanilla Import Fails",
|
||||||
|
"PrepareVanillaEnvironment": "Prepare Vanilla Environment",
|
||||||
|
"AssertVanillaImportSucceeds": "Assert Vanilla Import Succeeds",
|
||||||
|
}
|
||||||
|
|
||||||
|
EXECUTOR_FACTORIES = [
|
||||||
|
(ProcessPoolExecutor, {"max_workers": 1}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nodes_context_shim():
|
||||||
|
"""Test panic behavior with different executor types"""
|
||||||
|
|
||||||
|
# Initialize the specific executor
|
||||||
|
executor = ProcessPoolExecutor(max_workers=1, initializer=disable_vanilla)
|
||||||
|
|
||||||
|
if 'nodes' in sys.modules:
|
||||||
|
# something else imported it
|
||||||
|
del sys.modules['nodes']
|
||||||
|
assert 'nodes' not in sys.modules
|
||||||
|
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS={
|
||||||
|
"TestExceptionNode": ThrowsExceptionNode,
|
||||||
|
"AssertVanillaImportFails": AssertVanillaImportFails,
|
||||||
|
"PrepareVanillaEnvironment": PrepareVanillaEnvironment,
|
||||||
|
"AssertVanillaImportSucceeds": AssertVanillaImportSucceeds,
|
||||||
|
}, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)):
|
||||||
|
async with Comfy(executor=executor) as client:
|
||||||
|
# Queue our failing workflow
|
||||||
|
workflow = create_nodes_context_workflow()
|
||||||
|
await client.queue_prompt(workflow)
|
||||||
|
assert 'nodes' not in sys.modules
|
||||||
Loading…
Reference in New Issue
Block a user