diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 0a9daf272..141801691 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -7,7 +7,7 @@ class LatentRebatch: "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) - INPUT_IS_LIST = True + INPUTS_ARE_LISTS = True OUTPUT_IS_LIST = (True, ) FUNCTION = "rebatch" @@ -105,4 +105,4 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { "RebatchLatents": "Rebatch Latents", -} \ No newline at end of file +} diff --git a/execution.py b/execution.py index 0e2cc15c1..a41f4eea6 100644 --- a/execution.py +++ b/execution.py @@ -7,26 +7,89 @@ import heapq import traceback import gc import time +import itertools import torch import nodes import comfy.model_management +def get_input_data_batches(input_data_all): + """Given input data that can contain combinatorial input values, returns all + the possible batches that can be made by combining the different input + values together.""" + + input_to_index = {} + index_to_values = [] + + i = 0 + for input_name, value in input_data_all.items(): + if isinstance(value, dict) and "combinatorial" in value: + input_to_index[input_name] = i + index_to_values.append(value["values"]) + i += 1 + + if len(index_to_values) == 0: + # No combinatorial options. + return [input_data_all] + + batches = [] + + print("GET" + str(input_data_all)) + print("ALL " + str(index_to_values)) + print("INPS " + str(input_to_index)) + + for combination in list(itertools.product(*index_to_values)): + print("COMBO " + str(combination)) + batch = {} + for input_name, value in input_data_all.items(): + if isinstance(value, dict) and "combinatorial" in value: + combination_index = input_to_index[input_name] + batch[input_name] = [combination[combination_index]] # + else: + # already made into a list by get_input_data + batch[input_name] = value + batches.append(batch) + + return batches + + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): + """Given input data from the prompt, returns a list of input data dicts for + each combinatorial batch.""" valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: input_data = inputs[x] + required_or_optional = ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]) + if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: return None - obj = outputs[input_unique_id][output_index] - input_data_all[x] = obj + + # This is a list of outputs for each batch of combinatorial inputs. + # Without any combinatorial inputs, it's a list of length 1. + outputs_for_all_batches = outputs[input_unique_id] + + def flatten(list_of_lists): + return list(itertools.chain.from_iterable(list_of_lists)) + + if len(outputs_for_all_batches) == 1: + # Single batch, no combinatorial stuff + input_data_all[x] = outputs_for_all_batches[0][output_index] + else: + # Make the outputs into a list for map-over-list use + # (they are themselves lists so flatten them afterwards) + input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] + input_values = flatten(input_values) + input_data_all[x] = input_values + elif is_combinatorial_input(input_data): + if required_or_optional: + input_data_all[x] = { "combinatorial": True, "values": input_data["values"] } else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): + if required_or_optional: input_data_all[x] = [input_data] if "hidden" in valid_inputs: @@ -39,25 +102,51 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] - return input_data_all + print("=== GetInputData: " + str(inputs)) + + input_data_all_batches = get_input_data_batches(input_data_all) + + return input_data_all_batches + +def slice_lists_into_dict(d, i): + """ + get a slice of inputs, repeat last input when list isn't long enough + d={ "seed": [ 1, 2, 3 ], "steps": [ 4, 8 ] }, i=2 -> { "seed": 3, "steps": 8 } + """ + d_new = {} + for k, v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # check if node wants the lists - intput_is_list = False - if hasattr(obj, "INPUT_IS_LIST"): - intput_is_list = obj.INPUT_IS_LIST + inputs_are_lists = False + if hasattr(obj, "INPUTS_ARE_LISTS"): + inputs_are_lists = obj.INPUTS_ARE_LISTS - max_len_input = max([len(x) for x in input_data_all.values()]) - - # get a slice of inputs, repeat last input when list isn't long enough - def slice_dict(d, i): - d_new = dict() + def format_dict(d): + s = [] for k,v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new - + st = f"{k}: " + if isinstance(v, list): + st += f"list[len: {len(v)}][" + i = [] + for v2 in v: + i.append(v2.__class__.__name__) + st += ",".join(i) + "]" + else: + st += str(type(v)) + s.append(st) + return "( " + ", ".join(s) + " )" + + print("+++ Obj: " + str(obj)) + print("+++ Inputs: " + format_dict(input_data_all)) + max_len_input = max(len(x) for x in input_data_all.values()) + print("MaxLen " + str(max_len_input)) + print("0 " + str(slice_lists_into_dict(input_data_all, 0))) + results = [] - if intput_is_list: + if inputs_are_lists: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)(**input_data_all)) @@ -65,42 +154,65 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i))) return results -def get_output_data(obj, input_data_all): - - results = [] - uis = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) +def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): + all_outputs = [] + all_outputs_ui = [] + total_batches = len(input_data_all_batches) + print("TOTAL: " + str(total_batches)) - for r in return_values: - if isinstance(r, dict): - if 'ui' in r: - uis.append(r['ui']) - if 'result' in r: - results.append(r['result']) - else: - results.append(r) - - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST + for i, batch in enumerate(input_data_all_batches): + print("***** BATCH: " + str(i)) + return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) + uis = [] + results = [] + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) else: - output.append([o[i] for o in results]) + results.append(r) - ui = dict() - if len(uis) > 0: - ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - return output, ui + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + output_ui = dict() + if len(uis) > 0: + output_ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + + # update the UI after each batch finishes + if len(output_ui) > 0: + if server.client_id is not None: + message = { + "node": unique_id, + "output": output_ui, + "prompt_id": prompt_id, + "batch": i, + "total_batches": total_batches + } + server.send_sync("executed", message, server.client_id) + + all_outputs.append(output) + all_outputs_ui.append(output_ui) + + return all_outputs, all_outputs_ui def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item @@ -119,18 +231,15 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute if input_unique_id not in outputs: recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) + input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) obj = class_def() - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui - if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) + outputs[unique_id] = output_data_from_batches + outputs_ui[unique_id] = output_ui_from_batches executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -163,11 +272,14 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: is_changed_old = old_prompt[unique_id]['is_changed'] if 'is_changed' not in prompt[unique_id]: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all is not None: + input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs) + if input_data_all_batches is not None: try: #is_changed = class_def.IS_CHANGED(**input_data_all) - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + for batch in input_data_all_batches: + if map_node_over_list(class_def, batch, "IS_CHANGED"): + is_changed = True + break prompt[unique_id]['is_changed'] = is_changed except: to_delete = True @@ -286,6 +398,45 @@ class PromptExecutor: comfy.model_management.soft_empty_cache() +def is_combinatorial_input(val): + return isinstance(val, dict) and "__inputType__" in val + + +def get_raw_inputs(raw_val): + if isinstance(raw_val, list): + # link to another node + return [raw_val] + elif is_combinatorial_input(raw_val): + return raw_val["values"] + return [raw_val] + + +def clamp_input(val, info, class_type, obj_class, x): + if is_combinatorial_input(val): + for i, val_choice in enumerate(val["values"]): + r = clamp_input(val_choice, info, class_type, obj_class, x) + if r[0] == False: + return r + val["values"][i] = r[1] + return (True, val) + + type_input = info[0] + + if type_input == "INT": + val = int(val) + if type_input == "FLOAT": + val = float(val) + if type_input == "STRING": + val = str(val) + + if len(info) > 1: + if "min" in info[1] and val < info[1]["min"]: + return (False, "Value smaller than min. {}, {}".format(class_type, x)) + if "max" in info[1] and val > info[1]["max"]: + return (False, "Value bigger than max. {}, {}".format(class_type, x)) + + return (True, val) + def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -300,9 +451,12 @@ def validate_inputs(prompt, item, validated): for x in required_inputs: if x not in inputs: return (False, "Required input is missing. {}, {}".format(class_type, x)) + val = inputs[x] + info = required_inputs[x] type_input = info[0] + if isinstance(val, list): if len(val) != 2: return (False, "Bad Input. {}, {}".format(class_type, x)) @@ -316,33 +470,27 @@ def validate_inputs(prompt, item, validated): validated[o_id] = r return r else: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val + r = clamp_input(val, info, class_type, obj_class, x) + if r[0] == False: + return r - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - return (False, "Value smaller than min. {}, {}".format(class_type, x)) - if "max" in info[1] and val > info[1]["max"]: - return (False, "Value bigger than max. {}, {}".format(class_type, x)) + inputs[x] = r[1] if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) + input_data_all_batches = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for r in ret: - if r != True: - return (False, "{}, {}".format(class_type, r)) + for batch in input_data_all_batches: + ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") + for r in ret: + if r != True: + return (False, "{}, {}".format(class_type, r)) else: if isinstance(type_input, list): - if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + # Account for more than one combinatorial value + raw_vals = get_raw_inputs(val) + for raw_val in raw_vals: + if raw_val not in type_input: + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, raw_val, type_input)) ret = (True, "") validated[unique_id] = ret diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..f0fa06292 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -139,12 +139,21 @@ def get_full_path(folder_name, filename): return full_path +path_cache_dict = {} + +def clear_cache(): + global path_cache_dict + path_cache_dict = {} + + def get_filename_list(folder_name): - global folder_names_and_paths - output_list = set() - folders = folder_names_and_paths[folder_name] - for x in folders[0]: - output_list.update(filter_files_extensions(recursive_search(x), folders[1])) - return sorted(list(output_list)) - + global folder_names_and_paths, path_cache_dict + print("RecursiveWalk! " + folder_name) + if folder_name not in path_cache_dict: + output_list = set() + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + output_list.update(filter_files_extensions(recursive_search(x), folders[1])) + path_cache_dict[folder_name] = sorted(list(output_list)) + return path_cache_dict[folder_name] diff --git a/server.py b/server.py index f52117f10..46508c70c 100644 --- a/server.py +++ b/server.py @@ -264,6 +264,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): out = {} + folder_paths.clear_cache() for x in nodes.NODE_CLASS_MAPPINGS: obj_class = nodes.NODE_CLASS_MAPPINGS[x] info = {} diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index df7d8f071..051e71420 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -195,6 +195,8 @@ app.registerExtension({ this.addOutput("connect to widget input", "*"); this.serialize_widgets = true; this.isVirtualNode = true; + this.properties ||= {} + this.properties.isRange = false; } applyToGraph() { @@ -210,6 +212,13 @@ app.registerExtension({ const widget = node.widgets.find((w) => w.name === widgetName); if (widget) { widget.value = this.widgets[0].value; + if (this.properties.isRange) { + console.error("RANGE") + widget.__rangeData = { __inputType__: "list", values: [widget.value, widget.value + 256] } + } + else { + widget.__rangeData = undefined + } if (widget.callback) { widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); } @@ -304,6 +313,8 @@ app.registerExtension({ addValueControlWidget(this, widget, "fixed"); } + const isRangeWidget = this.addWidget("toggle", "isRange", this.properties.isRange, "isRange"); + // When our value changes, update other widgets to reflect our changes // e.g. so LoadImage shows correct image const callback = widget.callback; diff --git a/web/scripts/app.js b/web/scripts/app.js index 87c5e30ca..cbf3fb0ad 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1195,7 +1195,14 @@ export class ComfyApp { for (const i in widgets) { const widget = widgets[i]; if (!widget.options || widget.options.serialize !== false) { - inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + let widgetValue = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + + if (widget.__rangeData) { + console.error("SETRANGE", widget.name, widget.__rangeData) + widgetValue = widget.__rangeData; + } + + inputs[widget.name] = widgetValue } } }