From 159e2d02c9d3cfc27c94e0a46f07be68a628bb11 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 17 Nov 2025 15:50:49 -0800 Subject: [PATCH] Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet) --- comfy_api/latest/_io.py | 71 +++++++++++++++++++++++++------------ comfy_extras/nodes_logic.py | 13 +++++-- execution.py | 28 ++++++++------- 3 files changed, 74 insertions(+), 38 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 27571e970..2ec109ddb 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -822,6 +822,10 @@ class DynamicInput(Input, ABC): ''' def get_dynamic(self) -> list[Input]: return [] + + def add_to_dict_live_inputs(self, d: dict[str], live_inputs: dict[str]): + pass + class DynamicOutput(Output, ABC): ''' @@ -896,6 +900,19 @@ class DynamicCombo(ComfyTypeI): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self.options = options + def add_to_dict_live_inputs(self, d: dict[str], live_inputs: dict[str]): + # check if dynamic input's id is in live_inputs + if self.id in live_inputs: + key = live_inputs[self.id] + selected_option = None + for option in self.options: + if option.key == key: + selected_option = option + break + if selected_option is not None: + add_to_input_dict_v1(d, selected_option.inputs, live_inputs) + add_dynamic_to_dict_v1(d, self, selected_option.inputs) + def get_dynamic(self) -> list[Input]: return [input for option in self.options for input in option.inputs] @@ -958,6 +975,9 @@ class MatchType(ComfyTypeIO): "template": self.template.as_dict(), }) +class V3Data(TypedDict): + hidden_inputs: dict[str] + dynamic_data: dict[str] class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -1149,9 +1169,9 @@ class Schema: if output.id is None: output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls) -> NodeInfoV1: + def get_v1_info(self, cls, live_inputs: dict[str]=None) -> NodeInfoV1: # get V1 inputs - input = create_input_dict_v1(self.inputs) + input = create_input_dict_v1(self.inputs, live_inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1232,21 +1252,23 @@ class Schema: return info -def create_input_dict_v1(inputs: list[Input]) -> dict: +def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str]=None) -> dict: input = { "required": {} } - for i in inputs: - if isinstance(i, DynamicInput): - add_to_dict_v1(i, input) - dynamic_inputs = i.get_dynamic() - for d in dynamic_inputs: - add_dynamic_to_dict_v1(d, input) - else: - add_to_dict_v1(i, input) + add_to_input_dict_v1(input, inputs, live_inputs) return input -def add_to_dict_v1(i: Input, input: dict, dynamic_dict: dict=None): +def add_to_input_dict_v1(d: dict[str], inputs: list[Input], live_inputs: dict[str]=None): + for i in inputs: + if isinstance(i, DynamicInput): + add_to_dict_v1(i, d) + if live_inputs is not None: + i.add_to_dict_live_inputs(d, live_inputs) + else: + add_to_dict_v1(i, d) + +def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key @@ -1255,12 +1277,12 @@ def add_to_dict_v1(i: Input, input: dict, dynamic_dict: dict=None): value = (i.get_io_type(), as_dict) else: value = (i.get_io_type(), as_dict, dynamic_dict) - input.setdefault(key, {})[i.id] = value + d.setdefault(key, {})[i.id] = value -def add_dynamic_to_dict_v1(d: DynamicInput, i: Input, input: dict): - dynamic = input.setdefault("_dynamic", {}) - dd = {"parent_id": d.id} - add_to_dict_v1(input, dynamic, dd) +def add_dynamic_to_dict_v1(d: dict[str], parent: DynamicInput, inputs: list[Input]): + dynamic = d.setdefault("dynamic_data", {}) + ids = [i.id for i in inputs] + dynamic[parent.id] = {"ids": ids} def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) @@ -1384,12 +1406,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) return type_clone @final @@ -1506,14 +1528,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: + def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls) + info = schema.get_v1_info(cls, live_inputs) input = info.input if not include_hidden: input.pop("hidden", None) if return_schema: - return input, schema + v3_data: V3Data = {} + dynamic = input.pop("dynamic_data", None) + if dynamic is not None: + v3_data["dynamic_data"] = dynamic + return input, schema, v3_data return input @final @@ -1737,4 +1763,5 @@ __all__ = [ "NodeOutput", "add_to_dict_v1", "add_to_dict_v3", + "V3Data", ] diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 4a502d992..ab095f8d3 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -48,12 +48,19 @@ class DCTestNode(io.ComfyNode): io.DynamicCombo.Option("option3", [io.Image.Input("image")]), ] )], - outputs=[io.String.Output()], + outputs=[io.AnyType.Output()], ) @classmethod - def execute(cls, combo) -> io.NodeOutput: - return io.NodeOutput("haha") + def execute(cls, combo, **kwargs) -> io.NodeOutput: + if combo == "option1": + return io.NodeOutput(kwargs["string"]) + elif combo == "option2": + return io.NodeOutput(kwargs["integer"]) + elif combo == "option3": + return io.NodeOutput(kwargs["image"]) + else: + raise ValueError(f"Invalid combo: {combo}") class LogicExtension(ComfyExtension): diff --git a/execution.py b/execution.py index 17c77beab..f72ce7d14 100644 --- a/execution.py +++ b/execution.py @@ -76,7 +76,7 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await resolve_map_node_over_list_results(is_changed) @@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) + v3_data: io.V3Data = {} if is_v3: - valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) + valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) else: valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] - return input_data_all, missing_keys, hidden_inputs_v3 + v3_data["hidden_inputs"] = hidden_inputs_v3 + return input_data_all, missing_keys, v3_data map_node_over_list = None #Don't hook this please @@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results): raise exc return [x.result() if isinstance(x, asyncio.Task) else x for x in results] -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -259,12 +261,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) # otherwise, use class instance to populate/reuse some fields else: type_obj = type(obj) type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) f = make_locked_method_func(type_obj, func, class_clone) # V1 else: @@ -320,8 +322,8 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -460,7 +462,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -475,7 +477,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -507,7 +509,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -935,7 +937,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -943,7 +945,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret):