Add nodes for dealing with BOOLs

This commit is contained in:
Jacob Segal 2023-07-20 21:04:08 -07:00
parent 88fc046180
commit 9d9e1e65ab
7 changed files with 258 additions and 28 deletions

View File

@ -2,6 +2,7 @@ from .nodes import GENERAL_NODE_CLASS_MAPPINGS, GENERAL_NODE_DISPLAY_NAME_MAPPIN
from .components import setup_js, COMPONENT_NODE_CLASS_MAPPINGS, COMPONENT_NODE_DISPLAY_NAME_MAPPINGS
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
from .utility_nodes import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
@ -11,12 +12,14 @@ NODE_CLASS_MAPPINGS.update(GENERAL_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS.update(GENERAL_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
setup_js()

View File

@ -53,6 +53,7 @@ class ComponentOutput:
"required": {
"index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"data_type": ("STRING", {"multiline": False, "default": "IMAGE"}),
"name": ("STRING", {"multiline": False}),
"value": ("*",),
},
}
@ -62,7 +63,7 @@ class ComponentOutput:
CATEGORY = "Component Creation"
def component_output(self, index, data_type, value):
def component_output(self, index, data_type, name, value):
return (value,)
class ComponentMetadata:
@ -152,6 +153,7 @@ def LoadComponent(component_file):
elif data["class_type"] == "ComponentOutput":
component_outputs.append({
"node_id": node_id,
"name": data["inputs"]["name"] or data["inputs"]["data_type"],
"index": data["inputs"]["index"],
"data_type": data["inputs"]["data_type"],
})
@ -179,6 +181,7 @@ def LoadComponent(component_file):
}
RETURN_TYPES = tuple([node["data_type"] for node in component_outputs])
RETURN_NAMES = tuple([node["name"] for node in component_outputs])
FUNCTION = "expand_component"
CATEGORY = "Custom Components"

View File

@ -0,0 +1,162 @@
import re
import torch
class IntConditions:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
"operation": (["==", "!=", "<", ">", "<=", ">="],),
},
}
RETURN_TYPES = ("BOOL",)
FUNCTION = "int_condition"
CATEGORY = "Conditions"
def int_condition(self, a, b, operation):
if operation == "==":
return (a == b,)
elif operation == "!=":
return (a != b,)
elif operation == "<":
return (a < b,)
elif operation == ">":
return (a > b,)
elif operation == "<=":
return (a <= b,)
elif operation == ">=":
return (a >= b,)
class FloatConditions:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
"b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
"operation": (["==", "!=", "<", ">", "<=", ">="],),
},
}
RETURN_TYPES = ("BOOL",)
FUNCTION = "float_condition"
CATEGORY = "Conditions"
def float_condition(self, a, b, operation):
if operation == "==":
return (a == b,)
elif operation == "!=":
return (a != b,)
elif operation == "<":
return (a < b,)
elif operation == ">":
return (a > b,)
elif operation == "<=":
return (a <= b,)
elif operation == ">=":
return (a >= b,)
class StringConditions:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("STRING", {"multiline": False}),
"b": ("STRING", {"multiline": False}),
"operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],),
"case_sensitive": ("BOOL", {"default": True}),
},
}
RETURN_TYPES = ("BOOL",)
FUNCTION = "string_condition"
CATEGORY = "Conditions"
def string_condition(self, a, b, operation, case_sensitive):
if not case_sensitive:
a = a.lower()
b = b.lower()
if operation == "a == b":
return (a == b,)
elif operation == "a != b":
return (a != b,)
elif operation == "a IN b":
return (a in b,)
elif operation == "a MATCH REGEX(b)":
try:
return (re.match(b, a) is not None,)
except:
return (False,)
elif operation == "a BEGINSWITH b":
return (a.startswith(b),)
elif operation == "a ENDSWITH b":
return (a.endswith(b),)
class ToBoolNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("*",),
},
"optional": {
"invert": ("BOOL", {"default": False}),
},
}
RETURN_TYPES = ("BOOL",)
FUNCTION = "to_bool"
CATEGORY = "InversionDemo Nodes"
def to_bool(self, value, invert = False):
if isinstance(value, torch.Tensor):
if value.max().item() == 0 and value.min().item() == 0:
result = False
else:
result = True
else:
try:
result = bool(value)
except:
# Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer.
result = True
if invert:
result = not result
return (result,)
CONDITION_NODE_CLASS_MAPPINGS = {
"IntConditions": IntConditions,
"FloatConditions": FloatConditions,
"StringConditions": StringConditions,
"ToBoolNode": ToBoolNode,
}
CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
"IntConditions": "Int Condition",
"FloatConditions": "Float Condition",
"StringConditions": "String Condition",
"ToBoolNode": "To Bool",
}

View File

@ -9,7 +9,7 @@ class WhileLoopOpen:
def INPUT_TYPES(cls):
inputs = {
"required": {
"condition": ("INT", {"default": 1, "min": 0, "max": 1, "step": 1}),
"condition": ("BOOL", {"default": True}),
},
"optional": {
},
@ -38,8 +38,8 @@ class WhileLoopClose:
def INPUT_TYPES(cls):
inputs = {
"required": {
"flow_control": ("FLOW_CONTROL", {"raw_link": True}),
"condition": ("INT", {"default": 0, "min": 0, "max": 1, "step": 1}),
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
"condition": ("BOOL", {"forceInput": True}),
},
"optional": {
},

View File

@ -83,7 +83,7 @@ class InversionDemoLazySwitch:
def INPUT_TYPES(cls):
return {
"required": {
"switch": ([False, True],),
"switch": ("BOOL",),
"on_false": ("*", {"lazy": True}),
"on_true": ("*", {"lazy": True}),
},
@ -103,6 +103,61 @@ class InversionDemoLazySwitch:
def switch(self, switch, on_false = None, on_true = None):
value = on_true if switch else on_false
return (value,)
NUM_IF_ELSE_NODES = 10
class InversionDemoLazyConditional:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
args = {
"value1": ("*", {"lazy": True}),
"condition1": ("BOOL", {"forceInput": True}),
}
for i in range(1,NUM_IF_ELSE_NODES):
args["value%d" % (i + 1)] = ("*", {"lazy": True})
args["condition%d" % (i + 1)] = ("BOOL", {"lazy": True, "forceInput": True})
args["else"] = ("*", {"lazy": True})
return {
"required": {},
"optional": args,
}
RETURN_TYPES = ("*",)
FUNCTION = "conditional"
CATEGORY = "InversionDemo Nodes"
def check_lazy_status(self, **kwargs):
for i in range(0,NUM_IF_ELSE_NODES):
cond = "condition%d" % (i + 1)
if cond not in kwargs:
return [cond]
if kwargs[cond]:
val = "value%d" % (i + 1)
if val not in kwargs:
return [val]
else:
return []
if "else" not in kwargs:
return ["else"]
def conditional(self, **kwargs):
for i in range(0,NUM_IF_ELSE_NODES):
cond = "condition%d" % (i + 1)
if cond not in kwargs:
return [cond]
if kwargs.get(cond, False):
val = "value%d" % (i + 1)
return (kwargs.get(val, None),)
return (kwargs.get("else", None),)
class InversionDemoLazyIndexSwitch:
def __init__(self):
@ -195,6 +250,7 @@ GENERAL_NODE_CLASS_MAPPINGS = {
"InversionDemoLazySwitch": InversionDemoLazySwitch,
"InversionDemoLazyIndexSwitch": InversionDemoLazyIndexSwitch,
"InversionDemoLazyMixImages": InversionDemoLazyMixImages,
"InversionDemoLazyConditional": InversionDemoLazyConditional,
}
GENERAL_NODE_DISPLAY_NAME_MAPPINGS = {
@ -203,4 +259,5 @@ GENERAL_NODE_DISPLAY_NAME_MAPPINGS = {
"InversionDemoLazySwitch": "Lazy Switch",
"InversionDemoLazyIndexSwitch": "Lazy Index Switch",
"InversionDemoLazyMixImages": "Lazy Mix Images",
"InversionDemoLazyConditional": "Lazy Conditional",
}

View File

@ -1,4 +1,3 @@
import torch
from comfy.graph_utils import GraphBuilder
class AccumulateNode:
@ -120,7 +119,7 @@ class ListToAccumulationNode:
def accumulation_to_list(self, list):
return ({"accum": list},)
class IsTruthyNode:
class IntMathOperation:
def __init__(self):
pass
@ -128,26 +127,31 @@ class IsTruthyNode:
def INPUT_TYPES(cls):
return {
"required": {
"value": ("*",),
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
"operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],),
},
}
RETURN_TYPES = ("INT",)
FUNCTION = "is_truthy"
FUNCTION = "int_math_operation"
CATEGORY = "InversionDemo Nodes"
def is_truthy(self, value):
if isinstance(value, torch.Tensor):
if value.max().item() == 0 and value.min().item() == 0:
return (0,)
else:
return (1,)
try:
return (int(bool(value)),)
except:
# Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer.
return (1,)
def int_math_operation(self, a, b, operation):
if operation == "add":
return (a + b,)
elif operation == "subtract":
return (a - b,)
elif operation == "multiply":
return (a * b,)
elif operation == "divide":
return (a // b,)
elif operation == "modulo":
return (a % b,)
elif operation == "power":
return (a ** b,)
from .flow_control import NUM_FLOW_SOCKETS
class ForLoopOpen:
@ -193,11 +197,11 @@ class ForLoopClose:
def INPUT_TYPES(cls):
return {
"required": {
"flow_control": ("FLOW_CONTROL", {"raw_link": True}),
"old_remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}),
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
"old_remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1, "forceInput": True}),
},
"optional": {
"initial_value%d" % i: ("*",{"raw_link": True}) for i in range(1, NUM_FLOW_SOCKETS)
"initial_value%d" % i: ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS)
},
}
@ -211,11 +215,12 @@ class ForLoopClose:
graph = GraphBuilder()
while_open = flow_control[0]
# TODO - Requires WAS-ns. Will definitely want to solve before merging
sub = graph.node("Number Operation", operation="subtraction", number_a=[while_open,1], number_b=1)
sub = graph.node("IntMathOperation", operation="subtract", a=[while_open,1], b=1)
cond = graph.node("ToBoolNode", value=sub.out(0))
input_values = {("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)}
while_close = graph.node("WhileLoopClose",
flow_control=flow_control,
condition=sub.out(0),
condition=cond.out(0),
initial_value0=sub.out(0),
**input_values)
return {
@ -230,9 +235,9 @@ UTILITY_NODE_CLASS_MAPPINGS = {
"AccumulationTailNode": AccumulationTailNode,
"AccumulationToListNode": AccumulationToListNode,
"ListToAccumulationNode": ListToAccumulationNode,
"IsTruthyNode": IsTruthyNode,
"ForLoopOpen": ForLoopOpen,
"ForLoopClose": ForLoopClose,
"IntMathOperation": IntMathOperation,
}
UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
"AccumulateNode": "Accumulate",
@ -240,7 +245,7 @@ UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
"AccumulationTailNode": "Accumulation Tail",
"AccumulationToListNode": "Accumulation to List",
"ListToAccumulationNode": "List to Accumulation",
"IsTruthyNode": "Is Truthy",
"ForLoopOpen": "For Loop Open",
"ForLoopClose": "For Loop Close",
"IntMathOperation": "Int Math Operation",
}

View File

@ -160,7 +160,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x)
if isinstance(input_data, list) and not input_info.get("raw_link", False):
if isinstance(input_data, list) and not input_info.get("rawLink", False):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs: