Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet)

This commit is contained in:
Jedrzej Kosinski 2025-11-17 15:50:49 -08:00
parent ff392b1fb3
commit 159e2d02c9
3 changed files with 74 additions and 38 deletions

View File

@ -823,6 +823,10 @@ class DynamicInput(Input, ABC):
def get_dynamic(self) -> list[Input]: def get_dynamic(self) -> list[Input]:
return [] return []
def add_to_dict_live_inputs(self, d: dict[str], live_inputs: dict[str]):
pass
class DynamicOutput(Output, ABC): class DynamicOutput(Output, ABC):
''' '''
Abstract class for dynamic output registration. Abstract class for dynamic output registration.
@ -896,6 +900,19 @@ class DynamicCombo(ComfyTypeI):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.options = options self.options = options
def add_to_dict_live_inputs(self, d: dict[str], live_inputs: dict[str]):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs)
add_dynamic_to_dict_v1(d, self, selected_option.inputs)
def get_dynamic(self) -> list[Input]: def get_dynamic(self) -> list[Input]:
return [input for option in self.options for input in option.inputs] return [input for option in self.options for input in option.inputs]
@ -958,6 +975,9 @@ class MatchType(ComfyTypeIO):
"template": self.template.as_dict(), "template": self.template.as_dict(),
}) })
class V3Data(TypedDict):
hidden_inputs: dict[str]
dynamic_data: dict[str]
class HiddenHolder: class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any, def __init__(self, unique_id: str, prompt: Any,
@ -1149,9 +1169,9 @@ class Schema:
if output.id is None: if output.id is None:
output.id = f"_{i}_{output.io_type}_" output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls) -> NodeInfoV1: def get_v1_info(self, cls, live_inputs: dict[str]=None) -> NodeInfoV1:
# get V1 inputs # get V1 inputs
input = create_input_dict_v1(self.inputs) input = create_input_dict_v1(self.inputs, live_inputs)
if self.hidden: if self.hidden:
for hidden in self.hidden: for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,) input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1232,21 +1252,23 @@ class Schema:
return info return info
def create_input_dict_v1(inputs: list[Input]) -> dict: def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str]=None) -> dict:
input = { input = {
"required": {} "required": {}
} }
for i in inputs: add_to_input_dict_v1(input, inputs, live_inputs)
if isinstance(i, DynamicInput):
add_to_dict_v1(i, input)
dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_dynamic_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
return input return input
def add_to_dict_v1(i: Input, input: dict, dynamic_dict: dict=None): def add_to_input_dict_v1(d: dict[str], inputs: list[Input], live_inputs: dict[str]=None):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.add_to_dict_live_inputs(d, live_inputs)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required" key = "optional" if i.optional else "required"
as_dict = i.as_dict() as_dict = i.as_dict()
# for v1, we don't want to include the optional key # for v1, we don't want to include the optional key
@ -1255,12 +1277,12 @@ def add_to_dict_v1(i: Input, input: dict, dynamic_dict: dict=None):
value = (i.get_io_type(), as_dict) value = (i.get_io_type(), as_dict)
else: else:
value = (i.get_io_type(), as_dict, dynamic_dict) value = (i.get_io_type(), as_dict, dynamic_dict)
input.setdefault(key, {})[i.id] = value d.setdefault(key, {})[i.id] = value
def add_dynamic_to_dict_v1(d: DynamicInput, i: Input, input: dict): def add_dynamic_to_dict_v1(d: dict[str], parent: DynamicInput, inputs: list[Input]):
dynamic = input.setdefault("_dynamic", {}) dynamic = d.setdefault("dynamic_data", {})
dd = {"parent_id": d.id} ids = [i.id for i in inputs]
add_to_dict_v1(input, dynamic, dd) dynamic[parent.id] = {"ids": ids}
def add_to_dict_v3(io: Input | Output, d: dict): def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict()) d[io.id] = (io.get_io_type(), io.as_dict())
@ -1384,12 +1406,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching.""" """Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type) type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden # set hidden
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
return type_clone return type_clone
@final @final
@ -1506,14 +1528,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
schema = cls.FINALIZE_SCHEMA() schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls) info = schema.get_v1_info(cls, live_inputs)
input = info.input input = info.input
if not include_hidden: if not include_hidden:
input.pop("hidden", None) input.pop("hidden", None)
if return_schema: if return_schema:
return input, schema v3_data: V3Data = {}
dynamic = input.pop("dynamic_data", None)
if dynamic is not None:
v3_data["dynamic_data"] = dynamic
return input, schema, v3_data
return input return input
@final @final
@ -1737,4 +1763,5 @@ __all__ = [
"NodeOutput", "NodeOutput",
"add_to_dict_v1", "add_to_dict_v1",
"add_to_dict_v3", "add_to_dict_v3",
"V3Data",
] ]

View File

@ -48,12 +48,19 @@ class DCTestNode(io.ComfyNode):
io.DynamicCombo.Option("option3", [io.Image.Input("image")]), io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
] ]
)], )],
outputs=[io.String.Output()], outputs=[io.AnyType.Output()],
) )
@classmethod @classmethod
def execute(cls, combo) -> io.NodeOutput: def execute(cls, combo, **kwargs) -> io.NodeOutput:
return io.NodeOutput("haha") if combo == "option1":
return io.NodeOutput(kwargs["string"])
elif combo == "option2":
return io.NodeOutput(kwargs["integer"])
elif combo == "option3":
return io.NodeOutput(kwargs["image"])
else:
raise ValueError(f"Invalid combo: {combo}")
class LogicExtension(ComfyExtension): class LogicExtension(ComfyExtension):

View File

@ -76,7 +76,7 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal) is_v3 = issubclass(class_def, _ComfyNodeInternal)
v3_data: io.V3Data = {}
if is_v3: if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
else: else:
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG": if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys, hidden_inputs_v3 v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please
@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
raise exc raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results] return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
# check if node wants the lists # check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False) input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -259,12 +261,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if is_class(obj): if is_class(obj):
type_obj = obj type_obj = obj
obj.VALIDATE_CLASS() obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields # otherwise, use class instance to populate/reuse some fields
else: else:
type_obj = type(obj) type_obj = type(obj)
type_obj.VALIDATE_CLASS() type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone) f = make_locked_method_func(type_obj, func, class_clone)
# V1 # V1
else: else:
@ -320,8 +322,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task: if has_pending_task:
return return_values, {}, False, has_pending_task return return_values, {}, False, has_pending_task
@ -460,7 +462,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -475,7 +477,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else: else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present: if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@ -507,7 +509,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index): def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id) unblock = execution_list.add_external_block(unique_id)
@ -935,7 +937,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue continue
if len(validate_function_inputs) > 0 or validate_has_kwargs: if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_filtered = {} input_filtered = {}
for x in input_data_all: for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs: if x in validate_function_inputs or validate_has_kwargs:
@ -943,7 +945,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs: if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
ret = await resolve_map_node_over_list_results(ret) ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):