Merge branch 'master' into humo_i2v

This commit is contained in:
drozbay 2025-12-31 07:37:07 -07:00 committed by GitHub
commit 9ee905bc47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 862 additions and 277 deletions

View File

@ -30,6 +30,13 @@ except ImportError as e:
raise e raise e
exit(-1) exit(-1)
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
SAGE_ATTENTION3_IS_AVAILABLE = True
except ImportError:
pass
FLASH_ATTENTION_IS_AVAILABLE = False FLASH_ATTENTION_IS_AVAILABLE = False
try: try:
from flash_attn import flash_attn_func from flash_attn import flash_attn_func
@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = out.reshape(b, -1, heads * dim_head) out = out.reshape(b, -1, heads * dim_head)
return out return out
@wrap_attn
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if (q.device.type != "cuda" or
q.dtype not in (torch.float16, torch.bfloat16) or
mask is not None):
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
B, H, L, D = q.shape
if H != heads:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
B, N, inner_dim = q.shape
if inner_dim % heads != 0:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
dim_head = inner_dim // heads
if dim_head >= 256 or N <= 1024:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
B, H, L, D = q_s.shape
try:
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
except Exception as e:
exception_fallback = True
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
if exception_fallback:
if not skip_reshape:
del q_s, k_s, v_s
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
if not skip_output_reshape:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
else:
if skip_output_reshape:
pass
else:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
return out
try: try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
# register core-supported attention functions # register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE: if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage) register_attention_function("sage", attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
if FLASH_ATTENTION_IS_AVAILABLE: if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash) register_attention_function("flash", attention_flash)
if model_management.xformers_enabled(): if model_management.xformers_enabled():

View File

@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io from . import _io_public as io
from . import _ui_public as ui from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image from PIL import Image

View File

@ -26,7 +26,6 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class) prune_dict, shallow_clone_class)
from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL, SVG as _SVG from ._util import MESH, VOXEL, SVG as _SVG
@ -76,16 +75,6 @@ class NumberDisplay(str, Enum):
slider = "slider" slider = "slider"
class _StringIOType(str):
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class _ComfyType(ABC): class _ComfyType(ABC):
Type = Any Type = Any
io_type: str = None io_type: str = None
@ -125,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
new_cls.__module__ = cls.__module__ new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__ new_cls.__doc__ = cls.__doc__
# assign ComfyType attributes, if needed # assign ComfyType attributes, if needed
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) new_cls.io_type = io_type
new_cls.io_type = _StringIOType(io_type)
if hasattr(new_cls, "Input") and new_cls.Input is not None: if hasattr(new_cls, "Input") and new_cls.Input is not None:
new_cls.Input.Parent = new_cls new_cls.Input.Parent = new_cls
if hasattr(new_cls, "Output") and new_cls.Output is not None: if hasattr(new_cls, "Output") and new_cls.Output is not None:
@ -165,7 +153,7 @@ class Input(_IO_V3):
''' '''
Base class for a V3 Input. Base class for a V3 Input.
''' '''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__() super().__init__()
self.id = id self.id = id
self.display_name = display_name self.display_name = display_name
@ -173,6 +161,7 @@ class Input(_IO_V3):
self.tooltip = tooltip self.tooltip = tooltip
self.lazy = lazy self.lazy = lazy
self.extra_dict = extra_dict if extra_dict is not None else {} self.extra_dict = extra_dict if extra_dict is not None else {}
self.rawLink = raw_link
def as_dict(self): def as_dict(self):
return prune_dict({ return prune_dict({
@ -180,10 +169,11 @@ class Input(_IO_V3):
"optional": self.optional, "optional": self.optional,
"tooltip": self.tooltip, "tooltip": self.tooltip,
"lazy": self.lazy, "lazy": self.lazy,
"rawLink": self.rawLink,
}) | prune_dict(self.extra_dict) }) | prune_dict(self.extra_dict)
def get_io_type(self): def get_io_type(self):
return _StringIOType(self.io_type) return self.io_type
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] return [self]
@ -194,8 +184,8 @@ class WidgetInput(Input):
''' '''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: Any=None, default: Any=None,
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self.default = default self.default = default
self.socketless = socketless self.socketless = socketless
self.widget_type = widget_type self.widget_type = widget_type
@ -217,13 +207,14 @@ class Output(_IO_V3):
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False): is_output_list=False):
self.id = id self.id = id
self.display_name = display_name self.display_name = display_name if display_name else id
self.tooltip = tooltip self.tooltip = tooltip
self.is_output_list = is_output_list self.is_output_list = is_output_list
def as_dict(self): def as_dict(self):
display_name = self.display_name if self.display_name else self.id
return prune_dict({ return prune_dict({
"display_name": self.display_name, "display_name": display_name,
"tooltip": self.tooltip, "tooltip": self.tooltip,
"is_output_list": self.is_output_list, "is_output_list": self.is_output_list,
}) })
@ -251,8 +242,8 @@ class Boolean(ComfyTypeIO):
'''Boolean input.''' '''Boolean input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None, default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, force_input: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.label_on = label_on self.label_on = label_on
self.label_off = label_off self.label_off = label_off
self.default: bool self.default: bool
@ -271,8 +262,8 @@ class Int(ComfyTypeIO):
'''Integer input.''' '''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
@ -297,8 +288,8 @@ class Float(ComfyTypeIO):
'''Float input.''' '''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
@ -323,8 +314,8 @@ class String(ComfyTypeIO):
'''String input.''' '''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
socketless: bool=None, force_input: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.multiline = multiline self.multiline = multiline
self.placeholder = placeholder self.placeholder = placeholder
self.dynamic_prompts = dynamic_prompts self.dynamic_prompts = dynamic_prompts
@ -357,12 +348,14 @@ class Combo(ComfyTypeIO):
image_folder: FolderType=None, image_folder: FolderType=None,
remote: RemoteOptions=None, remote: RemoteOptions=None,
socketless: bool=None, socketless: bool=None,
extra_dict=None,
raw_link: bool=None,
): ):
if isinstance(options, type) and issubclass(options, Enum): if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options] options = [v.value for v in options]
if isinstance(default, Enum): if isinstance(default, Enum):
default = default.value default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
self.multiselect = False self.multiselect = False
self.options = options self.options = options
self.control_after_generate = control_after_generate self.control_after_generate = control_after_generate
@ -386,10 +379,6 @@ class Combo(ComfyTypeIO):
super().__init__(id, display_name, tooltip, is_output_list) super().__init__(id, display_name, tooltip, is_output_list)
self.options = options if options is not None else [] self.options = options if options is not None else []
@property
def io_type(self):
return self.options
@comfytype(io_type="COMBO") @comfytype(io_type="COMBO")
class MultiCombo(ComfyTypeI): class MultiCombo(ComfyTypeI):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).''' '''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
@ -398,8 +387,8 @@ class MultiCombo(ComfyTypeI):
class Input(Combo.Input): class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None): socketless: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
self.multiselect = True self.multiselect = True
self.placeholder = placeholder self.placeholder = placeholder
self.chip = chip self.chip = chip
@ -432,9 +421,9 @@ class Webcam(ComfyTypeIO):
Type = str Type = str
def __init__( def __init__(
self, id: str, display_name: str=None, optional=False, self, id: str, display_name: str=None, optional=False,
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
): ):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
@comfytype(io_type="MASK") @comfytype(io_type="MASK")
@ -787,7 +776,7 @@ class MultiType:
''' '''
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
''' '''
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
# if id is an Input, then use that Input with overridden values # if id is an Input, then use that Input with overridden values
self.input_override = None self.input_override = None
if isinstance(id, Input): if isinstance(id, Input):
@ -800,7 +789,7 @@ class MultiType:
# if is a widget input, make sure widget_type is set appropriately # if is a widget input, make sure widget_type is set appropriately
if isinstance(self.input_override, WidgetInput): if isinstance(self.input_override, WidgetInput):
self.input_override.widget_type = self.input_override.get_io_type() self.input_override.widget_type = self.input_override.get_io_type()
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self._io_types = types self._io_types = types
@property @property
@ -854,8 +843,8 @@ class MatchType(ComfyTypeIO):
class Input(Input): class Input(Input):
def __init__(self, id: str, template: MatchType.Template, def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self.template = template self.template = template
def as_dict(self): def as_dict(self):
@ -866,6 +855,8 @@ class MatchType(ComfyTypeIO):
class Output(Output): class Output(Output):
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False): is_output_list=False):
if not id and not display_name:
display_name = "MATCHTYPE"
super().__init__(id, display_name, tooltip, is_output_list) super().__init__(id, display_name, tooltip, is_output_list)
self.template = template self.template = template
@ -878,24 +869,30 @@ class DynamicInput(Input, ABC):
''' '''
Abstract class for dynamic input registration. Abstract class for dynamic input registration.
''' '''
def get_dynamic(self) -> list[Input]: pass
return []
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
pass
class DynamicOutput(Output, ABC): class DynamicOutput(Output, ABC):
''' '''
Abstract class for dynamic output registration. Abstract class for dynamic output registration.
''' '''
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, pass
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
def get_dynamic(self) -> list[Output]:
return []
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
if prefix_list is None:
prefix_list = []
if id is not None:
prefix_list = prefix_list + [id]
return prefix_list
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
assert not (prefix_list is None and id is None)
if prefix_list is None:
return id
elif id is not None:
prefix_list = prefix_list + [id]
return ".".join(prefix_list)
@comfytype(io_type="COMFY_AUTOGROW_V3") @comfytype(io_type="COMFY_AUTOGROW_V3")
class Autogrow(ComfyTypeI): class Autogrow(ComfyTypeI):
@ -932,14 +929,6 @@ class Autogrow(ComfyTypeI):
def validate(self): def validate(self):
self.input.validate() self.input.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
real_inputs = []
for name, input in self.cached_inputs.items():
if name in live_inputs:
real_inputs.append(input)
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
class TemplatePrefix(_AutogrowTemplate): class TemplatePrefix(_AutogrowTemplate):
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
super().__init__(input) super().__init__(input)
@ -984,22 +973,45 @@ class Autogrow(ComfyTypeI):
"template": self.template.as_dict(), "template": self.template.as_dict(),
}) })
def get_dynamic(self) -> list[Input]:
return self.template.get_all()
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + self.template.get_all() return [self] + self.template.get_all()
def validate(self): def validate(self):
self.template.validate() self.template.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): @staticmethod
curr_prefix = f"{curr_prefix}{self.id}." def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend # NOTE: purposely do not include self in out_dict; instead use only the template inputs
for inner_dict in d.values(): # need to figure out names based on template type
if self.id in inner_dict: is_names = ("names" in value[1]["template"])
del inner_dict[self.id] is_prefix = ("prefix" in value[1]["template"])
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) input = value[1]["template"]["input"]
if is_names:
min = value[1]["template"]["min"]
names = value[1]["template"]["names"]
max = len(names)
elif is_prefix:
prefix = value[1]["template"]["prefix"]
min = value[1]["template"]["min"]
max = value[1]["template"]["max"]
names = [f"{prefix}{i}" for i in range(max)]
# need to create a new input based on the contents of input
template_input = None
for _, dict_input in input.items():
# for now, get just the first value from dict_input
template_input = list(dict_input.values())[0]
new_dict = {}
for i, name in enumerate(names):
expected_id = finalize_prefix(curr_prefix, name)
if expected_id in live_inputs:
# required
if i < min:
type_dict = new_dict.setdefault("required", {})
# optional
else:
type_dict = new_dict.setdefault("optional", {})
type_dict[name] = template_input
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3") @comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
class DynamicCombo(ComfyTypeI): class DynamicCombo(ComfyTypeI):
@ -1022,23 +1034,6 @@ class DynamicCombo(ComfyTypeI):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.options = options self.options = options
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]:
return [input for option in self.options for input in option.inputs]
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + [input for option in self.options for input in option.inputs] return [self] + [input for option in self.options for input in option.inputs]
@ -1053,6 +1048,24 @@ class DynamicCombo(ComfyTypeI):
for input in option.inputs: for input in option.inputs:
input.validate() input.validate()
@staticmethod
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
finalized_id = finalize_prefix(curr_prefix)
if finalized_id in live_inputs:
key = live_inputs[finalized_id]
selected_option = None
# get options from dict
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
for option in options:
if option["key"] == key:
selected_option = option
break
if selected_option is not None:
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
# add self to inputs
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="COMFY_DYNAMICSLOT_V3") @comfytype(io_type="COMFY_DYNAMICSLOT_V3")
class DynamicSlot(ComfyTypeI): class DynamicSlot(ComfyTypeI):
Type = dict[str, Any] Type = dict[str, Any]
@ -1075,17 +1088,8 @@ class DynamicSlot(ComfyTypeI):
self.force_input = True self.force_input = True
self.slot.force_input = True self.slot.force_input = True
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
def get_dynamic(self) -> list[Input]:
return [self.slot] + self.inputs
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + [self.slot] + self.inputs return [self.slot] + self.inputs
def as_dict(self): def as_dict(self):
return super().as_dict() | prune_dict({ return super().as_dict() | prune_dict({
@ -1099,17 +1103,41 @@ class DynamicSlot(ComfyTypeI):
for input in self.inputs: for input in self.inputs:
input.validate() input.validate()
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): @staticmethod
dynamic = d.setdefault("dynamic_paths", {}) def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
if self is not None: finalized_id = finalize_prefix(curr_prefix)
dynamic[self.id] = f"{curr_prefix}{self.id}" if finalized_id in live_inputs:
for i in inputs: inputs = value[1]["inputs"]
if not isinstance(i, DynamicInput): parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" # add self to inputs
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
return DYNAMIC_INPUT_LOOKUP[io_type]
def setup_dynamic_input_funcs():
# Autogrow.Input
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
# DynamicCombo.Input
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
# DynamicSlot.Input
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
if len(DYNAMIC_INPUT_LOOKUP) == 0:
setup_dynamic_input_funcs()
class V3Data(TypedDict): class V3Data(TypedDict):
hidden_inputs: dict[str, Any] hidden_inputs: dict[str, Any]
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
dynamic_paths: dict[str, Any] dynamic_paths: dict[str, Any]
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).'
class HiddenHolder: class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any, def __init__(self, unique_id: str, prompt: Any,
@ -1145,6 +1173,10 @@ class HiddenHolder:
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
) )
@classmethod
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
class Hidden(str, Enum): class Hidden(str, Enum):
''' '''
Enumerator for requesting hidden variables in nodes. Enumerator for requesting hidden variables in nodes.
@ -1250,61 +1282,56 @@ class Schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other - verify ids on inputs and outputs are unique - both internally and in relation to each other
''' '''
nested_inputs: list[Input] = [] nested_inputs: list[Input] = []
if self.inputs is not None: for input in self.inputs:
for input in self.inputs: if not isinstance(input, DynamicInput):
nested_inputs.extend(input.get_all()) nested_inputs.extend(input.get_all())
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] input_ids = [i.id for i in nested_inputs]
output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] output_ids = [o.id for o in self.outputs]
input_set = set(input_ids) input_set = set(input_ids)
output_set = set(output_ids) output_set = set(output_ids)
issues = [] issues: list[str] = []
# verify ids are unique per list # verify ids are unique per list
if len(input_set) != len(input_ids): if len(input_set) != len(input_ids):
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
if len(output_set) != len(output_ids): if len(output_set) != len(output_ids):
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
# verify ids are unique between lists
intersection = input_set & output_set
if len(intersection) > 0:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0: if len(issues) > 0:
raise ValueError("\n".join(issues)) raise ValueError("\n".join(issues))
# validate inputs and outputs # validate inputs and outputs
if self.inputs is not None: for input in self.inputs:
for input in self.inputs: input.validate()
input.validate() for output in self.outputs:
if self.outputs is not None: output.validate()
for output in self.outputs:
output.validate()
def finalize(self): def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids.""" """Add hidden based on selected schema options, and give outputs without ids default ids."""
# ensure inputs, outputs, and hidden are lists
if self.inputs is None:
self.inputs = []
if self.outputs is None:
self.outputs = []
if self.hidden is None:
self.hidden = []
# if is an api_node, will need key-related hidden # if is an api_node, will need key-related hidden
if self.is_api_node: if self.is_api_node:
if self.hidden is None:
self.hidden = []
if Hidden.auth_token_comfy_org not in self.hidden: if Hidden.auth_token_comfy_org not in self.hidden:
self.hidden.append(Hidden.auth_token_comfy_org) self.hidden.append(Hidden.auth_token_comfy_org)
if Hidden.api_key_comfy_org not in self.hidden: if Hidden.api_key_comfy_org not in self.hidden:
self.hidden.append(Hidden.api_key_comfy_org) self.hidden.append(Hidden.api_key_comfy_org)
# if is an output_node, will need prompt and extra_pnginfo # if is an output_node, will need prompt and extra_pnginfo
if self.is_output_node: if self.is_output_node:
if self.hidden is None:
self.hidden = []
if Hidden.prompt not in self.hidden: if Hidden.prompt not in self.hidden:
self.hidden.append(Hidden.prompt) self.hidden.append(Hidden.prompt)
if Hidden.extra_pnginfo not in self.hidden: if Hidden.extra_pnginfo not in self.hidden:
self.hidden.append(Hidden.extra_pnginfo) self.hidden.append(Hidden.extra_pnginfo)
# give outputs without ids default ids # give outputs without ids default ids
if self.outputs is not None: for i, output in enumerate(self.outputs):
for i, output in enumerate(self.outputs): if output.id is None:
if output.id is None: output.id = f"_{i}_{output.io_type}_"
output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: def get_v1_info(self, cls) -> NodeInfoV1:
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs # get V1 inputs
input = create_input_dict_v1(self.inputs, live_inputs) input = create_input_dict_v1(self.inputs)
if self.hidden: if self.hidden:
for hidden in self.hidden: for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,) input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1384,33 +1411,54 @@ class Schema:
) )
return info return info
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
out_dict = {
"required": {},
"optional": {},
"dynamic_paths": {},
}
d = d.copy()
# ignore hidden for parsing
hidden = d.pop("hidden", None)
parse_class_inputs(out_dict, live_inputs, d)
if hidden is not None and include_hidden:
out_dict["hidden"] = hidden
v3_data = {}
dynamic_paths = out_dict.pop("dynamic_paths", None)
if dynamic_paths is not None:
v3_data["dynamic_paths"] = dynamic_paths
return out_dict, hidden, v3_data
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
for input_type, inner_d in curr_dict.items():
for id, value in inner_d.items():
io_type = value[0]
if io_type in DYNAMIC_INPUT_LOOKUP:
# dynamic inputs need to be handled with lookup functions
dynamic_input_func = get_dynamic_input_func(io_type)
new_prefix = handle_prefix(curr_prefix, id)
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
else:
# non-dynamic inputs get directly transferred
finalized_id = finalize_prefix(curr_prefix, id)
out_dict[input_type][finalized_id] = value
if curr_prefix:
out_dict["dynamic_paths"][finalized_id] = finalized_id
def create_input_dict_v1(inputs: list[Input]) -> dict:
input = { input = {
"required": {} "required": {}
} }
add_to_input_dict_v1(input, inputs, live_inputs) for i in inputs:
add_to_dict_v1(i, input)
return input return input
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): def add_to_dict_v1(i: Input, d: dict):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required" key = "optional" if i.optional else "required"
as_dict = i.as_dict() as_dict = i.as_dict()
# for v1, we don't want to include the optional key # for v1, we don't want to include the optional key
as_dict.pop("optional", None) as_dict.pop("optional", None)
if dynamic_dict is None: d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
value = (i.get_io_type(), as_dict)
else:
value = (i.get_io_type(), as_dict, dynamic_dict)
d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict): def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict()) d[io.id] = (io.get_io_type(), io.as_dict())
@ -1422,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
values = values.copy() values = values.copy()
result = {} result = {}
create_tuple = v3_data.get("create_dynamic_tuple", False)
for key, path in paths.items(): for key, path in paths.items():
parts = path.split(".") parts = path.split(".")
current = result current = result
@ -1430,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
is_last = (i == len(parts) - 1) is_last = (i == len(parts) - 1)
if is_last: if is_last:
current[p] = values.pop(key, None) value = values.pop(key, None)
if create_tuple:
value = (value, key)
current[p] = value
else: else:
current = current.setdefault(p, {}) current = current.setdefault(p, {})
@ -1445,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
SCHEMA = None SCHEMA = None
# filled in during execution # filled in during execution
resources: Resources = None
hidden: HiddenHolder = None hidden: HiddenHolder = None
@classmethod @classmethod
@ -1492,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
return [name for name in kwargs if kwargs[name] is None] return [name for name in kwargs if kwargs[name] is None]
def __init__(self): def __init__(self):
self.local_resources: ResourcesLocal = None
self.__class__.VALIDATE_CLASS() self.__class__.VALIDATE_CLASS()
@classmethod @classmethod
@ -1560,7 +1611,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type) type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden # set hidden
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
return type_clone return type_clone
@final @final
@ -1677,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: def INPUT_TYPES(cls) -> dict[str, dict]:
schema = cls.FINALIZE_SCHEMA() schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls, live_inputs) info = schema.get_v1_info(cls)
input = info.input return info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
v3_data: V3Data = {}
dynamic = input.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
return input, schema, v3_data
return input
@final @final
@classmethod @classmethod
@ -1808,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal):
return self.args if len(self.args) > 0 else None return self.args if len(self.args) > 0 else None
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
args = () args = ()
ui = None ui = None
expand = None expand = None
@ -1903,8 +1945,8 @@ __all__ = [
"Tracks", "Tracks",
# Dynamic Types # Dynamic Types
"MatchType", "MatchType",
# "DynamicCombo", "DynamicCombo",
# "Autogrow", "Autogrow",
# Other classes # Other classes
"HiddenHolder", "HiddenHolder",
"Hidden", "Hidden",

View File

@ -1,72 +0,0 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

View File

@ -1,16 +1,22 @@
import asyncio import asyncio
import contextlib import contextlib
import os import os
import re
import time import time
from collections.abc import Callable from collections.abc import Callable
from io import BytesIO from io import BytesIO
from yarl import URL
from comfy.cli_args import args from comfy.cli_args import args
from comfy.model_management import processing_interrupted from comfy.model_management import processing_interrupted
from comfy_api.latest import IO from comfy_api.latest import IO
from .common_exceptions import ProcessingInterrupted from .common_exceptions import ProcessingInterrupted
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
def is_processing_interrupted() -> bool: def is_processing_interrupted() -> bool:
"""Return True if user/runtime requested interruption.""" """Return True if user/runtime requested interruption."""
@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str): if isinstance(path_or_object, str):
return os.path.getsize(path_or_object) return os.path.getsize(path_or_object)
return len(path_or_object.getvalue()) return len(path_or_object.getvalue())
def to_aiohttp_url(url: str) -> URL:
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
escape and no malformed '%' sequences) and contains no raw whitespace/control
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
# Avoid encoded=True if URL contains raw whitespace/control chars
return URL(url)
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
return URL(url, encoded=True)
return URL(url)

View File

@ -19,6 +19,7 @@ from ._helpers import (
get_auth_header, get_auth_header,
is_processing_interrupted, is_processing_interrupted,
sleep_with_interrupt, sleep_with_interrupt,
to_aiohttp_url,
) )
from .client import _diagnose_connectivity from .client import _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
@ -94,7 +95,7 @@ async def download_url_to_bytesio(
monitor_task = asyncio.create_task(_monitor()) monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(url, headers=headers)) req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending: if monitor_task in done and req_task in pending:

View File

@ -97,6 +97,11 @@ def get_input_info(
extra_info = input_info[1] extra_info = input_info[1]
else: else:
extra_info = {} extra_info = {}
# if input_type is a list, it is a Combo defined in outdated format; convert it.
# NOTE: uncomment this when we are confident old format going away won't cause too much trouble.
# if isinstance(input_type, list):
# extra_info["options"] = input_type
# input_type = IO.Combo.io_type
return input_type, input_category, extra_info return input_type, input_category, extra_info
class TopologicalSort: class TopologicalSort:

View File

@ -21,14 +21,24 @@ def validate_node_input(
""" """
# If the types are exactly the same, we can return immediately # If the types are exactly the same, we can return immediately
# Use pre-union behaviour: inverse of `__ne__` # Use pre-union behaviour: inverse of `__ne__`
# NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class.
if not received_type != input_type: if not received_type != input_type:
return True return True
# If one of the types is '*', we can return True immediately; this is the 'Any' type.
if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type:
return True
# If the received type or input_type is a MatchType, we can return True immediately; # If the received type or input_type is a MatchType, we can return True immediately;
# validation for this is handled by the frontend # validation for this is handled by the frontend
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
return True return True
# This accounts for some custom nodes that output lists of options as the type;
# if we ever want to break them on purpose, this can be removed
if isinstance(received_type, list) and input_type == IO.Combo.io_type:
return True
# Not equal, and not strings # Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str): if not isinstance(received_type, str) or not isinstance(input_type, str):
return False return False
@ -37,6 +47,10 @@ def validate_node_input(
received_types = set(t.strip() for t in received_type.split(",")) received_types = set(t.strip() for t in received_type.split(","))
input_types = set(t.strip() for t in input_type.split(",")) input_types = set(t.strip() for t in input_type.split(","))
# If any of the types is '*', we can return True immediately; this is the 'Any' type.
if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types:
return True
if strict: if strict:
# In strict mode, all received types must be in the input types # In strict mode, all received types must be in the input types
return received_types.issubset(input_types) return received_types.issubset(input_types)

View File

@ -255,6 +255,7 @@ class LatentBatch(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="LatentBatch", node_id="LatentBatch",
category="latent/batch", category="latent/batch",
is_deprecated=True,
inputs=[ inputs=[
io.Latent.Input("samples1"), io.Latent.Input("samples1"),
io.Latent.Input("samples2"), io.Latent.Input("samples2"),

View File

@ -1,8 +1,11 @@
from __future__ import annotations
from typing import TypedDict from typing import TypedDict
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from comfy_api.latest import _io from comfy_api.latest import _io
# sentinel for missing inputs
MISSING = object()
class SwitchNode(io.ComfyNode): class SwitchNode(io.ComfyNode):
@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode):
display_name="Switch", display_name="Switch",
category="logic", category="logic",
is_experimental=True, is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True),
io.MatchType.Input("on_true", template=template, lazy=True),
],
outputs=[
io.MatchType.Output(template=template, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, switch, on_false=None, on_true=None):
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
@classmethod
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
return io.NodeOutput(on_true if switch else on_false)
class SoftSwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="logic",
is_experimental=True,
inputs=[ inputs=[
io.Boolean.Input("switch"), io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True, optional=True), io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode):
) )
@classmethod @classmethod
def check_lazy_status(cls, switch, on_false=..., on_true=...): def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs. # We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
# This trick allows us to ignore the value of the switch and still be able to run execute(). # This trick allows us to ignore the value of the switch and still be able to run execute().
# One of the inputs may be missing, in which case we need to evaluate the other input # One of the inputs may be missing, in which case we need to evaluate the other input
if on_false is ...: if on_false is MISSING:
return ["on_true"] return ["on_true"]
if on_true is ...: if on_true is MISSING:
return ["on_false"] return ["on_false"]
# Normal lazy switch operation # Normal lazy switch operation
if switch and on_true is None: if switch and on_true is None:
@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode):
return ["on_false"] return ["on_false"]
@classmethod @classmethod
def validate_inputs(cls, switch, on_false=..., on_true=...): def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
# This check happens before check_lazy_status(), so we can eliminate the case where # This check happens before check_lazy_status(), so we can eliminate the case where
# both inputs are missing. # both inputs are missing.
if on_false is ... and on_true is ...: if on_false is MISSING and on_true is MISSING:
return "At least one of on_false or on_true must be connected to Switch node" return "At least one of on_false or on_true must be connected to Switch node"
return True return True
@classmethod @classmethod
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
if on_true is ...: if on_true is MISSING:
return io.NodeOutput(on_false) return io.NodeOutput(on_false)
if on_false is ...: if on_false is MISSING:
return io.NodeOutput(on_true) return io.NodeOutput(on_true)
return io.NodeOutput(on_true if switch else on_false) return io.NodeOutput(on_true if switch else on_false)
class CustomComboNode(io.ComfyNode):
"""
Frontend node that allows user to write their own options for a combo.
This is here to make sure the node has a backend-representation to avoid some annoyances.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CustomCombo",
display_name="Custom Combo",
category="utils",
is_experimental=True,
inputs=[io.Combo.Input("choice", options=[])],
outputs=[io.String.Output()]
)
@classmethod
def validate_inputs(cls, choice: io.Combo.Type) -> bool:
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
return True
@classmethod
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
return io.NodeOutput(choice)
class DCTestNode(io.ComfyNode): class DCTestNode(io.ComfyNode):
class DCValues(TypedDict): class DCValues(TypedDict):
combo: str combo: str
@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode):
display_name="DCTest", display_name="DCTest",
category="logic", category="logic",
is_output_node=True, is_output_node=True,
inputs=[_io.DynamicCombo.Input("combo", options=[ inputs=[io.DynamicCombo.Input("combo", options=[
_io.DynamicCombo.Option("option1", [io.String.Input("string")]), io.DynamicCombo.Option("option1", [io.String.Input("string")]),
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]), io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
_io.DynamicCombo.Option("option4", [ io.DynamicCombo.Option("option4", [
_io.DynamicCombo.Input("subcombo", options=[ io.DynamicCombo.Input("subcombo", options=[
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
]) ])
])] ])]
)], )],
@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode):
combined = ",".join([str(x) for x in vals]) combined = ",".join([str(x) for x in vals])
return io.NodeOutput(combined) return io.NodeOutput(combined)
class ComboOutputTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
)
@classmethod
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
return io.NodeOutput(combo, combo2)
class ConvertStringToComboNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertStringToComboNode",
display_name="Convert String to Combo",
category="logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@classmethod
def execute(cls, string: str) -> io.NodeOutput:
return io.NodeOutput(string)
class InvertBooleanNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="InvertBooleanNode",
display_name="Invert Boolean",
category="logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)
@classmethod
def execute(cls, boolean: bool) -> io.NodeOutput:
return io.NodeOutput(not boolean)
class LogicExtension(ComfyExtension): class LogicExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
# SwitchNode, SwitchNode,
CustomComboNode,
# SoftSwitchNode,
# ConvertStringToComboNode,
# DCTestNode, # DCTestNode,
# AutogrowNamesTestNode, # AutogrowNamesTestNode,
# AutogrowPrefixTestNode, # AutogrowPrefixTestNode,
# ComboOutputTestNode,
# InvertBooleanNode,
] ]
async def comfy_entrypoint() -> LogicExtension: async def comfy_entrypoint() -> LogicExtension:

View File

@ -4,11 +4,15 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
import math import math
from enum import Enum
from typing import TypedDict, Literal
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from comfy_extras.nodes_latent import reshape_latent_to
import node_helpers import node_helpers
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from nodes import MAX_RESOLUTION
class Blend(io.ComfyNode): class Blend(io.ComfyNode):
@classmethod @classmethod
@ -241,6 +245,353 @@ class ImageScaleToTotalPixels(io.ComfyNode):
s = s.movedim(1,-1) s = s.movedim(1,-1)
return io.NodeOutput(s) return io.NodeOutput(s)
class ResizeType(str, Enum):
SCALE_BY = "scale by multiplier"
SCALE_DIMENSIONS = "scale dimensions"
SCALE_LONGER_DIMENSION = "scale longer dimension"
SCALE_SHORTER_DIMENSION = "scale shorter dimension"
SCALE_WIDTH = "scale width"
SCALE_HEIGHT = "scale height"
SCALE_TOTAL_PIXELS = "scale total pixels"
MATCH_SIZE = "match size"
def is_image(input: torch.Tensor) -> bool:
# images have 4 dimensions: [batch, height, width, channels]
# masks have 3 dimensions: [batch, height, width]
return len(input.shape) == 4
def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
if is_type_image:
input = input.movedim(-1, 1)
else:
input = input.unsqueeze(1)
return input
def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
if is_type_image:
input = input.movedim(1, -1)
else:
input = input.squeeze(1)
return input
def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = round(input.shape[-1] * multiplier)
height = round(input.shape[-2] * multiplier)
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor:
if width == 0 and height == 0:
return input
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
if width == 0:
width = max(1, round(input.shape[-1] * height / input.shape[-2]))
elif height == 0:
height = max(1, round(input.shape[-2] * width / input.shape[-1]))
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = input.shape[-1]
height = input.shape[-2]
if height > width:
width = round((width / height) * longer_size)
height = longer_size
elif width > height:
height = round((height / width) * longer_size)
width = longer_size
else:
height = longer_size
width = longer_size
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = input.shape[-1]
height = input.shape[-2]
if height < width:
width = round((width / height) * shorter_size)
height = shorter_size
elif width > height:
height = round((height / width) * shorter_size)
width = shorter_size
else:
height = shorter_size
width = shorter_size
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
total = int(megapixels * 1024 * 1024)
scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2]))
width = round(input.shape[-1] * scale_by)
height = round(input.shape[-2] * scale_by)
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
match = init_image_mask_input(match, is_image(match))
width = match.shape[-1]
height = match.shape[-2]
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
input = finalize_image_mask_input(input, is_type_image)
return input
class ResizeImageMaskNode(io.ComfyNode):
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
class ResizeTypedDict(TypedDict):
resize_type: ResizeType
scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop: Literal["disabled", "center"]
multiplier: float
width: int
height: int
longer_size: int
shorter_size: int
megapixels: float
@classmethod
def define_schema(cls):
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
return io.Schema(
node_id="ResizeImageMaskNode",
display_name="Resize Image/Mask",
category="transform",
inputs=[
io.MatchType.Input("input", template=template),
io.DynamicCombo.Input("resize_type", options=[
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
]),
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
]),
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
io.MultiType.Input("match", [io.Image, io.Mask]),
crop_combo,
]),
]),
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
],
outputs=[io.MatchType.Output(template=template, display_name="resized")]
)
@classmethod
def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput:
selected_type = resize_type["resize_type"]
if selected_type == ResizeType.SCALE_BY:
return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method))
elif selected_type == ResizeType.SCALE_DIMENSIONS:
return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"]))
elif selected_type == ResizeType.SCALE_LONGER_DIMENSION:
return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method))
elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION:
return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method))
elif selected_type == ResizeType.SCALE_WIDTH:
return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method))
elif selected_type == ResizeType.SCALE_HEIGHT:
return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method))
elif selected_type == ResizeType.SCALE_TOTAL_PIXELS:
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
elif selected_type == ResizeType.MATCH_SIZE:
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
raise ValueError(f"Unsupported resize type: {selected_type}")
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
if len(images) == 0:
return None
# first, get the max channels count
max_channels = max(image.shape[-1] for image in images)
# then, pad all images to have the same channels count
padded_images: list[torch.Tensor] = []
for image in images:
if image.shape[-1] < max_channels:
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
else:
padded_images.append(image)
# resize all images to be the same size as the first image
resized_images: list[torch.Tensor] = []
first_image_shape = padded_images[0].shape
for image in padded_images:
if image.shape[1:] != first_image_shape[1:]:
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
else:
resized_images.append(image)
# batch the images in the format [b, h, w, c]
return torch.cat(resized_images, dim=0)
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
if len(masks) == 0:
return None
# resize all masks to be the same size as the first mask
resized_masks: list[torch.Tensor] = []
first_mask_shape = masks[0].shape
for mask in masks:
if mask.shape[1:] != first_mask_shape[1:]:
mask = init_image_mask_input(mask, is_type_image=False)
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
else:
resized_masks.append(mask)
# batch the masks in the format [b, h, w]
return torch.cat(resized_masks, dim=0)
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
if len(latents) == 0:
return None
samples_out = latents[0].copy()
samples_out["batch_index"] = []
first_samples = latents[0]["samples"]
tensors: list[torch.Tensor] = []
for latent in latents:
# first, deal with latent tensors
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
# next, deal with batch_index
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
samples_out["samples"] = torch.cat(tensors, dim=0)
return samples_out
class BatchImagesNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
return io.Schema(
node_id="BatchImagesNode",
display_name="Batch Images",
category="image",
inputs=[
io.Autogrow.Input("images", template=autogrow_template)
],
outputs=[
io.Image.Output()
]
)
@classmethod
def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_images(list(images.values())))
class BatchMasksNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
return io.Schema(
node_id="BatchMasksNode",
display_name="Batch Masks",
category="mask",
inputs=[
io.Autogrow.Input("masks", template=autogrow_template)
],
outputs=[
io.Mask.Output()
]
)
@classmethod
def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_masks(list(masks.values())))
class BatchLatentsNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
return io.Schema(
node_id="BatchLatentsNode",
display_name="Batch Latents",
category="latent",
inputs=[
io.Autogrow.Input("latents", template=autogrow_template)
],
outputs=[
io.Latent.Output()
]
)
@classmethod
def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_latents(list(latents.values())))
class BatchImagesMasksLatentsNode(io.ComfyNode):
@classmethod
def define_schema(cls):
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
autogrow_template = io.Autogrow.TemplatePrefix(
io.MatchType.Input("input", matchtype_template),
prefix="input", min=1, max=50)
return io.Schema(
node_id="BatchImagesMasksLatentsNode",
display_name="Batch Images/Masks/Latents",
category="util",
inputs=[
io.Autogrow.Input("inputs", template=autogrow_template)
],
outputs=[
io.MatchType.Output(id=None, template=matchtype_template)
]
)
@classmethod
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
batched = None
values = list(inputs.values())
# latents
if isinstance(values[0], dict):
batched = batch_latents(values)
# images
elif is_image(values[0]):
batched = batch_images(values)
# masks
else:
batched = batch_masks(values)
return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension): class PostProcessingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -250,6 +601,11 @@ class PostProcessingExtension(ComfyExtension):
Quantize, Quantize,
Sharpen, Sharpen,
ImageScaleToTotalPixels, ImageScaleToTotalPixels,
ResizeImageMaskNode,
BatchImagesNode,
BatchMasksNode,
BatchLatentsNode,
# BatchImagesMasksLatentsNode,
] ]
async def comfy_entrypoint() -> PostProcessingExtension: async def comfy_entrypoint() -> PostProcessingExtension:

View File

@ -66,7 +66,7 @@ class Float(io.ComfyNode):
display_name="Float", display_name="Float",
category="utils/primitive", category="utils/primitive",
inputs=[ inputs=[
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1),
], ],
outputs=[io.Float.Output()], outputs=[io.Float.Output()],
) )

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.6.0" __version__ = "0.7.0"

View File

@ -79,7 +79,7 @@ class IsChangedCache:
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal) is_v3 = issubclass(class_def, _ComfyNodeInternal)
v3_data: io.V3Data = {} v3_data: io.V3Data = {}
hidden_inputs_v3 = {}
valid_inputs = class_def.INPUT_TYPES()
if is_v3: if is_v3:
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs)
else:
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
missing_keys = {} missing_keys = {}
hidden_inputs_v3 = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
_, input_category, input_info = get_input_info(class_def, x, valid_inputs) _, input_category, input_info = get_input_info(class_def, x, valid_inputs)
@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
if is_v3: if is_v3:
if schema.hidden: if hidden is not None:
if io.Hidden.prompt in schema.hidden: if io.Hidden.prompt.name in hidden:
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
if io.Hidden.dynprompt in schema.hidden: if io.Hidden.dynprompt.name in hidden:
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
if io.Hidden.extra_pnginfo in schema.hidden: if io.Hidden.extra_pnginfo.name in hidden:
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
if io.Hidden.unique_id in schema.hidden: if io.Hidden.unique_id.name in hidden:
hidden_inputs_v3[io.Hidden.unique_id] = unique_id hidden_inputs_v3[io.Hidden.unique_id] = unique_id
if io.Hidden.auth_token_comfy_org in schema.hidden: if io.Hidden.auth_token_comfy_org.name in hidden:
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
if io.Hidden.api_key_comfy_org in schema.hidden: if io.Hidden.api_key_comfy_org.name in hidden:
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
else: else:
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
pre_execute_cb(index) pre_execute_cb(index)
# V3 # V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# if is just a class, then assign no resources or state, just create clone # if is just a class, then assign no state, just create clone
if is_class(obj): if is_class(obj):
type_obj = obj type_obj = obj
obj.VALIDATE_CLASS() obj.VALIDATE_CLASS()
@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else: else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present: if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) # for check_lazy_status, the returned data should include the original key of the input
v3_data_lazy = v3_data.copy()
v3_data_lazy["create_dynamic_tuple"] = True
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy)
required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@ -756,10 +758,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors = [] errors = []
valid = True valid = True
v3_data = None
validate_function_inputs = [] validate_function_inputs = []
validate_has_kwargs = False validate_has_kwargs = False
if issubclass(obj_class, _ComfyNodeInternal): if issubclass(obj_class, _ComfyNodeInternal):
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) obj_class: _io._ComfyNodeBaseInternal
class_inputs = obj_class.INPUT_TYPES()
class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs)
validate_function_name = "validate_inputs" validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name) validate_function = first_real_override(obj_class, validate_function_name)
else: else:
@ -779,10 +784,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
assert extra_info is not None assert extra_info is not None
if x not in inputs: if x not in inputs:
if input_category == "required": if input_category == "required":
details = f"{x}" if not v3_data else x.split(".")[-1]
error = { error = {
"type": "required_input_missing", "type": "required_input_missing",
"message": "Required input is missing", "message": "Required input is missing",
"details": f"{x}", "details": details,
"extra_info": { "extra_info": {
"input_name": x "input_name": x
} }
@ -916,8 +922,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if isinstance(input_type, list): if isinstance(input_type, list) or input_type == io.Combo.io_type:
combo_options = input_type if input_type == io.Combo.io_type:
combo_options = extra_info.get("options", [])
else:
combo_options = input_type
if val not in combo_options: if val not in combo_options:
input_config = info input_config = info
list_info = "" list_info = ""

View File

@ -1863,6 +1863,7 @@ class ImageBatch:
FUNCTION = "batch" FUNCTION = "batch"
CATEGORY = "image" CATEGORY = "image"
DEPRECATED = True
def batch(self, image1, image2): def batch(self, image1, image2):
if image1.shape[-1] != image2.shape[-1]: if image1.shape[-1] != image2.shape[-1]:

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.6.0" version = "0.7.0"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"