Allow specifying which inputs should be used as lists

This commit is contained in:
space-nuko 2023-05-29 14:56:40 -05:00
parent d934119333
commit e4be8e0666
6 changed files with 296 additions and 225 deletions

View File

@ -3,11 +3,10 @@ import torch
class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
return {"required": { "latents": ("LATENT", { "is_list": True }),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True
RETURN_TYPES = ("LATENT", )
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
@ -54,8 +53,6 @@ class LatentRebatch:
return result
def rebatch(self, latents, batch_size):
batch_size = batch_size[0]
output_list = []
current_batch = (None, None, None)
processed = 0
@ -105,4 +102,4 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}
}

View File

@ -13,21 +13,44 @@ import nodes
import comfy.model_management
def slice_lists_into_dict(d, i):
"""
get a slice of inputs, repeat last input when list isn't long enough
d={ "seed": [ 1, 2, 3 ], "steps": [ 4, 8 ] }, i=2 -> { "seed": 3, "steps": 8 }
"""
d_new = {}
for k, v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
required = valid_inputs.get("required", {})
optional = valid_inputs.get("optional", {})
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
input_type = input_data["type"]
input_def = required.get(x)
if input_def is None:
input_def = optional.get(x)
use_value_as_list = input_def is not None and len(input_def) > 1 and input_def[1].get("is_list", False)
if input_type == "link":
input_unique_id = input_data["origin_id"]
output_index = input_data["origin_slot"]
if input_unique_id not in outputs:
return None
obj = outputs[input_unique_id][output_index]
if use_value_as_list:
obj = [obj]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = [input_data]
value = input_data["value"]
if input_def is not None:
input_data_all[x] = [value]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
@ -39,37 +62,23 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
return input_data_all
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
results = []
if intput_is_list:
max_len_input = max([len(x) for x in input_data_all.values()])
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
args = slice_lists_into_dict(input_data_all, i)
results.append(getattr(obj, func)(**args))
return results
def get_output_data(obj, input_data_all):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
@ -120,10 +129,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for x in inputs:
input_data = inputs[x]
input_type = input_data["type"]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_type == "link":
input_unique_id = input_data["origin_id"]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
if result[0] is not True:
@ -192,9 +201,9 @@ def recursive_will_execute(prompt, outputs, current_item):
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
input_type = input_data["type"]
if input_type == "link":
input_unique_id = input_data["origin_id"]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
@ -235,10 +244,10 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
input_type = input_data["type"]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_type == "link":
input_unique_id = input_data["origin_id"]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else:
@ -366,6 +375,150 @@ class PromptExecutor:
comfy.model_management.soft_empty_cache()
def validate_link(prompt, x, val, info, validated):
type_input = info[0]
o_id = val.get("origin_id", None)
o_slot = val.get("origin_slot", None)
if o_id is None or o_slot is None:
error = {
"type": "bad_linked_input",
"message": "Bad linked input, must be a dictionary like { type: 'link', origin_id: 1, origin_slot: 1 }",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val
}
}
return (False, error)
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[o_slot] != type_input:
received_type = r[val[1]]
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_type": received_type,
"linked_node": val
}
}
return (False, error)
try:
r = validate_inputs(prompt, o_id, validated)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
return (False, None)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
reasons = [{
"type": "exception_during_inner_validation",
"message": "Exception when validating inner node",
"details": str(ex),
"extra_info": {
"input_name": x,
"input_config": info,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"linked_node": val,
"linked_node_inputs": prompt[o_id]
}
}]
validated[o_id] = (False, reasons, o_id)
return (False, None)
return (True, val)
def validate_value(inputs, unique_id, x, val, info, obj_class):
type_input = info[0]
result_val = val
try:
if type_input == "INT":
result_val = int(val)
if type_input == "FLOAT":
result_val = float(val)
if type_input == "STRING":
result_val = str(val)
except Exception as ex:
error = {
"type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value",
"details": f"{x}, {val}, {ex}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
"exception_message": str(ex)
}
}
return (False, error)
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
return (False, error)
if "max" in info[1] and val > info[1]["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
return (False, error)
else:
# Validate combo widget
if isinstance(type_input, list):
if val not in type_input:
input_config = info
list_info = ""
# Don't send back gigantic lists like if they're lots of
# scanned model filepaths
if len(type_input) > 20:
list_info = f"(list of length {len(type_input)})"
input_config = None
else:
list_info = str(type_input)
error = {
"type": "value_not_in_list",
"message": "Value not in list",
"details": f"{x}: '{val}' not in {list_info}",
"extra_info": {
"input_name": x,
"input_config": input_config,
"received_value": val,
}
}
return (False, error)
return (True, result_val)
def validate_inputs(prompt, item, validated):
unique_id = item
if unique_id in validated:
@ -396,168 +549,84 @@ def validate_inputs(prompt, item, validated):
val = inputs[x]
info = required_inputs[x]
type_input = info[0]
if isinstance(val, list):
if len(val) != 2:
error = {
"type": "bad_linked_input",
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val
}
input_type = None
if isinstance(val, dict):
input_type = val.get("type", None)
if input_type not in ["link", "value"]:
error = {
"type": "bad_input_format",
"message": "Bad input format, must be a dictionary with 'type' set to 'link' or 'value'",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val
}
errors.append(error)
}
errors.append(error)
continue
if input_type == "link":
result = validate_link(prompt, x, val, info, validated)
if result[0] is False:
valid = False
if result[1] is not None:
errors.append(result[1])
continue
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input:
received_type = r[val[1]]
details = f"{x}, {received_type} != {type_input}"
inputs[x] = result[1]
elif input_type == "value":
inner_val = val.get("value", None)
if inner_val is None:
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_type": received_type,
"linked_node": val
}
}
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
continue
except Exception as ex:
typ, _, tb = sys.exc_info()
valid = False
exception_type = full_type_name(typ)
reasons = [{
"type": "exception_during_inner_validation",
"message": "Exception when validating inner node",
"details": str(ex),
"extra_info": {
"input_name": x,
"input_config": info,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"linked_node": val
}
}]
validated[o_id] = (False, reasons, o_id)
continue
else:
try:
if type_input == "INT":
val = int(val)
inputs[x] = val
if type_input == "FLOAT":
val = float(val)
inputs[x] = val
if type_input == "STRING":
val = str(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value",
"details": f"{x}, {val}, {ex}",
"type": "bad_value_input",
"message": "Bad value input, must be a dictionary like { type: 'value', value: 42 }",
"details": f"{x}, {val}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
"exception_message": str(ex)
}
}
return (False, error)
result = validate_value(inputs, unique_id, x, inner_val, info, obj_class)
if result[0] is False:
errors.append(result[1])
continue
inputs[x] = { "type": "value", "value": result[1] }
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret):
if r is not True:
details = ""
if r is not False:
details += str(r)
input_data_formatted = {}
if input_data_all is not None:
input_data_formatted = {}
for name, inputList in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputList]
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_config": info,
"received_inputs": input_data_formatted,
}
}
errors.append(error)
continue
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
list_info = ""
# Don't send back gigantic lists like if they're lots of
# scanned model filepaths
if len(type_input) > 20:
list_info = f"(list of length {len(type_input)})"
input_config = None
else:
list_info = str(type_input)
error = {
"type": "value_not_in_list",
"message": "Value not in list",
"details": f"{x}: '{val}' not in {list_info}",
"extra_info": {
"input_name": x,
"input_config": input_config,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
@ -644,7 +713,7 @@ def validate_prompt(prompt):
node_errors[node_id]["dependent_outputs"].append(o)
print("Output will be ignored")
if len(good_outputs) == 0:
if len(good_outputs) == 0 or node_errors:
errors_list = []
for o, errors in errors:
for error in errors:

View File

@ -1085,7 +1085,7 @@ class LoadImage:
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{"image": (sorted(files), { "forceInput": True })},
{"image": (sorted(files), )},
}
CATEGORY = "image"
@ -1127,7 +1127,7 @@ class LoadImageBatch:
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{"images": (sorted(files), )},
{"images": ("MULTIIMAGEUPLOAD", { "filepaths": sorted(files) } )},
}
CATEGORY = "image"
@ -1135,7 +1135,6 @@ class LoadImageBatch:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_images"
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, True, )
def load_images(self, images):
@ -1437,6 +1436,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"PreviewImage": "Preview Image",
"LoadImage": "Load Image",
"LoadImageMask": "Load Image (as Mask)",
"LoadImageBatch": "Load Image Batch",
"ImageScale": "Upscale Image",
"ImageUpscaleWithModel": "Upscale Image (using Model)",
"ImageInvert": "Invert Image",

View File

@ -10,9 +10,6 @@ app.registerExtension({
case "LoadImageMask":
nodeData.input.required.upload = ["IMAGEUPLOAD"];
break;
case "LoadImageBatch":
nodeData.input.required.upload = ["MULTIIMAGEUPLOAD"];
break;
}
},
});

View File

@ -1110,20 +1110,21 @@ export class ComfyApp {
for (const inputName in inputs) {
const inputData = inputs[inputName];
const type = inputData[0];
const inputShape = nodeData["input_is_list"] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE;
const options = inputData[1] || {};
const inputShape = options.is_list ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE;
if(inputData[1]?.forceInput) {
this.addInput(inputName, type, { shape: inputShape });
} else {
if (Array.isArray(type)) {
// Enums
Object.assign(config, widgets.COMBO(this, inputName, inputData, nodeData, app) || {});
Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {});
} else if (`${type}:${inputName}` in widgets) {
// Support custom widgets by Type:Name
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, nodeData, app) || {});
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});
} else if (type in widgets) {
// Standard type widgets
Object.assign(config, widgets[type](this, inputName, inputData, nodeData, app) || {});
Object.assign(config, widgets[type](this, inputName, inputData, app) || {});
} else {
// Node connection inputs
this.addInput(inputName, type, { shape: inputShape });
@ -1313,7 +1314,8 @@ export class ComfyApp {
for (const i in widgets) {
const widget = widgets[i];
if (!widget.options || widget.options.serialize !== false) {
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
const value = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
inputs[widget.name] = { type: "value", value }
}
}
}
@ -1333,7 +1335,11 @@ export class ComfyApp {
}
if (link) {
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
inputs[node.inputs[i].name] = {
type: "link",
origin_id: String(link.origin_id),
origin_slot: parseInt(link.origin_slot)
};
}
}
}
@ -1377,6 +1383,9 @@ export class ComfyApp {
message += "\n" + nodeError.class_type + ":"
for (const errorReason of nodeError.errors) {
message += "\n - " + errorReason.message + ": " + errorReason.details
if (errorReason.extra_info?.traceback) {
message += "\n" + errorReason.extra_info.traceback.join("")
}
}
}
return message

View File

@ -268,7 +268,7 @@ const INT = (node, inputName, inputData) => {
};
}
const STRING = (node, inputName, inputData, nodeData, app) => {
const STRING = (node, inputName, inputData, app) => {
const defaultVal = inputData[1].default || "";
const multiline = !!inputData[1].multiline;
@ -279,14 +279,15 @@ const STRING = (node, inputName, inputData, nodeData, app) => {
}
}
const COMBO = (node, inputName, inputData, nodeData) => {
const COMBO = (node, inputName, inputData) => {
const type = inputData[0];
let defaultValue = type[0];
if (inputData[1] && inputData[1].default) {
defaultValue = inputData[1].default;
let options = inputData[1] || {}
if (options.default) {
defaultValue = options.default
}
if (nodeData["input_is_list"]) {
if (options.is_list) {
defaultValue = [defaultValue]
const widget = node.addWidget("text", inputName, defaultValue, () => {}, { values: type })
widget.disabled = true;
@ -297,7 +298,7 @@ const COMBO = (node, inputName, inputData, nodeData) => {
}
}
const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
const IMAGEUPLOAD = (node, inputName, inputData, app) => {
const imageWidget = node.widgets.find((w) => w.name === "image");
let uploadWidget;
@ -412,8 +413,7 @@ const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
uploadWidget = node.addWidget("button", "choose file to upload", "image", () => {
fileInput.value = null;
fileInput.click();
});
uploadWidget.serialize = false;
}, { serialize: false });
// Add handler to check if an image is being dragged over our node
node.onDragOver = function (e) {
@ -442,8 +442,14 @@ const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
return { widget: uploadWidget };
}
const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
const imagesWidget = node.widgets.find((w) => w.name === "images");
const MULTIIMAGEUPLOAD = (node, inputName, inputData, app) => {
const imagesWidget = node.addWidget("text", inputName, inputData, () => {})
imagesWidget._filepaths = []
if (inputData[1] && inputData[1].filepaths) {
imagesWidget._filepaths = inputData[1].filepaths
}
let uploadWidget;
let clearWidget;
@ -534,11 +540,6 @@ const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
if (resp.status === 200) {
const data = await resp.json();
// Add the file as an option and update the widget value
if (!imagesWidget.options.values.includes(data.name)) {
imagesWidget.options.values.push(data.name);
}
if (updateNode) {
imagesWidget.value.push(data.name)
}
@ -573,14 +574,12 @@ const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
uploadWidget = node.addWidget("button", "choose files to upload", "images", () => {
fileInput.value = null;
fileInput.click();
});
uploadWidget.serialize = false;
}, { serialize: false });
clearWidget = node.addWidget("button", "clear all uploads", "images", () => {
imagesWidget.value = []
showImages(imagesWidget.value);
});
clearWidget.serialize = false;
}, { serialize: false });
// Add handler to check if an image is being dragged over our node
node.onDragOver = function (e) {