mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 18:13:01 +08:00
Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet)
This commit is contained in:
parent
ff392b1fb3
commit
159e2d02c9
@ -823,6 +823,10 @@ class DynamicInput(Input, ABC):
|
|||||||
def get_dynamic(self) -> list[Input]:
|
def get_dynamic(self) -> list[Input]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def add_to_dict_live_inputs(self, d: dict[str], live_inputs: dict[str]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutput(Output, ABC):
|
class DynamicOutput(Output, ABC):
|
||||||
'''
|
'''
|
||||||
Abstract class for dynamic output registration.
|
Abstract class for dynamic output registration.
|
||||||
@ -896,6 +900,19 @@ class DynamicCombo(ComfyTypeI):
|
|||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
self.options = options
|
self.options = options
|
||||||
|
|
||||||
|
def add_to_dict_live_inputs(self, d: dict[str], live_inputs: dict[str]):
|
||||||
|
# check if dynamic input's id is in live_inputs
|
||||||
|
if self.id in live_inputs:
|
||||||
|
key = live_inputs[self.id]
|
||||||
|
selected_option = None
|
||||||
|
for option in self.options:
|
||||||
|
if option.key == key:
|
||||||
|
selected_option = option
|
||||||
|
break
|
||||||
|
if selected_option is not None:
|
||||||
|
add_to_input_dict_v1(d, selected_option.inputs, live_inputs)
|
||||||
|
add_dynamic_to_dict_v1(d, self, selected_option.inputs)
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
def get_dynamic(self) -> list[Input]:
|
||||||
return [input for option in self.options for input in option.inputs]
|
return [input for option in self.options for input in option.inputs]
|
||||||
|
|
||||||
@ -958,6 +975,9 @@ class MatchType(ComfyTypeIO):
|
|||||||
"template": self.template.as_dict(),
|
"template": self.template.as_dict(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
class V3Data(TypedDict):
|
||||||
|
hidden_inputs: dict[str]
|
||||||
|
dynamic_data: dict[str]
|
||||||
|
|
||||||
class HiddenHolder:
|
class HiddenHolder:
|
||||||
def __init__(self, unique_id: str, prompt: Any,
|
def __init__(self, unique_id: str, prompt: Any,
|
||||||
@ -1149,9 +1169,9 @@ class Schema:
|
|||||||
if output.id is None:
|
if output.id is None:
|
||||||
output.id = f"_{i}_{output.io_type}_"
|
output.id = f"_{i}_{output.io_type}_"
|
||||||
|
|
||||||
def get_v1_info(self, cls) -> NodeInfoV1:
|
def get_v1_info(self, cls, live_inputs: dict[str]=None) -> NodeInfoV1:
|
||||||
# get V1 inputs
|
# get V1 inputs
|
||||||
input = create_input_dict_v1(self.inputs)
|
input = create_input_dict_v1(self.inputs, live_inputs)
|
||||||
if self.hidden:
|
if self.hidden:
|
||||||
for hidden in self.hidden:
|
for hidden in self.hidden:
|
||||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||||
@ -1232,21 +1252,23 @@ class Schema:
|
|||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def create_input_dict_v1(inputs: list[Input]) -> dict:
|
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str]=None) -> dict:
|
||||||
input = {
|
input = {
|
||||||
"required": {}
|
"required": {}
|
||||||
}
|
}
|
||||||
for i in inputs:
|
add_to_input_dict_v1(input, inputs, live_inputs)
|
||||||
if isinstance(i, DynamicInput):
|
|
||||||
add_to_dict_v1(i, input)
|
|
||||||
dynamic_inputs = i.get_dynamic()
|
|
||||||
for d in dynamic_inputs:
|
|
||||||
add_dynamic_to_dict_v1(d, input)
|
|
||||||
else:
|
|
||||||
add_to_dict_v1(i, input)
|
|
||||||
return input
|
return input
|
||||||
|
|
||||||
def add_to_dict_v1(i: Input, input: dict, dynamic_dict: dict=None):
|
def add_to_input_dict_v1(d: dict[str], inputs: list[Input], live_inputs: dict[str]=None):
|
||||||
|
for i in inputs:
|
||||||
|
if isinstance(i, DynamicInput):
|
||||||
|
add_to_dict_v1(i, d)
|
||||||
|
if live_inputs is not None:
|
||||||
|
i.add_to_dict_live_inputs(d, live_inputs)
|
||||||
|
else:
|
||||||
|
add_to_dict_v1(i, d)
|
||||||
|
|
||||||
|
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
||||||
key = "optional" if i.optional else "required"
|
key = "optional" if i.optional else "required"
|
||||||
as_dict = i.as_dict()
|
as_dict = i.as_dict()
|
||||||
# for v1, we don't want to include the optional key
|
# for v1, we don't want to include the optional key
|
||||||
@ -1255,12 +1277,12 @@ def add_to_dict_v1(i: Input, input: dict, dynamic_dict: dict=None):
|
|||||||
value = (i.get_io_type(), as_dict)
|
value = (i.get_io_type(), as_dict)
|
||||||
else:
|
else:
|
||||||
value = (i.get_io_type(), as_dict, dynamic_dict)
|
value = (i.get_io_type(), as_dict, dynamic_dict)
|
||||||
input.setdefault(key, {})[i.id] = value
|
d.setdefault(key, {})[i.id] = value
|
||||||
|
|
||||||
def add_dynamic_to_dict_v1(d: DynamicInput, i: Input, input: dict):
|
def add_dynamic_to_dict_v1(d: dict[str], parent: DynamicInput, inputs: list[Input]):
|
||||||
dynamic = input.setdefault("_dynamic", {})
|
dynamic = d.setdefault("dynamic_data", {})
|
||||||
dd = {"parent_id": d.id}
|
ids = [i.id for i in inputs]
|
||||||
add_to_dict_v1(input, dynamic, dd)
|
dynamic[parent.id] = {"ids": ids}
|
||||||
|
|
||||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||||
@ -1384,12 +1406,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
|
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
|
||||||
"""Creates clone of real node class to prevent monkey-patching."""
|
"""Creates clone of real node class to prevent monkey-patching."""
|
||||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||||
# set hidden
|
# set hidden
|
||||||
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
|
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
|
||||||
return type_clone
|
return type_clone
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1506,14 +1528,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
|
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
||||||
schema = cls.FINALIZE_SCHEMA()
|
schema = cls.FINALIZE_SCHEMA()
|
||||||
info = schema.get_v1_info(cls)
|
info = schema.get_v1_info(cls, live_inputs)
|
||||||
input = info.input
|
input = info.input
|
||||||
if not include_hidden:
|
if not include_hidden:
|
||||||
input.pop("hidden", None)
|
input.pop("hidden", None)
|
||||||
if return_schema:
|
if return_schema:
|
||||||
return input, schema
|
v3_data: V3Data = {}
|
||||||
|
dynamic = input.pop("dynamic_data", None)
|
||||||
|
if dynamic is not None:
|
||||||
|
v3_data["dynamic_data"] = dynamic
|
||||||
|
return input, schema, v3_data
|
||||||
return input
|
return input
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1737,4 +1763,5 @@ __all__ = [
|
|||||||
"NodeOutput",
|
"NodeOutput",
|
||||||
"add_to_dict_v1",
|
"add_to_dict_v1",
|
||||||
"add_to_dict_v3",
|
"add_to_dict_v3",
|
||||||
|
"V3Data",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -48,12 +48,19 @@ class DCTestNode(io.ComfyNode):
|
|||||||
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||||
]
|
]
|
||||||
)],
|
)],
|
||||||
outputs=[io.String.Output()],
|
outputs=[io.AnyType.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, combo) -> io.NodeOutput:
|
def execute(cls, combo, **kwargs) -> io.NodeOutput:
|
||||||
return io.NodeOutput("haha")
|
if combo == "option1":
|
||||||
|
return io.NodeOutput(kwargs["string"])
|
||||||
|
elif combo == "option2":
|
||||||
|
return io.NodeOutput(kwargs["integer"])
|
||||||
|
elif combo == "option3":
|
||||||
|
return io.NodeOutput(kwargs["image"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid combo: {combo}")
|
||||||
|
|
||||||
|
|
||||||
class LogicExtension(ComfyExtension):
|
class LogicExtension(ComfyExtension):
|
||||||
|
|||||||
28
execution.py
28
execution.py
@ -76,7 +76,7 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
|||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||||
|
v3_data: io.V3Data = {}
|
||||||
if is_v3:
|
if is_v3:
|
||||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||||
else:
|
else:
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||||
if h[x] == "API_KEY_COMFY_ORG":
|
if h[x] == "API_KEY_COMFY_ORG":
|
||||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||||
return input_data_all, missing_keys, hidden_inputs_v3
|
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||||
|
return input_data_all, missing_keys, v3_data
|
||||||
|
|
||||||
map_node_over_list = None #Don't hook this please
|
map_node_over_list = None #Don't hook this please
|
||||||
|
|
||||||
@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
|
|||||||
raise exc
|
raise exc
|
||||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||||
|
|
||||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||||
|
|
||||||
@ -259,12 +261,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
if is_class(obj):
|
if is_class(obj):
|
||||||
type_obj = obj
|
type_obj = obj
|
||||||
obj.VALIDATE_CLASS()
|
obj.VALIDATE_CLASS()
|
||||||
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||||||
# otherwise, use class instance to populate/reuse some fields
|
# otherwise, use class instance to populate/reuse some fields
|
||||||
else:
|
else:
|
||||||
type_obj = type(obj)
|
type_obj = type(obj)
|
||||||
type_obj.VALIDATE_CLASS()
|
type_obj.VALIDATE_CLASS()
|
||||||
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||||||
f = make_locked_method_func(type_obj, func, class_clone)
|
f = make_locked_method_func(type_obj, func, class_clone)
|
||||||
# V1
|
# V1
|
||||||
else:
|
else:
|
||||||
@ -320,8 +322,8 @@ def merge_result_data(results, obj):
|
|||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||||
if has_pending_task:
|
if has_pending_task:
|
||||||
return return_values, {}, False, has_pending_task
|
return return_values, {}, False, has_pending_task
|
||||||
@ -460,7 +462,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
get_progress_state().start_progress(unique_id)
|
||||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
@ -475,7 +477,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
else:
|
else:
|
||||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||||
if lazy_status_present:
|
if lazy_status_present:
|
||||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
|
||||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
@ -507,7 +509,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
def pre_execute_cb(call_index):
|
def pre_execute_cb(call_index):
|
||||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
pending_async_nodes[unique_id] = output_data
|
pending_async_nodes[unique_id] = output_data
|
||||||
unblock = execution_list.add_external_block(unique_id)
|
unblock = execution_list.add_external_block(unique_id)
|
||||||
@ -935,7 +937,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||||
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
|
||||||
input_filtered = {}
|
input_filtered = {}
|
||||||
for x in input_data_all:
|
for x in input_data_all:
|
||||||
if x in validate_function_inputs or validate_has_kwargs:
|
if x in validate_function_inputs or validate_has_kwargs:
|
||||||
@ -943,7 +945,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
if 'input_types' in validate_function_inputs:
|
if 'input_types' in validate_function_inputs:
|
||||||
input_filtered['input_types'] = [received_types]
|
input_filtered['input_types'] = [received_types]
|
||||||
|
|
||||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
|
||||||
ret = await resolve_map_node_over_list_results(ret)
|
ret = await resolve_map_node_over_list_results(ret)
|
||||||
for x in input_filtered:
|
for x in input_filtered:
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user