mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Allow specifying which inputs should be used as lists
This commit is contained in:
parent
d934119333
commit
e4be8e0666
@ -3,11 +3,10 @@ import torch
|
||||
class LatentRebatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "latents": ("LATENT",),
|
||||
return {"required": { "latents": ("LATENT", { "is_list": True }),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
INPUT_IS_LIST = True
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
OUTPUT_IS_LIST = (True, )
|
||||
|
||||
FUNCTION = "rebatch"
|
||||
@ -54,8 +53,6 @@ class LatentRebatch:
|
||||
return result
|
||||
|
||||
def rebatch(self, latents, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
||||
output_list = []
|
||||
current_batch = (None, None, None)
|
||||
processed = 0
|
||||
@ -105,4 +102,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RebatchLatents": "Rebatch Latents",
|
||||
}
|
||||
}
|
||||
|
||||
445
execution.py
445
execution.py
@ -13,21 +13,44 @@ import nodes
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
def slice_lists_into_dict(d, i):
|
||||
"""
|
||||
get a slice of inputs, repeat last input when list isn't long enough
|
||||
d={ "seed": [ 1, 2, 3 ], "steps": [ 4, 8 ] }, i=2 -> { "seed": 3, "steps": 8 }
|
||||
"""
|
||||
d_new = {}
|
||||
for k, v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
required = valid_inputs.get("required", {})
|
||||
optional = valid_inputs.get("optional", {})
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
input_type = input_data["type"]
|
||||
|
||||
input_def = required.get(x)
|
||||
if input_def is None:
|
||||
input_def = optional.get(x)
|
||||
|
||||
use_value_as_list = input_def is not None and len(input_def) > 1 and input_def[1].get("is_list", False)
|
||||
|
||||
if input_type == "link":
|
||||
input_unique_id = input_data["origin_id"]
|
||||
output_index = input_data["origin_slot"]
|
||||
if input_unique_id not in outputs:
|
||||
return None
|
||||
obj = outputs[input_unique_id][output_index]
|
||||
if use_value_as_list:
|
||||
obj = [obj]
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = [input_data]
|
||||
value = input_data["value"]
|
||||
if input_def is not None:
|
||||
input_data_all[x] = [value]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
@ -39,37 +62,23 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
|
||||
return input_data_all
|
||||
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
intput_is_list = obj.INPUT_IS_LIST
|
||||
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k,v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
|
||||
results = []
|
||||
if intput_is_list:
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
|
||||
args = slice_lists_into_dict(input_data_all, i)
|
||||
results.append(getattr(obj, func)(**args))
|
||||
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
@ -120,10 +129,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
input_type = input_data["type"]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_type == "link":
|
||||
input_unique_id = input_data["origin_id"]
|
||||
if input_unique_id not in outputs:
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
if result[0] is not True:
|
||||
@ -192,9 +201,9 @@ def recursive_will_execute(prompt, outputs, current_item):
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
input_type = input_data["type"]
|
||||
if input_type == "link":
|
||||
input_unique_id = input_data["origin_id"]
|
||||
if input_unique_id not in outputs:
|
||||
will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
|
||||
|
||||
@ -235,10 +244,10 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
input_type = input_data["type"]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_type == "link":
|
||||
input_unique_id = input_data["origin_id"]
|
||||
if input_unique_id in outputs:
|
||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
||||
else:
|
||||
@ -366,6 +375,150 @@ class PromptExecutor:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
|
||||
|
||||
def validate_link(prompt, x, val, info, validated):
|
||||
type_input = info[0]
|
||||
|
||||
o_id = val.get("origin_id", None)
|
||||
o_slot = val.get("origin_slot", None)
|
||||
|
||||
if o_id is None or o_slot is None:
|
||||
error = {
|
||||
"type": "bad_linked_input",
|
||||
"message": "Bad linked input, must be a dictionary like { type: 'link', origin_id: 1, origin_slot: 1 }",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val
|
||||
}
|
||||
}
|
||||
return (False, error)
|
||||
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[o_slot] != type_input:
|
||||
received_type = r[val[1]]
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
"message": "Return type mismatch between linked nodes",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_type": received_type,
|
||||
"linked_node": val
|
||||
}
|
||||
}
|
||||
return (False, error)
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
return (False, None)
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
reasons = [{
|
||||
"type": "exception_during_inner_validation",
|
||||
"message": "Exception when validating inner node",
|
||||
"details": str(ex),
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"linked_node": val,
|
||||
"linked_node_inputs": prompt[o_id]
|
||||
}
|
||||
}]
|
||||
validated[o_id] = (False, reasons, o_id)
|
||||
return (False, None)
|
||||
|
||||
return (True, val)
|
||||
|
||||
|
||||
def validate_value(inputs, unique_id, x, val, info, obj_class):
|
||||
type_input = info[0]
|
||||
result_val = val
|
||||
|
||||
try:
|
||||
if type_input == "INT":
|
||||
result_val = int(val)
|
||||
if type_input == "FLOAT":
|
||||
result_val = float(val)
|
||||
if type_input == "STRING":
|
||||
result_val = str(val)
|
||||
except Exception as ex:
|
||||
error = {
|
||||
"type": "invalid_input_type",
|
||||
"message": f"Failed to convert an input value to a {type_input} value",
|
||||
"details": f"{x}, {val}, {ex}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
"exception_message": str(ex)
|
||||
}
|
||||
}
|
||||
return (False, error)
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
return (False, error)
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
return (False, error)
|
||||
else:
|
||||
# Validate combo widget
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
# Don't send back gigantic lists like if they're lots of
|
||||
# scanned model filepaths
|
||||
if len(type_input) > 20:
|
||||
list_info = f"(list of length {len(type_input)})"
|
||||
input_config = None
|
||||
else:
|
||||
list_info = str(type_input)
|
||||
|
||||
error = {
|
||||
"type": "value_not_in_list",
|
||||
"message": "Value not in list",
|
||||
"details": f"{x}: '{val}' not in {list_info}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": input_config,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
return (False, error)
|
||||
|
||||
return (True, result_val)
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
@ -396,168 +549,84 @@ def validate_inputs(prompt, item, validated):
|
||||
|
||||
val = inputs[x]
|
||||
info = required_inputs[x]
|
||||
type_input = info[0]
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
error = {
|
||||
"type": "bad_linked_input",
|
||||
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val
|
||||
}
|
||||
|
||||
input_type = None
|
||||
if isinstance(val, dict):
|
||||
input_type = val.get("type", None)
|
||||
|
||||
if input_type not in ["link", "value"]:
|
||||
error = {
|
||||
"type": "bad_input_format",
|
||||
"message": "Bad input format, must be a dictionary with 'type' set to 'link' or 'value'",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val
|
||||
}
|
||||
errors.append(error)
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if input_type == "link":
|
||||
result = validate_link(prompt, x, val, info, validated)
|
||||
if result[0] is False:
|
||||
valid = False
|
||||
if result[1] is not None:
|
||||
errors.append(result[1])
|
||||
continue
|
||||
|
||||
o_id = val[0]
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[val[1]] != type_input:
|
||||
received_type = r[val[1]]
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
inputs[x] = result[1]
|
||||
|
||||
elif input_type == "value":
|
||||
inner_val = val.get("value", None)
|
||||
if inner_val is None:
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
"message": "Return type mismatch between linked nodes",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_type": received_type,
|
||||
"linked_node": val
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
continue
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
valid = False
|
||||
exception_type = full_type_name(typ)
|
||||
reasons = [{
|
||||
"type": "exception_during_inner_validation",
|
||||
"message": "Exception when validating inner node",
|
||||
"details": str(ex),
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"linked_node": val
|
||||
}
|
||||
}]
|
||||
validated[o_id] = (False, reasons, o_id)
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
if type_input == "FLOAT":
|
||||
val = float(val)
|
||||
inputs[x] = val
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
except Exception as ex:
|
||||
error = {
|
||||
"type": "invalid_input_type",
|
||||
"message": f"Failed to convert an input value to a {type_input} value",
|
||||
"details": f"{x}, {val}, {ex}",
|
||||
"type": "bad_value_input",
|
||||
"message": "Bad value input, must be a dictionary like { type: 'value', value: 42 }",
|
||||
"details": f"{x}, {val}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
"exception_message": str(ex)
|
||||
}
|
||||
}
|
||||
return (False, error)
|
||||
|
||||
result = validate_value(inputs, unique_id, x, inner_val, info, obj_class)
|
||||
|
||||
if result[0] is False:
|
||||
errors.append(result[1])
|
||||
continue
|
||||
|
||||
inputs[x] = { "type": "value", "value": result[1] }
|
||||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
details = ""
|
||||
if r is not False:
|
||||
details += str(r)
|
||||
|
||||
input_data_formatted = {}
|
||||
if input_data_all is not None:
|
||||
input_data_formatted = {}
|
||||
for name, inputList in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputList]
|
||||
|
||||
error = {
|
||||
"type": "custom_validation_failed",
|
||||
"message": "Custom validation failed for node",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_config": info,
|
||||
"received_inputs": input_data_formatted,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
|
||||
error = {
|
||||
"type": "custom_validation_failed",
|
||||
"message": "Custom validation failed for node",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
else:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
# Don't send back gigantic lists like if they're lots of
|
||||
# scanned model filepaths
|
||||
if len(type_input) > 20:
|
||||
list_info = f"(list of length {len(type_input)})"
|
||||
input_config = None
|
||||
else:
|
||||
list_info = str(type_input)
|
||||
|
||||
error = {
|
||||
"type": "value_not_in_list",
|
||||
"message": "Value not in list",
|
||||
"details": f"{x}: '{val}' not in {list_info}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": input_config,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(errors) > 0 or valid is not True:
|
||||
ret = (False, errors, unique_id)
|
||||
@ -644,7 +713,7 @@ def validate_prompt(prompt):
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
print("Output will be ignored")
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
if len(good_outputs) == 0 or node_errors:
|
||||
errors_list = []
|
||||
for o, errors in errors:
|
||||
for error in errors:
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -1085,7 +1085,7 @@ class LoadImage:
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(files), { "forceInput": True })},
|
||||
{"image": (sorted(files), )},
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
@ -1127,7 +1127,7 @@ class LoadImageBatch:
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"images": (sorted(files), )},
|
||||
{"images": ("MULTIIMAGEUPLOAD", { "filepaths": sorted(files) } )},
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
@ -1135,7 +1135,6 @@ class LoadImageBatch:
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_images"
|
||||
|
||||
INPUT_IS_LIST = True
|
||||
OUTPUT_IS_LIST = (True, True, )
|
||||
|
||||
def load_images(self, images):
|
||||
@ -1437,6 +1436,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PreviewImage": "Preview Image",
|
||||
"LoadImage": "Load Image",
|
||||
"LoadImageMask": "Load Image (as Mask)",
|
||||
"LoadImageBatch": "Load Image Batch",
|
||||
"ImageScale": "Upscale Image",
|
||||
"ImageUpscaleWithModel": "Upscale Image (using Model)",
|
||||
"ImageInvert": "Invert Image",
|
||||
|
||||
@ -10,9 +10,6 @@ app.registerExtension({
|
||||
case "LoadImageMask":
|
||||
nodeData.input.required.upload = ["IMAGEUPLOAD"];
|
||||
break;
|
||||
case "LoadImageBatch":
|
||||
nodeData.input.required.upload = ["MULTIIMAGEUPLOAD"];
|
||||
break;
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@ -1110,20 +1110,21 @@ export class ComfyApp {
|
||||
for (const inputName in inputs) {
|
||||
const inputData = inputs[inputName];
|
||||
const type = inputData[0];
|
||||
const inputShape = nodeData["input_is_list"] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE;
|
||||
const options = inputData[1] || {};
|
||||
const inputShape = options.is_list ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE;
|
||||
|
||||
if(inputData[1]?.forceInput) {
|
||||
this.addInput(inputName, type, { shape: inputShape });
|
||||
} else {
|
||||
if (Array.isArray(type)) {
|
||||
// Enums
|
||||
Object.assign(config, widgets.COMBO(this, inputName, inputData, nodeData, app) || {});
|
||||
Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {});
|
||||
} else if (`${type}:${inputName}` in widgets) {
|
||||
// Support custom widgets by Type:Name
|
||||
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, nodeData, app) || {});
|
||||
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});
|
||||
} else if (type in widgets) {
|
||||
// Standard type widgets
|
||||
Object.assign(config, widgets[type](this, inputName, inputData, nodeData, app) || {});
|
||||
Object.assign(config, widgets[type](this, inputName, inputData, app) || {});
|
||||
} else {
|
||||
// Node connection inputs
|
||||
this.addInput(inputName, type, { shape: inputShape });
|
||||
@ -1313,7 +1314,8 @@ export class ComfyApp {
|
||||
for (const i in widgets) {
|
||||
const widget = widgets[i];
|
||||
if (!widget.options || widget.options.serialize !== false) {
|
||||
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
|
||||
const value = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
|
||||
inputs[widget.name] = { type: "value", value }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1333,7 +1335,11 @@ export class ComfyApp {
|
||||
}
|
||||
|
||||
if (link) {
|
||||
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
|
||||
inputs[node.inputs[i].name] = {
|
||||
type: "link",
|
||||
origin_id: String(link.origin_id),
|
||||
origin_slot: parseInt(link.origin_slot)
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1377,6 +1383,9 @@ export class ComfyApp {
|
||||
message += "\n" + nodeError.class_type + ":"
|
||||
for (const errorReason of nodeError.errors) {
|
||||
message += "\n - " + errorReason.message + ": " + errorReason.details
|
||||
if (errorReason.extra_info?.traceback) {
|
||||
message += "\n" + errorReason.extra_info.traceback.join("")
|
||||
}
|
||||
}
|
||||
}
|
||||
return message
|
||||
|
||||
@ -268,7 +268,7 @@ const INT = (node, inputName, inputData) => {
|
||||
};
|
||||
}
|
||||
|
||||
const STRING = (node, inputName, inputData, nodeData, app) => {
|
||||
const STRING = (node, inputName, inputData, app) => {
|
||||
const defaultVal = inputData[1].default || "";
|
||||
const multiline = !!inputData[1].multiline;
|
||||
|
||||
@ -279,14 +279,15 @@ const STRING = (node, inputName, inputData, nodeData, app) => {
|
||||
}
|
||||
}
|
||||
|
||||
const COMBO = (node, inputName, inputData, nodeData) => {
|
||||
const COMBO = (node, inputName, inputData) => {
|
||||
const type = inputData[0];
|
||||
let defaultValue = type[0];
|
||||
if (inputData[1] && inputData[1].default) {
|
||||
defaultValue = inputData[1].default;
|
||||
let options = inputData[1] || {}
|
||||
if (options.default) {
|
||||
defaultValue = options.default
|
||||
}
|
||||
|
||||
if (nodeData["input_is_list"]) {
|
||||
if (options.is_list) {
|
||||
defaultValue = [defaultValue]
|
||||
const widget = node.addWidget("text", inputName, defaultValue, () => {}, { values: type })
|
||||
widget.disabled = true;
|
||||
@ -297,7 +298,7 @@ const COMBO = (node, inputName, inputData, nodeData) => {
|
||||
}
|
||||
}
|
||||
|
||||
const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
|
||||
const IMAGEUPLOAD = (node, inputName, inputData, app) => {
|
||||
const imageWidget = node.widgets.find((w) => w.name === "image");
|
||||
let uploadWidget;
|
||||
|
||||
@ -412,8 +413,7 @@ const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
|
||||
uploadWidget = node.addWidget("button", "choose file to upload", "image", () => {
|
||||
fileInput.value = null;
|
||||
fileInput.click();
|
||||
});
|
||||
uploadWidget.serialize = false;
|
||||
}, { serialize: false });
|
||||
|
||||
// Add handler to check if an image is being dragged over our node
|
||||
node.onDragOver = function (e) {
|
||||
@ -442,8 +442,14 @@ const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
|
||||
return { widget: uploadWidget };
|
||||
}
|
||||
|
||||
const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
|
||||
const imagesWidget = node.widgets.find((w) => w.name === "images");
|
||||
const MULTIIMAGEUPLOAD = (node, inputName, inputData, app) => {
|
||||
const imagesWidget = node.addWidget("text", inputName, inputData, () => {})
|
||||
|
||||
imagesWidget._filepaths = []
|
||||
if (inputData[1] && inputData[1].filepaths) {
|
||||
imagesWidget._filepaths = inputData[1].filepaths
|
||||
}
|
||||
|
||||
let uploadWidget;
|
||||
let clearWidget;
|
||||
|
||||
@ -534,11 +540,6 @@ const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
|
||||
|
||||
if (resp.status === 200) {
|
||||
const data = await resp.json();
|
||||
// Add the file as an option and update the widget value
|
||||
if (!imagesWidget.options.values.includes(data.name)) {
|
||||
imagesWidget.options.values.push(data.name);
|
||||
}
|
||||
|
||||
if (updateNode) {
|
||||
imagesWidget.value.push(data.name)
|
||||
}
|
||||
@ -573,14 +574,12 @@ const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => {
|
||||
uploadWidget = node.addWidget("button", "choose files to upload", "images", () => {
|
||||
fileInput.value = null;
|
||||
fileInput.click();
|
||||
});
|
||||
uploadWidget.serialize = false;
|
||||
}, { serialize: false });
|
||||
|
||||
clearWidget = node.addWidget("button", "clear all uploads", "images", () => {
|
||||
imagesWidget.value = []
|
||||
showImages(imagesWidget.value);
|
||||
});
|
||||
clearWidget.serialize = false;
|
||||
}, { serialize: false });
|
||||
|
||||
// Add handler to check if an image is being dragged over our node
|
||||
node.onDragOver = function (e) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user