Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet)

This commit is contained in:
Jedrzej Kosinski 2025-11-17 15:50:49 -08:00
parent ff392b1fb3
commit 159e2d02c9
3 changed files with 74 additions and 38 deletions

View File

@ -822,6 +822,10 @@ class DynamicInput(Input, ABC):
'''
def get_dynamic(self) -> list[Input]:
return []
def add_to_dict_live_inputs(self, d: dict[str], live_inputs: dict[str]):
pass
class DynamicOutput(Output, ABC):
'''
@ -896,6 +900,19 @@ class DynamicCombo(ComfyTypeI):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.options = options
def add_to_dict_live_inputs(self, d: dict[str], live_inputs: dict[str]):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs)
add_dynamic_to_dict_v1(d, self, selected_option.inputs)
def get_dynamic(self) -> list[Input]:
return [input for option in self.options for input in option.inputs]
@ -958,6 +975,9 @@ class MatchType(ComfyTypeIO):
"template": self.template.as_dict(),
})
class V3Data(TypedDict):
hidden_inputs: dict[str]
dynamic_data: dict[str]
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
@ -1149,9 +1169,9 @@ class Schema:
if output.id is None:
output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls) -> NodeInfoV1:
def get_v1_info(self, cls, live_inputs: dict[str]=None) -> NodeInfoV1:
# get V1 inputs
input = create_input_dict_v1(self.inputs)
input = create_input_dict_v1(self.inputs, live_inputs)
if self.hidden:
for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1232,21 +1252,23 @@ class Schema:
return info
def create_input_dict_v1(inputs: list[Input]) -> dict:
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str]=None) -> dict:
input = {
"required": {}
}
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, input)
dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_dynamic_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
add_to_input_dict_v1(input, inputs, live_inputs)
return input
def add_to_dict_v1(i: Input, input: dict, dynamic_dict: dict=None):
def add_to_input_dict_v1(d: dict[str], inputs: list[Input], live_inputs: dict[str]=None):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.add_to_dict_live_inputs(d, live_inputs)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required"
as_dict = i.as_dict()
# for v1, we don't want to include the optional key
@ -1255,12 +1277,12 @@ def add_to_dict_v1(i: Input, input: dict, dynamic_dict: dict=None):
value = (i.get_io_type(), as_dict)
else:
value = (i.get_io_type(), as_dict, dynamic_dict)
input.setdefault(key, {})[i.id] = value
d.setdefault(key, {})[i.id] = value
def add_dynamic_to_dict_v1(d: DynamicInput, i: Input, input: dict):
dynamic = input.setdefault("_dynamic", {})
dd = {"parent_id": d.id}
add_to_dict_v1(input, dynamic, dd)
def add_dynamic_to_dict_v1(d: dict[str], parent: DynamicInput, inputs: list[Input]):
dynamic = d.setdefault("dynamic_data", {})
ids = [i.id for i in inputs]
dynamic[parent.id] = {"ids": ids}
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
@ -1384,12 +1406,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
return type_clone
@final
@ -1506,14 +1528,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls)
info = schema.get_v1_info(cls, live_inputs)
input = info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
return input, schema
v3_data: V3Data = {}
dynamic = input.pop("dynamic_data", None)
if dynamic is not None:
v3_data["dynamic_data"] = dynamic
return input, schema, v3_data
return input
@final
@ -1737,4 +1763,5 @@ __all__ = [
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
"V3Data",
]

View File

@ -48,12 +48,19 @@ class DCTestNode(io.ComfyNode):
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
]
)],
outputs=[io.String.Output()],
outputs=[io.AnyType.Output()],
)
@classmethod
def execute(cls, combo) -> io.NodeOutput:
return io.NodeOutput("haha")
def execute(cls, combo, **kwargs) -> io.NodeOutput:
if combo == "option1":
return io.NodeOutput(kwargs["string"])
elif combo == "option2":
return io.NodeOutput(kwargs["integer"])
elif combo == "option3":
return io.NodeOutput(kwargs["image"])
else:
raise ValueError(f"Invalid combo: {combo}")
class LogicExtension(ComfyExtension):

View File

@ -76,7 +76,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed)
@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
v3_data: io.V3Data = {}
if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
else:
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys, hidden_inputs_v3
v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
map_node_over_list = None #Don't hook this please
@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -259,12 +261,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# V1
else:
@ -320,8 +322,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@ -460,7 +462,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -475,7 +477,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@ -507,7 +509,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
@ -935,7 +937,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:
@ -943,7 +945,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:
for i, r in enumerate(ret):