diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 37614a4c3..2ef9f32c0 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1253,6 +1253,140 @@ class DynamicSlot(ComfyTypeI): out_dict[input_type][finalized_id] = value out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) +@comfytype(io_type="COMFY_LIST_V3") +class List(ComfyTypeI): + """A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows). + + At execution time the node receives a ``list[dict]`` where each element is a row. + + Example:: + + io.List.Input( + "loras", + template=[ + io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")), + io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01), + ], + min=0, + max=50, + ) + # execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...] + """ + + Type = list[dict[str, Any]] + _MaxRows = 100 + + class Input(DynamicInput): + def __init__( + self, + id: str, + template: list["Input"], + min: int = 0, + max: int = 50, + display_name: str = None, + optional: bool = False, + tooltip: str = None, + lazy: bool = None, + extra_dict=None, + ): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + # Validate template entries: only WidgetInput subclasses, no nesting + assert len(template) > 0, "List template must have at least one field." + for t in template: + assert isinstance(t, WidgetInput), ( + f"List template field '{t.id}' must be a WidgetInput subclass " + f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}." + ) + assert not isinstance(t, DynamicInput), ( + f"List template field '{t.id}' must not be a DynamicInput. " + "Nesting dynamic inputs inside List is not supported." + ) + # Enforce unique field ids within template + field_ids = [t.id for t in template] + assert len(field_ids) == len(set(field_ids)), ( + f"List template field ids must be unique within a row. Got: {field_ids}" + ) + assert min >= 0, "List min must be >= 0." + assert max >= 1, "List max must be >= 1." + assert max <= List._MaxRows, f"List max must be <= {List._MaxRows}." + assert min <= max, "List min must be <= max." + self.template = template + self.min = min + self.max = max + + def get_all(self) -> list["Input"]: + return [self] + list(self.template) + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": create_input_dict_v1(self.template), + "min": self.min, + "max": self.max, + }) + + def validate(self): + for t in self.template: + t.validate() + + @staticmethod + def _expand_schema_for_dynamic( + out_dict: dict[str, Any], + live_inputs: dict[str, Any], + value: tuple[str, dict[str, Any]], + input_type: str, + curr_prefix: list[str] | None, + ): + info = value[1] + min_rows: int = info.get("min", 0) + template: dict[str, Any] = info.get("template", {}) + + # Collect all template field specs across required/optional sections + field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = [] + for field_required_key in ("required", "optional"): + section = template.get(field_required_key, {}) + is_required_field = field_required_key == "required" + for field_id, field_value in section.items(): + field_specs.append((field_id, field_value, is_required_field)) + + # Determine how many rows are currently present by scanning live_inputs + finalized_prefix = finalize_prefix(curr_prefix) + present_rows = 0 + for live_key in live_inputs: + # Keys look like ".." + if live_key.startswith(finalized_prefix + "."): + remainder = live_key[len(finalized_prefix) + 1:] + parts = remainder.split(".", 1) + if len(parts) >= 1: + try: + row_idx = int(parts[0]) + present_rows = max(present_rows, row_idx + 1) + except ValueError: + pass + + row_count = max(min_rows, present_rows) + + for row in range(row_count): + for field_id, field_value, is_required_field in field_specs: + slot_id = f"{finalized_prefix}.{row}.{field_id}" + # The first `min_rows` rows are required if the field itself is required + if row < min_rows and is_required_field: + out_dict["required"][slot_id] = field_value + else: + out_dict["optional"][slot_id] = field_value + # Register into dynamic_paths so build_nested_inputs places value at the right path + out_dict["dynamic_paths"][slot_id] = slot_id + + # Track the list root path so build_nested_inputs can convert the index dict to a list + out_dict.setdefault("list_paths", set()).add(finalized_prefix) + + # Handle the empty case (0 rows) – emit an empty-list default for the parent. + # This must only fire when there are genuinely no rows; otherwise the parent + # path would clobber the per-row dict built from the slot ids above. + if row_count == 0: + out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix + out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST + + @comfytype(io_type="IMAGECOMPARE") class ImageCompare(ComfyTypeI): Type = dict @@ -1383,6 +1517,8 @@ def setup_dynamic_input_funcs(): register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic) # DynamicSlot.Input register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic) + # List.Input + register_dynamic_input_func(List.io_type, List._expand_schema_for_dynamic) if len(DYNAMIC_INPUT_LOOKUP) == 0: setup_dynamic_input_funcs() @@ -1394,6 +1530,8 @@ class V3Data(TypedDict): 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' dynamic_paths_default_value: dict[str, Any] 'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.' + list_paths: set[str] + 'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.' create_dynamic_tuple: bool 'When True, the value of the dynamic input will be in the format (value, path_key).' @@ -1727,6 +1865,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i "optional": {}, "dynamic_paths": {}, "dynamic_paths_default_value": {}, + "list_paths": set(), } d = d.copy() # ignore hidden for parsing @@ -1742,6 +1881,10 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None) if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0: v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value + # list_paths: keys whose nested dict should be post-converted to a sorted list[dict] + list_paths = out_dict.pop("list_paths", None) + if list_paths: + v3_data["list_paths"] = list_paths return out_dict, hidden, v3_data def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: @@ -1777,10 +1920,12 @@ def add_to_dict_v1(i: Input, d: dict): class DynamicPathsDefaultValue: EMPTY_DICT = "empty_dict" + EMPTY_LIST = "empty_list" def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): paths = v3_data.get("dynamic_paths", None) default_value_dict = v3_data.get("dynamic_paths_default_value", {}) + list_paths: set[str] = v3_data.get("list_paths", set()) or set() if paths is None: return values values = values.copy() @@ -1803,6 +1948,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): default_option = default_value_dict.get(key, None) if default_option == DynamicPathsDefaultValue.EMPTY_DICT: value = {} + elif default_option == DynamicPathsDefaultValue.EMPTY_LIST: + value = [] if create_tuple: value = (value, key) current[p] = value @@ -1810,6 +1957,34 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): current = current.setdefault(p, {}) values.update(result) + + # Post-pass: convert index-keyed dicts to sorted lists for io.List fields + for list_path in list_paths: + parts = list_path.split(".") + # Navigate to the parent container, then convert the leaf + container = values + for part in parts[:-1]: + if not isinstance(container, dict) or part not in container: + container = None + break + container = container[part] + if container is None: + continue + leaf_key = parts[-1] + leaf = container.get(leaf_key, None) + if isinstance(leaf, dict): + try: + sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)] + container[leaf_key] = sorted_rows + except (ValueError, TypeError): + # Keys are not all integers; leave as-is + pass + elif isinstance(leaf, list): + # Already a list (e.g. the EMPTY_LIST default was applied above) + pass + elif leaf is None: + container[leaf_key] = [] + return values @@ -2372,7 +2547,9 @@ __all__ = [ # Dynamic Types "MatchType", "DynamicCombo", + "DynamicSlot", "Autogrow", + "List", # Other classes "HiddenHolder", "Hidden", diff --git a/tests-unit/comfy_api_test/io_list_test.py b/tests-unit/comfy_api_test/io_list_test.py new file mode 100644 index 000000000..fe0a7c6d4 --- /dev/null +++ b/tests-unit/comfy_api_test/io_list_test.py @@ -0,0 +1,204 @@ +"""Unit tests for io.List: expansion/reconstruction (0-row and N-row cases).""" +import sys +import types +import pytest + +# Stub torch (type-hint only in _io.py; real torch not available in unit-test env) +if "torch" not in sys.modules: + _torch_stub = types.ModuleType("torch") + _torch_stub.Tensor = object # type: ignore[attr-defined] + sys.modules["torch"] = _torch_stub + +from comfy_api.latest._io import ( # noqa: E402 + List, + Float, + Int, + String, + Boolean, + get_finalized_class_inputs, + build_nested_inputs, + create_input_dict_v1, + setup_dynamic_input_funcs, +) + +# Make sure dynamic input funcs are registered (may already be done at import time) +setup_dynamic_input_funcs() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_class_inputs(list_input: List.Input) -> dict: + """Wrap a List.Input into the required/optional dict structure.""" + return create_input_dict_v1([list_input]) + + +def _run(list_input: List.Input, live_values: dict) -> dict: + """End-to-end helper: expand schema + reconstruct values. + + Mirrors the production split in execution.py: + 1. get_finalized_class_inputs (schema expansion, line 162) + 2. build_nested_inputs (value reconstruction, line 281) + + The two steps are separate in production because the engine resolves + linked node outputs between them, but in tests we supply values directly. + """ + class_inputs = _make_class_inputs(list_input) + _, _, v3_data = get_finalized_class_inputs(class_inputs, live_values) + return build_nested_inputs(dict(live_values), v3_data) + + +# --------------------------------------------------------------------------- +# Schema construction +# --------------------------------------------------------------------------- + +class TestListInputConstruction: + def test_basic_construction(self): + inp = List.Input( + "loras", + template=[ + Float.Input("strength", default=1.0), + String.Input("name"), + ], + min=0, + max=10, + ) + assert inp.id == "loras" + assert inp.min == 0 + assert inp.max == 10 + assert len(inp.template) == 2 + + def test_get_all_includes_self_and_template(self): + inp = List.Input( + "items", + template=[Float.Input("value")], + ) + all_inputs = inp.get_all() + assert all_inputs[0] is inp + assert all_inputs[1].id == "value" + + def test_as_dict_has_template_min_max(self): + inp = List.Input( + "items", + template=[Float.Input("val", default=0.5)], + min=1, + max=5, + ) + d = inp.as_dict() + assert "template" in d + assert d["min"] == 1 + assert d["max"] == 5 + + def test_duplicate_field_ids_raises(self): + with pytest.raises(AssertionError): + List.Input( + "bad", + template=[Float.Input("x"), Float.Input("x")], + ) + + def test_empty_template_raises(self): + with pytest.raises(AssertionError): + List.Input("bad", template=[]) + + def test_min_gt_max_raises(self): + with pytest.raises(AssertionError): + List.Input("bad", template=[Float.Input("x")], min=5, max=3) + + def test_max_exceeds_limit_raises(self): + with pytest.raises(AssertionError): + List.Input("bad", template=[Float.Input("x")], max=101) + + def test_dynamic_input_in_template_raises(self): + with pytest.raises(AssertionError): + List.Input( + "bad", + template=[List.Input("nested", template=[Float.Input("x")])], + ) + + def test_validate_calls_through(self): + inp = List.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)]) + inp.validate() # should not raise + + +# --------------------------------------------------------------------------- +# 0-row case +# --------------------------------------------------------------------------- + +class TestZeroRows: + def test_empty_live_inputs_produces_empty_list(self): + """With min=0 and no live values, the result should be an empty list.""" + inp = List.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10) + assert _run(inp, {}).get("loras") == [] + + def test_min_zero_with_values(self): + """min=0 but 2 rows of live data.""" + inp = List.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10) + result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5}) + assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}] + + +# --------------------------------------------------------------------------- +# N-row case +# --------------------------------------------------------------------------- + +class TestNRows: + def test_two_rows_two_fields(self): + """Two rows with two fields each produce a list[dict].""" + inp = List.Input( + "loras", + template=[String.Input("lora_name"), Float.Input("strength", default=1.0)], + min=0, max=50, + ) + result = _run(inp, { + "loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9, + "loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4, + }) + assert result["loras"] == [ + {"lora_name": "model_a.safetensors", "strength": 0.9}, + {"lora_name": "model_b.safetensors", "strength": 0.4}, + ] + + def test_rows_are_sorted_by_index(self): + """Rows must be in ascending index order even if dict iteration is unordered.""" + inp = List.Input("items", template=[Int.Input("v", default=0)], min=0, max=10) + result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20}) + assert [row["v"] for row in result["items"]] == [10, 20, 30] + + def test_min_rows_schema_slots(self): + """With min=2 and no live data, 2 slots must appear in the expanded schema.""" + inp = List.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5) + out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {}) + all_slots = {**out.get("required", {}), **out.get("optional", {})} + assert "items.0.val" in all_slots + assert "items.1.val" in all_slots + + def test_min_rows_reconstructs_when_no_values(self): + """min=2 with NO live values must still yield a 2-element list, + not collapse to [] (regression: parent-path clobber).""" + inp = List.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5) + result = _run(inp, {}) + assert len(result["items"]) == 2 + assert all("val" in row for row in result["items"]) + + def test_min_rows_reconstructs_with_partial_values(self): + """min=2 with only the first row's value present still yields 2 rows.""" + inp = List.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5) + result = _run(inp, {"items.0.val": 0.7}) + assert len(result["items"]) == 2 + assert result["items"][0]["val"] == 0.7 + assert result["items"][1]["val"] is None + + def test_list_paths_in_v3_data(self): + """list_paths must contain the list id so build_nested_inputs knows to convert.""" + inp = List.Input("things", template=[Boolean.Input("flag")], min=0, max=5) + _, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {}) + assert "things" in v3_data.get("list_paths", set()) + + def test_no_leftover_flat_keys(self): + """Flat keys must be consumed; only the reconstructed list remains.""" + inp = List.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5) + result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0}) + assert "rows.0.x" not in result + assert "rows.1.x" not in result + assert isinstance(result["rows"], list)