Merge branch 'comfyanonymous:master' into pbar-throttle-priority

This commit is contained in:
Silver 2026-01-01 00:39:11 +01:00 committed by silveroxides
commit 2c98bcad80
36 changed files with 1398 additions and 738 deletions

View File

@ -212,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
### Instructions: ### Instructions:
Git clone this repo. Git clone this repo.

View File

@ -2,6 +2,25 @@ import torch
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops import comfy.ops
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class CLIPAttention(torch.nn.Module): class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations): def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -1,6 +1,5 @@
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os import os
import torch
import json import json
import logging import logging
@ -17,24 +16,7 @@ class Output:
def __setitem__(self, key, item): def __setitem__(self, key, item):
setattr(self, key, item) setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = { IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
@ -73,7 +55,7 @@ class ClipVisionModel():
def encode_image(self, image, crop=True): def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output() outputs = Output()

View File

@ -188,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC):
audio_cond = cond_value.cond audio_cond = cond_value.cond
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# Handle vace_context (temporal dim is 3)
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
vace_cond = cond_value.cond
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
# if has cond that is a Tensor, check if needs to be subset # if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \

View File

@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
def default_noise_sampler(x, seed=None): def default_noise_sampler(x, seed=None):
if seed is not None: if seed is not None:
if x.device == torch.device("cpu"):
seed += 1
generator = torch.Generator(device=x.device) generator = torch.Generator(device=x.device)
generator.manual_seed(seed) generator.manual_seed(seed)
else: else:

View File

@ -491,7 +491,8 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) # This norm final is in the lumina 2.0 code but isn't actually used for anything.
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None: if self.pad_tokens_multiple is not None:

View File

@ -30,6 +30,13 @@ except ImportError as e:
raise e raise e
exit(-1) exit(-1)
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
SAGE_ATTENTION3_IS_AVAILABLE = True
except ImportError:
pass
FLASH_ATTENTION_IS_AVAILABLE = False FLASH_ATTENTION_IS_AVAILABLE = False
try: try:
from flash_attn import flash_attn_func from flash_attn import flash_attn_func
@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = out.reshape(b, -1, heads * dim_head) out = out.reshape(b, -1, heads * dim_head)
return out return out
@wrap_attn
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if (q.device.type != "cuda" or
q.dtype not in (torch.float16, torch.bfloat16) or
mask is not None):
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
B, H, L, D = q.shape
if H != heads:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
B, N, inner_dim = q.shape
if inner_dim % heads != 0:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
dim_head = inner_dim // heads
if dim_head >= 256 or N <= 1024:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
B, H, L, D = q_s.shape
try:
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
except Exception as e:
exception_fallback = True
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
if exception_fallback:
if not skip_reshape:
del q_s, k_s, v_s
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
if not skip_output_reshape:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
else:
if skip_output_reshape:
pass
else:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
return out
try: try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
# register core-supported attention functions # register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE: if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage) register_attention_function("sage", attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
if FLASH_ATTENTION_IS_AVAILABLE: if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash) register_attention_function("flash", attention_flash)
if model_management.xformers_enabled(): if model_management.xformers_enabled():

View File

@ -1019,8 +1019,8 @@ NUM_STREAMS = 0
if args.async_offload is not None: if args.async_offload is not None:
NUM_STREAMS = args.async_offload NUM_STREAMS = args.async_offload
else: else:
# Enable by default on Nvidia # Enable by default on Nvidia and AMD
if is_nvidia(): if is_nvidia() or is_amd():
NUM_STREAMS = 2 NUM_STREAMS = 2
if args.disable_async_offload: if args.disable_async_offload:
@ -1126,6 +1126,16 @@ if not args.disable_pinned_memory:
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
torch.cuda.synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
def pin_memory(tensor): def pin_memory(tensor):
global TOTAL_PINNED_MEMORY global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0: if MAX_PINNED_MEMORY <= 0:
@ -1158,6 +1168,9 @@ def pin_memory(tensor):
PINNED_MEMORY[ptr] = size PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size TOTAL_PINNED_MEMORY += size
return True return True
else:
logging.warning("Pin error.")
discard_cuda_async_error()
return False return False
@ -1186,6 +1199,9 @@ def unpin_memory(tensor):
if len(PINNED_MEMORY) == 0: if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0 TOTAL_PINNED_MEMORY = 0
return True return True
else:
logging.warning("Unpin error.")
discard_cuda_async_error()
return False return False

View File

@ -31,8 +31,6 @@ from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args
import json import json
import time import time
import inspect
import os
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap DISABLE_MMAP = args.disable_mmap
@ -1100,54 +1098,12 @@ def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function PROGRESS_BAR_HOOK = function
class ProgressPriority: # Throttle settings for progress bar updates to reduce WebSocket flooding
CRITICAL = 0 # No throttling (previews) PROGRESS_THROTTLE_MIN_INTERVAL = 0.1 # 100ms minimum between updates
HIGH = 1 # Core ComfyUI - light throttling PROGRESS_THROTTLE_MIN_PERCENT = 0.5 # 0.5% minimum progress change
LOW = 2 # Custom nodes - aggressive throttling
PROGRESS_THROTTLE_SETTINGS = {
ProgressPriority.CRITICAL: {"min_interval": 0.0, "min_percent": 0.0},
ProgressPriority.HIGH: {"min_interval": 0.1, "min_percent": 0.5},
ProgressPriority.LOW: {"min_interval": 0.2, "min_percent": 2.0},
}
_COMFY_CORE_PATHS = None
def _get_comfy_core_paths():
global _COMFY_CORE_PATHS
if _COMFY_CORE_PATHS is None:
utils_dir = os.path.dirname(os.path.abspath(__file__))
comfy_dir = utils_dir
comfyui_root = os.path.dirname(comfy_dir)
_COMFY_CORE_PATHS = {
os.path.normpath(comfy_dir),
os.path.normpath(os.path.join(comfyui_root, "comfy_extras")),
os.path.normpath(os.path.join(comfyui_root, "comfy_api")),
os.path.normpath(os.path.join(comfyui_root, "comfy_api_nodes")),
os.path.normpath(os.path.join(comfyui_root, "comfy_execution")),
os.path.normpath(os.path.join(comfyui_root, "latent_preview.py")),
os.path.normpath(os.path.join(comfyui_root, "nodes.py")),
}
return _COMFY_CORE_PATHS
def _is_core_comfyui_caller():
core_paths = _get_comfy_core_paths()
for frame_info in inspect.stack():
filepath = os.path.normpath(os.path.abspath(frame_info.filename))
if "utils.py" in filepath and "comfy" in filepath:
continue
if "site-packages" in filepath or "lib" + os.sep + "python" in filepath:
continue
if "custom_nodes" in filepath:
return False
for core_path in core_paths:
if filepath.startswith(core_path) or filepath == core_path:
return True
break
return False
class ProgressBar: class ProgressBar:
def __init__(self, total, node_id=None, priority=None): def __init__(self, total, node_id=None):
global PROGRESS_BAR_HOOK global PROGRESS_BAR_HOOK
self.total = total self.total = total
self.current = 0 self.current = 0
@ -1155,12 +1111,6 @@ class ProgressBar:
self.node_id = node_id self.node_id = node_id
self._last_update_time = 0.0 self._last_update_time = 0.0
self._last_sent_value = -1 self._last_sent_value = -1
if priority is not None:
self._priority = priority
elif _is_core_comfyui_caller():
self._priority = ProgressPriority.HIGH
else:
self._priority = ProgressPriority.LOW
def update_absolute(self, value, total=None, preview=None): def update_absolute(self, value, total=None, preview=None):
if total is not None: if total is not None:
@ -1173,20 +1123,22 @@ class ProgressBar:
is_first = (self._last_sent_value < 0) is_first = (self._last_sent_value < 0)
is_final = (value >= self.total) is_final = (value >= self.total)
has_preview = (preview is not None) has_preview = (preview is not None)
if has_preview:
priority = ProgressPriority.CRITICAL # Always send immediately for previews, first update, or final update
else: if has_preview or is_first or is_final:
priority = self._priority self.hook(self.current, self.total, preview, node_id=self.node_id)
settings = PROGRESS_THROTTLE_SETTINGS[priority] self._last_update_time = current_time
min_interval = settings["min_interval"] self._last_sent_value = value
min_percent = settings["min_percent"] return
# Apply throttling for regular progress updates
if self.total > 0: if self.total > 0:
percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100 percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100
else: else:
percent_changed = 100 percent_changed = 100
time_elapsed = current_time - self._last_update_time time_elapsed = current_time - self._last_update_time
should_send = (is_first or is_final or (time_elapsed >= min_interval and percent_changed >= min_percent))
if should_send: if time_elapsed >= PROGRESS_THROTTLE_MIN_INTERVAL and percent_changed >= PROGRESS_THROTTLE_MIN_PERCENT:
self.hook(self.current, self.total, preview, node_id=self.node_id) self.hook(self.current, self.total, preview, node_id=self.node_id)
self._last_update_time = current_time self._last_update_time = current_time
self._last_sent_value = value self._last_sent_value = value

View File

@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io from . import _io_public as io
from . import _ui_public as ui from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image from PIL import Image

View File

@ -26,11 +26,9 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class) prune_dict, shallow_clone_class)
from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL from ._util import MESH, VOXEL, SVG as _SVG
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
class FolderType(str, Enum): class FolderType(str, Enum):
input = "input" input = "input"
@ -77,16 +75,6 @@ class NumberDisplay(str, Enum):
slider = "slider" slider = "slider"
class _StringIOType(str):
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class _ComfyType(ABC): class _ComfyType(ABC):
Type = Any Type = Any
io_type: str = None io_type: str = None
@ -126,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
new_cls.__module__ = cls.__module__ new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__ new_cls.__doc__ = cls.__doc__
# assign ComfyType attributes, if needed # assign ComfyType attributes, if needed
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) new_cls.io_type = io_type
new_cls.io_type = _StringIOType(io_type)
if hasattr(new_cls, "Input") and new_cls.Input is not None: if hasattr(new_cls, "Input") and new_cls.Input is not None:
new_cls.Input.Parent = new_cls new_cls.Input.Parent = new_cls
if hasattr(new_cls, "Output") and new_cls.Output is not None: if hasattr(new_cls, "Output") and new_cls.Output is not None:
@ -166,7 +153,7 @@ class Input(_IO_V3):
''' '''
Base class for a V3 Input. Base class for a V3 Input.
''' '''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__() super().__init__()
self.id = id self.id = id
self.display_name = display_name self.display_name = display_name
@ -174,6 +161,7 @@ class Input(_IO_V3):
self.tooltip = tooltip self.tooltip = tooltip
self.lazy = lazy self.lazy = lazy
self.extra_dict = extra_dict if extra_dict is not None else {} self.extra_dict = extra_dict if extra_dict is not None else {}
self.rawLink = raw_link
def as_dict(self): def as_dict(self):
return prune_dict({ return prune_dict({
@ -181,10 +169,11 @@ class Input(_IO_V3):
"optional": self.optional, "optional": self.optional,
"tooltip": self.tooltip, "tooltip": self.tooltip,
"lazy": self.lazy, "lazy": self.lazy,
"rawLink": self.rawLink,
}) | prune_dict(self.extra_dict) }) | prune_dict(self.extra_dict)
def get_io_type(self): def get_io_type(self):
return _StringIOType(self.io_type) return self.io_type
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] return [self]
@ -195,8 +184,8 @@ class WidgetInput(Input):
''' '''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: Any=None, default: Any=None,
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self.default = default self.default = default
self.socketless = socketless self.socketless = socketless
self.widget_type = widget_type self.widget_type = widget_type
@ -218,13 +207,14 @@ class Output(_IO_V3):
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False): is_output_list=False):
self.id = id self.id = id
self.display_name = display_name self.display_name = display_name if display_name else id
self.tooltip = tooltip self.tooltip = tooltip
self.is_output_list = is_output_list self.is_output_list = is_output_list
def as_dict(self): def as_dict(self):
display_name = self.display_name if self.display_name else self.id
return prune_dict({ return prune_dict({
"display_name": self.display_name, "display_name": display_name,
"tooltip": self.tooltip, "tooltip": self.tooltip,
"is_output_list": self.is_output_list, "is_output_list": self.is_output_list,
}) })
@ -252,8 +242,8 @@ class Boolean(ComfyTypeIO):
'''Boolean input.''' '''Boolean input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None, default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, force_input: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.label_on = label_on self.label_on = label_on
self.label_off = label_off self.label_off = label_off
self.default: bool self.default: bool
@ -272,8 +262,8 @@ class Int(ComfyTypeIO):
'''Integer input.''' '''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
@ -298,8 +288,8 @@ class Float(ComfyTypeIO):
'''Float input.''' '''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
@ -324,8 +314,8 @@ class String(ComfyTypeIO):
'''String input.''' '''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
socketless: bool=None, force_input: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.multiline = multiline self.multiline = multiline
self.placeholder = placeholder self.placeholder = placeholder
self.dynamic_prompts = dynamic_prompts self.dynamic_prompts = dynamic_prompts
@ -358,12 +348,14 @@ class Combo(ComfyTypeIO):
image_folder: FolderType=None, image_folder: FolderType=None,
remote: RemoteOptions=None, remote: RemoteOptions=None,
socketless: bool=None, socketless: bool=None,
extra_dict=None,
raw_link: bool=None,
): ):
if isinstance(options, type) and issubclass(options, Enum): if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options] options = [v.value for v in options]
if isinstance(default, Enum): if isinstance(default, Enum):
default = default.value default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
self.multiselect = False self.multiselect = False
self.options = options self.options = options
self.control_after_generate = control_after_generate self.control_after_generate = control_after_generate
@ -387,10 +379,6 @@ class Combo(ComfyTypeIO):
super().__init__(id, display_name, tooltip, is_output_list) super().__init__(id, display_name, tooltip, is_output_list)
self.options = options if options is not None else [] self.options = options if options is not None else []
@property
def io_type(self):
return self.options
@comfytype(io_type="COMBO") @comfytype(io_type="COMBO")
class MultiCombo(ComfyTypeI): class MultiCombo(ComfyTypeI):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).''' '''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
@ -399,8 +387,8 @@ class MultiCombo(ComfyTypeI):
class Input(Combo.Input): class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None): socketless: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
self.multiselect = True self.multiselect = True
self.placeholder = placeholder self.placeholder = placeholder
self.chip = chip self.chip = chip
@ -433,9 +421,9 @@ class Webcam(ComfyTypeIO):
Type = str Type = str
def __init__( def __init__(
self, id: str, display_name: str=None, optional=False, self, id: str, display_name: str=None, optional=False,
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
): ):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
@comfytype(io_type="MASK") @comfytype(io_type="MASK")
@ -656,7 +644,7 @@ class Video(ComfyTypeIO):
@comfytype(io_type="SVG") @comfytype(io_type="SVG")
class SVG(ComfyTypeIO): class SVG(ComfyTypeIO):
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 Type = _SVG
@comfytype(io_type="LORA_MODEL") @comfytype(io_type="LORA_MODEL")
class LoraModel(ComfyTypeIO): class LoraModel(ComfyTypeIO):
@ -788,7 +776,7 @@ class MultiType:
''' '''
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
''' '''
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
# if id is an Input, then use that Input with overridden values # if id is an Input, then use that Input with overridden values
self.input_override = None self.input_override = None
if isinstance(id, Input): if isinstance(id, Input):
@ -801,7 +789,7 @@ class MultiType:
# if is a widget input, make sure widget_type is set appropriately # if is a widget input, make sure widget_type is set appropriately
if isinstance(self.input_override, WidgetInput): if isinstance(self.input_override, WidgetInput):
self.input_override.widget_type = self.input_override.get_io_type() self.input_override.widget_type = self.input_override.get_io_type()
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self._io_types = types self._io_types = types
@property @property
@ -855,8 +843,8 @@ class MatchType(ComfyTypeIO):
class Input(Input): class Input(Input):
def __init__(self, id: str, template: MatchType.Template, def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self.template = template self.template = template
def as_dict(self): def as_dict(self):
@ -867,6 +855,8 @@ class MatchType(ComfyTypeIO):
class Output(Output): class Output(Output):
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False): is_output_list=False):
if not id and not display_name:
display_name = "MATCHTYPE"
super().__init__(id, display_name, tooltip, is_output_list) super().__init__(id, display_name, tooltip, is_output_list)
self.template = template self.template = template
@ -879,24 +869,30 @@ class DynamicInput(Input, ABC):
''' '''
Abstract class for dynamic input registration. Abstract class for dynamic input registration.
''' '''
def get_dynamic(self) -> list[Input]: pass
return []
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
pass
class DynamicOutput(Output, ABC): class DynamicOutput(Output, ABC):
''' '''
Abstract class for dynamic output registration. Abstract class for dynamic output registration.
''' '''
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, pass
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
def get_dynamic(self) -> list[Output]:
return []
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
if prefix_list is None:
prefix_list = []
if id is not None:
prefix_list = prefix_list + [id]
return prefix_list
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
assert not (prefix_list is None and id is None)
if prefix_list is None:
return id
elif id is not None:
prefix_list = prefix_list + [id]
return ".".join(prefix_list)
@comfytype(io_type="COMFY_AUTOGROW_V3") @comfytype(io_type="COMFY_AUTOGROW_V3")
class Autogrow(ComfyTypeI): class Autogrow(ComfyTypeI):
@ -933,14 +929,6 @@ class Autogrow(ComfyTypeI):
def validate(self): def validate(self):
self.input.validate() self.input.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
real_inputs = []
for name, input in self.cached_inputs.items():
if name in live_inputs:
real_inputs.append(input)
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
class TemplatePrefix(_AutogrowTemplate): class TemplatePrefix(_AutogrowTemplate):
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
super().__init__(input) super().__init__(input)
@ -985,22 +973,45 @@ class Autogrow(ComfyTypeI):
"template": self.template.as_dict(), "template": self.template.as_dict(),
}) })
def get_dynamic(self) -> list[Input]:
return self.template.get_all()
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + self.template.get_all() return [self] + self.template.get_all()
def validate(self): def validate(self):
self.template.validate() self.template.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): @staticmethod
curr_prefix = f"{curr_prefix}{self.id}." def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend # NOTE: purposely do not include self in out_dict; instead use only the template inputs
for inner_dict in d.values(): # need to figure out names based on template type
if self.id in inner_dict: is_names = ("names" in value[1]["template"])
del inner_dict[self.id] is_prefix = ("prefix" in value[1]["template"])
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) input = value[1]["template"]["input"]
if is_names:
min = value[1]["template"]["min"]
names = value[1]["template"]["names"]
max = len(names)
elif is_prefix:
prefix = value[1]["template"]["prefix"]
min = value[1]["template"]["min"]
max = value[1]["template"]["max"]
names = [f"{prefix}{i}" for i in range(max)]
# need to create a new input based on the contents of input
template_input = None
for _, dict_input in input.items():
# for now, get just the first value from dict_input
template_input = list(dict_input.values())[0]
new_dict = {}
for i, name in enumerate(names):
expected_id = finalize_prefix(curr_prefix, name)
if expected_id in live_inputs:
# required
if i < min:
type_dict = new_dict.setdefault("required", {})
# optional
else:
type_dict = new_dict.setdefault("optional", {})
type_dict[name] = template_input
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3") @comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
class DynamicCombo(ComfyTypeI): class DynamicCombo(ComfyTypeI):
@ -1023,23 +1034,6 @@ class DynamicCombo(ComfyTypeI):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.options = options self.options = options
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]:
return [input for option in self.options for input in option.inputs]
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + [input for option in self.options for input in option.inputs] return [self] + [input for option in self.options for input in option.inputs]
@ -1054,6 +1048,24 @@ class DynamicCombo(ComfyTypeI):
for input in option.inputs: for input in option.inputs:
input.validate() input.validate()
@staticmethod
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
finalized_id = finalize_prefix(curr_prefix)
if finalized_id in live_inputs:
key = live_inputs[finalized_id]
selected_option = None
# get options from dict
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
for option in options:
if option["key"] == key:
selected_option = option
break
if selected_option is not None:
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
# add self to inputs
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="COMFY_DYNAMICSLOT_V3") @comfytype(io_type="COMFY_DYNAMICSLOT_V3")
class DynamicSlot(ComfyTypeI): class DynamicSlot(ComfyTypeI):
Type = dict[str, Any] Type = dict[str, Any]
@ -1076,17 +1088,8 @@ class DynamicSlot(ComfyTypeI):
self.force_input = True self.force_input = True
self.slot.force_input = True self.slot.force_input = True
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
def get_dynamic(self) -> list[Input]:
return [self.slot] + self.inputs
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + [self.slot] + self.inputs return [self.slot] + self.inputs
def as_dict(self): def as_dict(self):
return super().as_dict() | prune_dict({ return super().as_dict() | prune_dict({
@ -1100,17 +1103,41 @@ class DynamicSlot(ComfyTypeI):
for input in self.inputs: for input in self.inputs:
input.validate() input.validate()
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): @staticmethod
dynamic = d.setdefault("dynamic_paths", {}) def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
if self is not None: finalized_id = finalize_prefix(curr_prefix)
dynamic[self.id] = f"{curr_prefix}{self.id}" if finalized_id in live_inputs:
for i in inputs: inputs = value[1]["inputs"]
if not isinstance(i, DynamicInput): parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" # add self to inputs
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
return DYNAMIC_INPUT_LOOKUP[io_type]
def setup_dynamic_input_funcs():
# Autogrow.Input
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
# DynamicCombo.Input
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
# DynamicSlot.Input
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
if len(DYNAMIC_INPUT_LOOKUP) == 0:
setup_dynamic_input_funcs()
class V3Data(TypedDict): class V3Data(TypedDict):
hidden_inputs: dict[str, Any] hidden_inputs: dict[str, Any]
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
dynamic_paths: dict[str, Any] dynamic_paths: dict[str, Any]
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).'
class HiddenHolder: class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any, def __init__(self, unique_id: str, prompt: Any,
@ -1146,6 +1173,10 @@ class HiddenHolder:
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
) )
@classmethod
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
class Hidden(str, Enum): class Hidden(str, Enum):
''' '''
Enumerator for requesting hidden variables in nodes. Enumerator for requesting hidden variables in nodes.
@ -1251,61 +1282,56 @@ class Schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other - verify ids on inputs and outputs are unique - both internally and in relation to each other
''' '''
nested_inputs: list[Input] = [] nested_inputs: list[Input] = []
if self.inputs is not None: for input in self.inputs:
for input in self.inputs: if not isinstance(input, DynamicInput):
nested_inputs.extend(input.get_all()) nested_inputs.extend(input.get_all())
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] input_ids = [i.id for i in nested_inputs]
output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] output_ids = [o.id for o in self.outputs]
input_set = set(input_ids) input_set = set(input_ids)
output_set = set(output_ids) output_set = set(output_ids)
issues = [] issues: list[str] = []
# verify ids are unique per list # verify ids are unique per list
if len(input_set) != len(input_ids): if len(input_set) != len(input_ids):
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
if len(output_set) != len(output_ids): if len(output_set) != len(output_ids):
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
# verify ids are unique between lists
intersection = input_set & output_set
if len(intersection) > 0:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0: if len(issues) > 0:
raise ValueError("\n".join(issues)) raise ValueError("\n".join(issues))
# validate inputs and outputs # validate inputs and outputs
if self.inputs is not None: for input in self.inputs:
for input in self.inputs: input.validate()
input.validate() for output in self.outputs:
if self.outputs is not None: output.validate()
for output in self.outputs:
output.validate()
def finalize(self): def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids.""" """Add hidden based on selected schema options, and give outputs without ids default ids."""
# ensure inputs, outputs, and hidden are lists
if self.inputs is None:
self.inputs = []
if self.outputs is None:
self.outputs = []
if self.hidden is None:
self.hidden = []
# if is an api_node, will need key-related hidden # if is an api_node, will need key-related hidden
if self.is_api_node: if self.is_api_node:
if self.hidden is None:
self.hidden = []
if Hidden.auth_token_comfy_org not in self.hidden: if Hidden.auth_token_comfy_org not in self.hidden:
self.hidden.append(Hidden.auth_token_comfy_org) self.hidden.append(Hidden.auth_token_comfy_org)
if Hidden.api_key_comfy_org not in self.hidden: if Hidden.api_key_comfy_org not in self.hidden:
self.hidden.append(Hidden.api_key_comfy_org) self.hidden.append(Hidden.api_key_comfy_org)
# if is an output_node, will need prompt and extra_pnginfo # if is an output_node, will need prompt and extra_pnginfo
if self.is_output_node: if self.is_output_node:
if self.hidden is None:
self.hidden = []
if Hidden.prompt not in self.hidden: if Hidden.prompt not in self.hidden:
self.hidden.append(Hidden.prompt) self.hidden.append(Hidden.prompt)
if Hidden.extra_pnginfo not in self.hidden: if Hidden.extra_pnginfo not in self.hidden:
self.hidden.append(Hidden.extra_pnginfo) self.hidden.append(Hidden.extra_pnginfo)
# give outputs without ids default ids # give outputs without ids default ids
if self.outputs is not None: for i, output in enumerate(self.outputs):
for i, output in enumerate(self.outputs): if output.id is None:
if output.id is None: output.id = f"_{i}_{output.io_type}_"
output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: def get_v1_info(self, cls) -> NodeInfoV1:
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs # get V1 inputs
input = create_input_dict_v1(self.inputs, live_inputs) input = create_input_dict_v1(self.inputs)
if self.hidden: if self.hidden:
for hidden in self.hidden: for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,) input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1385,33 +1411,54 @@ class Schema:
) )
return info return info
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
out_dict = {
"required": {},
"optional": {},
"dynamic_paths": {},
}
d = d.copy()
# ignore hidden for parsing
hidden = d.pop("hidden", None)
parse_class_inputs(out_dict, live_inputs, d)
if hidden is not None and include_hidden:
out_dict["hidden"] = hidden
v3_data = {}
dynamic_paths = out_dict.pop("dynamic_paths", None)
if dynamic_paths is not None:
v3_data["dynamic_paths"] = dynamic_paths
return out_dict, hidden, v3_data
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
for input_type, inner_d in curr_dict.items():
for id, value in inner_d.items():
io_type = value[0]
if io_type in DYNAMIC_INPUT_LOOKUP:
# dynamic inputs need to be handled with lookup functions
dynamic_input_func = get_dynamic_input_func(io_type)
new_prefix = handle_prefix(curr_prefix, id)
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
else:
# non-dynamic inputs get directly transferred
finalized_id = finalize_prefix(curr_prefix, id)
out_dict[input_type][finalized_id] = value
if curr_prefix:
out_dict["dynamic_paths"][finalized_id] = finalized_id
def create_input_dict_v1(inputs: list[Input]) -> dict:
input = { input = {
"required": {} "required": {}
} }
add_to_input_dict_v1(input, inputs, live_inputs) for i in inputs:
add_to_dict_v1(i, input)
return input return input
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): def add_to_dict_v1(i: Input, d: dict):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required" key = "optional" if i.optional else "required"
as_dict = i.as_dict() as_dict = i.as_dict()
# for v1, we don't want to include the optional key # for v1, we don't want to include the optional key
as_dict.pop("optional", None) as_dict.pop("optional", None)
if dynamic_dict is None: d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
value = (i.get_io_type(), as_dict)
else:
value = (i.get_io_type(), as_dict, dynamic_dict)
d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict): def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict()) d[io.id] = (io.get_io_type(), io.as_dict())
@ -1423,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
values = values.copy() values = values.copy()
result = {} result = {}
create_tuple = v3_data.get("create_dynamic_tuple", False)
for key, path in paths.items(): for key, path in paths.items():
parts = path.split(".") parts = path.split(".")
current = result current = result
@ -1431,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
is_last = (i == len(parts) - 1) is_last = (i == len(parts) - 1)
if is_last: if is_last:
current[p] = values.pop(key, None) value = values.pop(key, None)
if create_tuple:
value = (value, key)
current[p] = value
else: else:
current = current.setdefault(p, {}) current = current.setdefault(p, {})
@ -1446,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
SCHEMA = None SCHEMA = None
# filled in during execution # filled in during execution
resources: Resources = None
hidden: HiddenHolder = None hidden: HiddenHolder = None
@classmethod @classmethod
@ -1493,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
return [name for name in kwargs if kwargs[name] is None] return [name for name in kwargs if kwargs[name] is None]
def __init__(self): def __init__(self):
self.local_resources: ResourcesLocal = None
self.__class__.VALIDATE_CLASS() self.__class__.VALIDATE_CLASS()
@classmethod @classmethod
@ -1561,7 +1611,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type) type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden # set hidden
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
return type_clone return type_clone
@final @final
@ -1678,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: def INPUT_TYPES(cls) -> dict[str, dict]:
schema = cls.FINALIZE_SCHEMA() schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls, live_inputs) info = schema.get_v1_info(cls)
input = info.input return info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
v3_data: V3Data = {}
dynamic = input.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
return input, schema, v3_data
return input
@final @final
@classmethod @classmethod
@ -1809,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal):
return self.args if len(self.args) > 0 else None return self.args if len(self.args) > 0 else None
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
args = () args = ()
ui = None ui = None
expand = None expand = None
@ -1904,8 +1945,8 @@ __all__ = [
"Tracks", "Tracks",
# Dynamic Types # Dynamic Types
"MatchType", "MatchType",
# "DynamicCombo", "DynamicCombo",
# "Autogrow", "Autogrow",
# Other classes # Other classes
"HiddenHolder", "HiddenHolder",
"Hidden", "Hidden",

View File

@ -1,72 +0,0 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

View File

@ -1,5 +1,6 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH from .geometry_types import VOXEL, MESH
from .image_types import SVG
__all__ = [ __all__ = [
# Utility Types # Utility Types
@ -8,4 +9,5 @@ __all__ = [
"VideoComponents", "VideoComponents",
"VOXEL", "VOXEL",
"MESH", "MESH",
"SVG",
] ]

View File

@ -0,0 +1,18 @@
from io import BytesIO
class SVG:
"""Stores SVG representations via a list of BytesIO objects."""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)

View File

@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
mode: str = Field("pro") mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'") sound: str = Field(..., description="'on' or 'off'")
class MotionControlRequest(BaseModel):
prompt: str = Field(...)
image_url: str = Field(...)
video_url: str = Field(...)
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")

View File

@ -229,6 +229,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -269,7 +270,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="ByteDanceSeedreamNode", node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 4", display_name="ByteDance Seedream 4.5",
category="api node/image/ByteDance", category="api node/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[ inputs=[

View File

@ -51,6 +51,7 @@ from comfy_api_nodes.apis import (
) )
from comfy_api_nodes.apis.kling_api import ( from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest, ImageToVideoWithAudioRequest,
MotionControlRequest,
OmniImageParamImage, OmniImageParamImage,
OmniParamImage, OmniParamImage,
OmniParamVideo, OmniParamVideo,
@ -2163,6 +2164,91 @@ class ImageToVideoWithAudio(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class MotionControl(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingMotionControl",
display_name="Kling Motion Control",
category="api node/video/Kling",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.Image.Input("reference_image"),
IO.Video.Input(
"reference_video",
tooltip="Motion reference video used to drive movement/expression.\n"
"Duration limits depend on character_orientation:\n"
" - image: 310s (max 10s)\n"
" - video: 330s (max 30s)",
),
IO.Boolean.Input("keep_original_sound", default=True),
IO.Combo.Input(
"character_orientation",
options=["video", "image"],
tooltip="Controls where the character's facing/orientation comes from.\n"
"video: movements, expressions, camera moves, and orientation "
"follow the motion reference video (other details via prompt).\n"
"image: movements and expressions still follow the motion reference video, "
"but the character orientation matches the reference image (camera/other details via prompt).",
),
IO.Combo.Input("mode", options=["pro", "std"]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
reference_image: Input.Image,
reference_video: Input.Video,
keep_original_sound: bool,
character_orientation: str,
mode: str,
) -> IO.NodeOutput:
validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340)
validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1))
if character_orientation == "image":
validate_video_duration(reference_video, min_duration=3, max_duration=10)
else:
validate_video_duration(reference_video, min_duration=3, max_duration=30)
validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"),
response_model=TaskStatusResponse,
data=MotionControlRequest(
prompt=prompt,
image_url=(await upload_images_to_comfyapi(cls, reference_image))[0],
video_url=await upload_video_to_comfyapi(cls, reference_video),
keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation,
mode=mode,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension): class KlingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2188,6 +2274,7 @@ class KlingExtension(ComfyExtension):
OmniProImageNode, OmniProImageNode,
TextToVideoWithAudio, TextToVideoWithAudio,
ImageToVideoWithAudio, ImageToVideoWithAudio,
MotionControl,
] ]

View File

@ -168,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Only add generateAudio for Veo 3 models # Only add generateAudio for Veo 3 models
if model.find("veo-2.0") == -1: if model.find("veo-2.0") == -1:
parameters["generateAudio"] = generate_audio parameters["generateAudio"] = generate_audio
# force "enhance_prompt" to True for Veo3 models
parameters["enhancePrompt"] = True
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
@ -291,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
IO.Boolean.Input( IO.Boolean.Input(
"enhance_prompt", "enhance_prompt",
default=True, default=True,
tooltip="Whether to enhance the prompt with AI assistance", tooltip="This parameter is deprecated and ignored.",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(

View File

@ -1,16 +1,22 @@
import asyncio import asyncio
import contextlib import contextlib
import os import os
import re
import time import time
from collections.abc import Callable from collections.abc import Callable
from io import BytesIO from io import BytesIO
from yarl import URL
from comfy.cli_args import args from comfy.cli_args import args
from comfy.model_management import processing_interrupted from comfy.model_management import processing_interrupted
from comfy_api.latest import IO from comfy_api.latest import IO
from .common_exceptions import ProcessingInterrupted from .common_exceptions import ProcessingInterrupted
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
def is_processing_interrupted() -> bool: def is_processing_interrupted() -> bool:
"""Return True if user/runtime requested interruption.""" """Return True if user/runtime requested interruption."""
@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str): if isinstance(path_or_object, str):
return os.path.getsize(path_or_object) return os.path.getsize(path_or_object)
return len(path_or_object.getvalue()) return len(path_or_object.getvalue())
def to_aiohttp_url(url: str) -> URL:
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
escape and no malformed '%' sequences) and contains no raw whitespace/control
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
# Avoid encoded=True if URL contains raw whitespace/control chars
return URL(url)
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
return URL(url, encoded=True)
return URL(url)

View File

@ -430,9 +430,9 @@ def _display_text(
if status: if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
if price is not None: if price is not None:
p = f"{float(price):,.4f}".rstrip("0").rstrip(".") p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0": if p != "0":
display_lines.append(f"Price: ${p}") display_lines.append(f"Price: {p} credits")
if text is not None: if text is not None:
display_lines.append(text) display_lines.append(text)
if display_lines: if display_lines:

View File

@ -19,6 +19,7 @@ from ._helpers import (
get_auth_header, get_auth_header,
is_processing_interrupted, is_processing_interrupted,
sleep_with_interrupt, sleep_with_interrupt,
to_aiohttp_url,
) )
from .client import _diagnose_connectivity from .client import _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
@ -94,7 +95,7 @@ async def download_url_to_bytesio(
monitor_task = asyncio.create_task(_monitor()) monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(url, headers=headers)) req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending: if monitor_task in done and req_task in pending:

View File

@ -97,6 +97,11 @@ def get_input_info(
extra_info = input_info[1] extra_info = input_info[1]
else: else:
extra_info = {} extra_info = {}
# if input_type is a list, it is a Combo defined in outdated format; convert it.
# NOTE: uncomment this when we are confident old format going away won't cause too much trouble.
# if isinstance(input_type, list):
# extra_info["options"] = input_type
# input_type = IO.Combo.io_type
return input_type, input_category, extra_info return input_type, input_category, extra_info
class TopologicalSort: class TopologicalSort:

View File

@ -21,14 +21,24 @@ def validate_node_input(
""" """
# If the types are exactly the same, we can return immediately # If the types are exactly the same, we can return immediately
# Use pre-union behaviour: inverse of `__ne__` # Use pre-union behaviour: inverse of `__ne__`
# NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class.
if not received_type != input_type: if not received_type != input_type:
return True return True
# If one of the types is '*', we can return True immediately; this is the 'Any' type.
if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type:
return True
# If the received type or input_type is a MatchType, we can return True immediately; # If the received type or input_type is a MatchType, we can return True immediately;
# validation for this is handled by the frontend # validation for this is handled by the frontend
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
return True return True
# This accounts for some custom nodes that output lists of options as the type;
# if we ever want to break them on purpose, this can be removed
if isinstance(received_type, list) and input_type == IO.Combo.io_type:
return True
# Not equal, and not strings # Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str): if not isinstance(received_type, str) or not isinstance(input_type, str):
return False return False
@ -37,6 +47,10 @@ def validate_node_input(
received_types = set(t.strip() for t in received_type.split(",")) received_types = set(t.strip() for t in received_type.split(","))
input_types = set(t.strip() for t in input_type.split(",")) input_types = set(t.strip() for t in input_type.split(","))
# If any of the types is '*', we can return True immediately; this is the 'Any' type.
if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types:
return True
if strict: if strict:
# In strict mode, all received types must be in the input types # In strict mode, all received types must be in the input types
return received_types.issubset(input_types) return received_types.issubset(input_types)

View File

@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
@classmethod @classmethod
def _process(cls, image, longer_edge): def _process(cls, image, longer_edge):
img = tensor_to_pil(image) resized_images = []
w, h = img.size for image_i in image:
if w > h: img = tensor_to_pil(image_i)
new_w = longer_edge w, h = img.size
new_h = int(h * (longer_edge / w)) if w > h:
else: new_w = longer_edge
new_h = longer_edge new_h = int(h * (longer_edge / w))
new_w = int(w * (longer_edge / h)) else:
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) new_h = longer_edge
return pil_to_tensor(img) new_w = int(w * (longer_edge / h))
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
resized_images.append(pil_to_tensor(img))
return torch.cat(resized_images, dim=0)
class CenterCropImagesNode(ImageProcessingNode): class CenterCropImagesNode(ImageProcessingNode):

View File

@ -2,280 +2,231 @@ from __future__ import annotations
import nodes import nodes
import folder_paths import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json import json
import os import os
import re import re
from io import BytesIO
from inspect import cleandoc
import torch import torch
import comfy.utils import comfy.utils
from comfy.comfy_types import FileLocator, IO
from server import PromptServer from server import PromptServer
from comfy_api.latest import ComfyExtension, IO, UI
from typing_extensions import override
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
MAX_RESOLUTION = nodes.MAX_RESOLUTION MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop: class ImageCrop(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), node_id="ImageCrop",
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), display_name="Image Crop",
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), category="image/transform",
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), inputs=[
}} IO.Image.Input("image"),
RETURN_TYPES = ("IMAGE",) IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
FUNCTION = "crop" IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
def crop(self, image, width, height, x, y):
x = min(x, image.shape[2] - 1) x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1) y = min(y, image.shape[1] - 1)
to_x = width + x to_x = width + x
to_y = height + y to_y = height + y
img = image[:,y:to_y, x:to_x, :] img = image[:,y:to_y, x:to_x, :]
return (img,) return IO.NodeOutput(img)
class RepeatImageBatch: crop = execute # TODO: remove
class RepeatImageBatch(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}), node_id="RepeatImageBatch",
}} category="image/batch",
RETURN_TYPES = ("IMAGE",) inputs=[
FUNCTION = "repeat" IO.Image.Input("image"),
IO.Int.Input("amount", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/batch" @classmethod
def execute(cls, image, amount) -> IO.NodeOutput:
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1)) s = image.repeat((amount, 1,1,1))
return (s,) return IO.NodeOutput(s)
class ImageFromBatch: repeat = execute # TODO: remove
class ImageFromBatch(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), node_id="ImageFromBatch",
"length": ("INT", {"default": 1, "min": 1, "max": 4096}), category="image/batch",
}} inputs=[
RETURN_TYPES = ("IMAGE",) IO.Image.Input("image"),
FUNCTION = "frombatch" IO.Int.Input("batch_index", default=0, min=0, max=4095),
IO.Int.Input("length", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/batch" @classmethod
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
def frombatch(self, image, batch_index, length):
s_in = image s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index) batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length) length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone() s = s_in[batch_index:batch_index + length].clone()
return (s,) return IO.NodeOutput(s)
frombatch = execute # TODO: remove
class ImageAddNoise: class ImageAddNoise(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), node_id="ImageAddNoise",
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), category="image",
}} inputs=[
RETURN_TYPES = ("IMAGE",) IO.Image.Input("image"),
FUNCTION = "repeat" IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image" @classmethod
def execute(cls, image, seed, strength) -> IO.NodeOutput:
def repeat(self, image, seed, strength):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
return (s,) return IO.NodeOutput(s)
class SaveAnimatedWEBP: repeat = execute # TODO: remove
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = () class SaveAnimatedWEBP(IO.ComfyNode):
FUNCTION = "save_images" COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
OUTPUT_NODE = True
CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results: list[FileLocator] = []
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return IO.Schema(
{"images": ("IMAGE", ), node_id="SaveAnimatedWEBP",
"filename_prefix": ("STRING", {"default": "ComfyUI"}), category="image/animation",
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), inputs=[
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) IO.Image.Input("images"),
}, IO.String.Input("filename_prefix", default="ComfyUI"),
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
} IO.Boolean.Input("lossless", default=True),
IO.Int.Input("quality", default=80, min=0, max=100),
IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
RETURN_TYPES = () @classmethod
FUNCTION = "save_images" def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
images=images,
filename_prefix=filename_prefix,
cls=cls,
fps=fps,
lossless=lossless,
quality=quality,
method=cls.COMPRESS_METHODS.get(method)
)
)
OUTPUT_NODE = True save_images = execute # TODO: remove
CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch: class SaveAnimatedPNG(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAnimatedPNG",
category="image/animation",
inputs=[
IO.Image.Input("images"),
IO.String.Input("filename_prefix", default="ComfyUI"),
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
IO.Int.Input("compress_level", default=4, min=0, max=9),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
images=images,
filename_prefix=filename_prefix,
cls=cls,
fps=fps,
compress_level=compress_level,
)
)
save_images = execute # TODO: remove
class ImageStitch(IO.ComfyNode):
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ImageStitch",
"image1": ("IMAGE",), display_name="Image Stitch",
"direction": (["right", "down", "left", "up"], {"default": "right"}), description="Stitches image2 to image1 in the specified direction.\n"
"match_image_size": ("BOOLEAN", {"default": True}), "If image2 is not provided, returns image1 unchanged.\n"
"spacing_width": ( "Optional spacing can be added between images.",
"INT", category="image/transform",
{"default": 0, "min": 0, "max": 1024, "step": 2}, inputs=[
), IO.Image.Input("image1"),
"spacing_color": ( IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
["white", "black", "red", "green", "blue"], IO.Boolean.Input("match_image_size", default=True),
{"default": "white"}, IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2),
), IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"),
}, IO.Image.Input("image2", optional=True),
"optional": { ],
"image2": ("IMAGE",), outputs=[IO.Image.Output()],
}, )
}
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "stitch" def execute(
CATEGORY = "image/transform" cls,
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1, image1,
direction, direction,
match_image_size, match_image_size,
spacing_width, spacing_width,
spacing_color, spacing_color,
image2=None, image2=None,
): ) -> IO.NodeOutput:
if image2 is None: if image2 is None:
return (image1,) return IO.NodeOutput(image1)
# Handle batch size differences # Handle batch size differences
if image1.shape[0] != image2.shape[0]: if image1.shape[0] != image2.shape[0]:
@ -412,36 +363,30 @@ Optional spacing can be added between images.
images.insert(1, spacing) images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1 concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),) return IO.NodeOutput(torch.cat(images, dim=concat_dim))
stitch = execute # TODO: remove
class ResizeAndPadImage(IO.ComfyNode):
class ResizeAndPadImage:
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ResizeAndPadImage",
"image": ("IMAGE",), category="image/transform",
"target_width": ("INT", { inputs=[
"default": 512, IO.Image.Input("image"),
"min": 1, IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
"max": MAX_RESOLUTION, IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
"step": 1 IO.Combo.Input("padding_color", options=["white", "black"]),
}), IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]),
"target_height": ("INT", { ],
"default": 512, outputs=[IO.Image.Output()],
"min": 1, )
"max": MAX_RESOLUTION,
"step": 1
}),
"padding_color": (["white", "black"],),
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
}
}
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "resize_and_pad" def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
CATEGORY = "image/transform"
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
batch_size, orig_height, orig_width, channels = image.shape batch_size, orig_height, orig_width, channels = image.shape
scale_w = target_width / orig_width scale_w = target_width / orig_width
@ -469,52 +414,47 @@ class ResizeAndPadImage:
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
output = padded.permute(0, 2, 3, 1) output = padded.permute(0, 2, 3, 1)
return (output,) return IO.NodeOutput(output)
class SaveSVGNode: resize_and_pad = execute # TODO: remove
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = () class SaveSVGNode(IO.ComfyNode):
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "image/save" # Changed
OUTPUT_NODE = True
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="SaveSVGNode",
"svg": ("SVG",), # Changed description="Save SVG files on disk.",
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) category="image/save",
}, inputs=[
"hidden": { IO.SVG.Input("svg"),
"prompt": "PROMPT", IO.String.Input(
"extra_pnginfo": "EXTRA_PNGINFO" "filename_prefix",
} default="svg/ComfyUI",
} tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): @classmethod
filename_prefix += self.prefix_append def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = list() results: list[UI.SavedResult] = []
# Prepare metadata JSON # Prepare metadata JSON
metadata_dict = {} metadata_dict = {}
if prompt is not None: if cls.hidden.prompt is not None:
metadata_dict["prompt"] = prompt metadata_dict["prompt"] = cls.hidden.prompt
if extra_pnginfo is not None: if cls.hidden.extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo) metadata_dict.update(cls.hidden.extra_pnginfo)
# Convert metadata to JSON string # Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data): for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg" file = f"{filename_with_batch_num}_{counter:05}_.svg"
@ -544,57 +484,64 @@ class SaveSVGNode:
with open(os.path.join(full_output_folder, file), 'wb') as svg_file: with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8')) svg_file.write(svg_content.encode('utf-8'))
results.append({ results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1 counter += 1
return { "ui": { "images": results } } return IO.NodeOutput(ui={"images": results})
class GetImageSize: save_svg = execute # TODO: remove
class GetImageSize(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="GetImageSize",
"image": (IO.IMAGE,), display_name="Get Image Size",
}, description="Returns width and height of the image, and passes it through unchanged.",
"hidden": { category="image",
"unique_id": "UNIQUE_ID", inputs=[
} IO.Image.Input("image"),
} ],
outputs=[
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
IO.Int.Output(display_name="batch_size"),
],
hidden=[IO.Hidden.unique_id],
)
RETURN_TYPES = (IO.INT, IO.INT, IO.INT) @classmethod
RETURN_NAMES = ("width", "height", "batch_size") def execute(cls, image) -> IO.NodeOutput:
FUNCTION = "get_size"
CATEGORY = "image"
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
def get_size(self, image, unique_id=None) -> tuple[int, int]:
height = image.shape[1] height = image.shape[1]
width = image.shape[2] width = image.shape[2]
batch_size = image.shape[0] batch_size = image.shape[0]
# Send progress text to display size on the node # Send progress text to display size on the node
if unique_id: if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
return width, height, batch_size return IO.NodeOutput(width, height, batch_size)
get_size = execute # TODO: remove
class ImageRotate(IO.ComfyNode):
class ImageRotate:
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": (IO.IMAGE,), return IO.Schema(
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), node_id="ImageRotate",
}} category="image/transform",
RETURN_TYPES = (IO.IMAGE,) inputs=[
FUNCTION = "rotate" IO.Image.Input("image"),
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, rotation) -> IO.NodeOutput:
def rotate(self, image, rotation):
rotate_by = 0 rotate_by = 0
if rotation.startswith("90"): if rotation.startswith("90"):
rotate_by = 1 rotate_by = 1
@ -604,41 +551,57 @@ class ImageRotate:
rotate_by = 3 rotate_by = 3
image = torch.rot90(image, k=rotate_by, dims=[2, 1]) image = torch.rot90(image, k=rotate_by, dims=[2, 1])
return (image,) return IO.NodeOutput(image)
rotate = execute # TODO: remove
class ImageFlip(IO.ComfyNode):
class ImageFlip:
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": (IO.IMAGE,), return IO.Schema(
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],), node_id="ImageFlip",
}} category="image/transform",
RETURN_TYPES = (IO.IMAGE,) inputs=[
FUNCTION = "flip" IO.Image.Input("image"),
IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, flip_method) -> IO.NodeOutput:
def flip(self, image, flip_method):
if flip_method.startswith("x"): if flip_method.startswith("x"):
image = torch.flip(image, dims=[1]) image = torch.flip(image, dims=[1])
elif flip_method.startswith("y"): elif flip_method.startswith("y"):
image = torch.flip(image, dims=[2]) image = torch.flip(image, dims=[2])
return (image,) return IO.NodeOutput(image)
class ImageScaleToMaxDimension: flip = execute # TODO: remove
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
class ImageScaleToMaxDimension(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE",), return IO.Schema(
"upscale_method": (s.upscale_methods,), node_id="ImageScaleToMaxDimension",
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} category="image/upscaling",
RETURN_TYPES = ("IMAGE",) inputs=[
FUNCTION = "upscale" IO.Image.Input("image"),
IO.Combo.Input(
"upscale_method",
options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
),
IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/upscaling" @classmethod
def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
def upscale(self, image, upscale_method, largest_size):
height = image.shape[1] height = image.shape[1]
width = image.shape[2] width = image.shape[2]
@ -655,20 +618,30 @@ class ImageScaleToMaxDimension:
samples = image.movedim(-1, 1) samples = image.movedim(-1, 1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1, -1) s = s.movedim(1, -1)
return (s,) return IO.NodeOutput(s)
NODE_CLASS_MAPPINGS = { upscale = execute # TODO: remove
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch, class ImagesExtension(ComfyExtension):
"ImageAddNoise": ImageAddNoise, @override
"SaveAnimatedWEBP": SaveAnimatedWEBP, async def get_node_list(self) -> list[type[IO.ComfyNode]]:
"SaveAnimatedPNG": SaveAnimatedPNG, return [
"SaveSVGNode": SaveSVGNode, ImageCrop,
"ImageStitch": ImageStitch, RepeatImageBatch,
"ResizeAndPadImage": ResizeAndPadImage, ImageFromBatch,
"GetImageSize": GetImageSize, ImageAddNoise,
"ImageRotate": ImageRotate, SaveAnimatedWEBP,
"ImageFlip": ImageFlip, SaveAnimatedPNG,
"ImageScaleToMaxDimension": ImageScaleToMaxDimension, SaveSVGNode,
} ImageStitch,
ResizeAndPadImage,
GetImageSize,
ImageRotate,
ImageFlip,
ImageScaleToMaxDimension,
]
async def comfy_entrypoint() -> ImagesExtension:
return ImagesExtension()

View File

@ -255,6 +255,7 @@ class LatentBatch(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="LatentBatch", node_id="LatentBatch",
category="latent/batch", category="latent/batch",
is_deprecated=True,
inputs=[ inputs=[
io.Latent.Input("samples1"), io.Latent.Input("samples1"),
io.Latent.Input("samples2"), io.Latent.Input("samples2"),

View File

@ -1,8 +1,11 @@
from __future__ import annotations
from typing import TypedDict from typing import TypedDict
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from comfy_api.latest import _io from comfy_api.latest import _io
# sentinel for missing inputs
MISSING = object()
class SwitchNode(io.ComfyNode): class SwitchNode(io.ComfyNode):
@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode):
display_name="Switch", display_name="Switch",
category="logic", category="logic",
is_experimental=True, is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True),
io.MatchType.Input("on_true", template=template, lazy=True),
],
outputs=[
io.MatchType.Output(template=template, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, switch, on_false=None, on_true=None):
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
@classmethod
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
return io.NodeOutput(on_true if switch else on_false)
class SoftSwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="logic",
is_experimental=True,
inputs=[ inputs=[
io.Boolean.Input("switch"), io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True, optional=True), io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode):
) )
@classmethod @classmethod
def check_lazy_status(cls, switch, on_false=..., on_true=...): def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs. # We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
# This trick allows us to ignore the value of the switch and still be able to run execute(). # This trick allows us to ignore the value of the switch and still be able to run execute().
# One of the inputs may be missing, in which case we need to evaluate the other input # One of the inputs may be missing, in which case we need to evaluate the other input
if on_false is ...: if on_false is MISSING:
return ["on_true"] return ["on_true"]
if on_true is ...: if on_true is MISSING:
return ["on_false"] return ["on_false"]
# Normal lazy switch operation # Normal lazy switch operation
if switch and on_true is None: if switch and on_true is None:
@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode):
return ["on_false"] return ["on_false"]
@classmethod @classmethod
def validate_inputs(cls, switch, on_false=..., on_true=...): def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
# This check happens before check_lazy_status(), so we can eliminate the case where # This check happens before check_lazy_status(), so we can eliminate the case where
# both inputs are missing. # both inputs are missing.
if on_false is ... and on_true is ...: if on_false is MISSING and on_true is MISSING:
return "At least one of on_false or on_true must be connected to Switch node" return "At least one of on_false or on_true must be connected to Switch node"
return True return True
@classmethod @classmethod
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
if on_true is ...: if on_true is MISSING:
return io.NodeOutput(on_false) return io.NodeOutput(on_false)
if on_false is ...: if on_false is MISSING:
return io.NodeOutput(on_true) return io.NodeOutput(on_true)
return io.NodeOutput(on_true if switch else on_false) return io.NodeOutput(on_true if switch else on_false)
class CustomComboNode(io.ComfyNode):
"""
Frontend node that allows user to write their own options for a combo.
This is here to make sure the node has a backend-representation to avoid some annoyances.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CustomCombo",
display_name="Custom Combo",
category="utils",
is_experimental=True,
inputs=[io.Combo.Input("choice", options=[])],
outputs=[io.String.Output()]
)
@classmethod
def validate_inputs(cls, choice: io.Combo.Type) -> bool:
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
return True
@classmethod
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
return io.NodeOutput(choice)
class DCTestNode(io.ComfyNode): class DCTestNode(io.ComfyNode):
class DCValues(TypedDict): class DCValues(TypedDict):
combo: str combo: str
@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode):
display_name="DCTest", display_name="DCTest",
category="logic", category="logic",
is_output_node=True, is_output_node=True,
inputs=[_io.DynamicCombo.Input("combo", options=[ inputs=[io.DynamicCombo.Input("combo", options=[
_io.DynamicCombo.Option("option1", [io.String.Input("string")]), io.DynamicCombo.Option("option1", [io.String.Input("string")]),
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]), io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
_io.DynamicCombo.Option("option4", [ io.DynamicCombo.Option("option4", [
_io.DynamicCombo.Input("subcombo", options=[ io.DynamicCombo.Input("subcombo", options=[
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
]) ])
])] ])]
)], )],
@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode):
combined = ",".join([str(x) for x in vals]) combined = ",".join([str(x) for x in vals])
return io.NodeOutput(combined) return io.NodeOutput(combined)
class ComboOutputTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
)
@classmethod
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
return io.NodeOutput(combo, combo2)
class ConvertStringToComboNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertStringToComboNode",
display_name="Convert String to Combo",
category="logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@classmethod
def execute(cls, string: str) -> io.NodeOutput:
return io.NodeOutput(string)
class InvertBooleanNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="InvertBooleanNode",
display_name="Invert Boolean",
category="logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)
@classmethod
def execute(cls, boolean: bool) -> io.NodeOutput:
return io.NodeOutput(not boolean)
class LogicExtension(ComfyExtension): class LogicExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
# SwitchNode, SwitchNode,
CustomComboNode,
# SoftSwitchNode,
# ConvertStringToComboNode,
# DCTestNode, # DCTestNode,
# AutogrowNamesTestNode, # AutogrowNamesTestNode,
# AutogrowPrefixTestNode, # AutogrowPrefixTestNode,
# ComboOutputTestNode,
# InvertBooleanNode,
] ]
async def comfy_entrypoint() -> LogicExtension: async def comfy_entrypoint() -> LogicExtension:

View File

@ -4,11 +4,15 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
import math import math
from enum import Enum
from typing import TypedDict, Literal
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from comfy_extras.nodes_latent import reshape_latent_to
import node_helpers import node_helpers
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from nodes import MAX_RESOLUTION
class Blend(io.ComfyNode): class Blend(io.ComfyNode):
@classmethod @classmethod
@ -241,6 +245,353 @@ class ImageScaleToTotalPixels(io.ComfyNode):
s = s.movedim(1,-1) s = s.movedim(1,-1)
return io.NodeOutput(s) return io.NodeOutput(s)
class ResizeType(str, Enum):
SCALE_BY = "scale by multiplier"
SCALE_DIMENSIONS = "scale dimensions"
SCALE_LONGER_DIMENSION = "scale longer dimension"
SCALE_SHORTER_DIMENSION = "scale shorter dimension"
SCALE_WIDTH = "scale width"
SCALE_HEIGHT = "scale height"
SCALE_TOTAL_PIXELS = "scale total pixels"
MATCH_SIZE = "match size"
def is_image(input: torch.Tensor) -> bool:
# images have 4 dimensions: [batch, height, width, channels]
# masks have 3 dimensions: [batch, height, width]
return len(input.shape) == 4
def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
if is_type_image:
input = input.movedim(-1, 1)
else:
input = input.unsqueeze(1)
return input
def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
if is_type_image:
input = input.movedim(1, -1)
else:
input = input.squeeze(1)
return input
def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = round(input.shape[-1] * multiplier)
height = round(input.shape[-2] * multiplier)
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor:
if width == 0 and height == 0:
return input
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
if width == 0:
width = max(1, round(input.shape[-1] * height / input.shape[-2]))
elif height == 0:
height = max(1, round(input.shape[-2] * width / input.shape[-1]))
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = input.shape[-1]
height = input.shape[-2]
if height > width:
width = round((width / height) * longer_size)
height = longer_size
elif width > height:
height = round((height / width) * longer_size)
width = longer_size
else:
height = longer_size
width = longer_size
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
width = input.shape[-1]
height = input.shape[-2]
if height < width:
width = round((width / height) * shorter_size)
height = shorter_size
elif width > height:
height = round((height / width) * shorter_size)
width = shorter_size
else:
height = shorter_size
width = shorter_size
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
total = int(megapixels * 1024 * 1024)
scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2]))
width = round(input.shape[-1] * scale_by)
height = round(input.shape[-2] * scale_by)
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor:
is_type_image = is_image(input)
input = init_image_mask_input(input, is_type_image)
match = init_image_mask_input(match, is_image(match))
width = match.shape[-1]
height = match.shape[-2]
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
input = finalize_image_mask_input(input, is_type_image)
return input
class ResizeImageMaskNode(io.ComfyNode):
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
class ResizeTypedDict(TypedDict):
resize_type: ResizeType
scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop: Literal["disabled", "center"]
multiplier: float
width: int
height: int
longer_size: int
shorter_size: int
megapixels: float
@classmethod
def define_schema(cls):
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
return io.Schema(
node_id="ResizeImageMaskNode",
display_name="Resize Image/Mask",
category="transform",
inputs=[
io.MatchType.Input("input", template=template),
io.DynamicCombo.Input("resize_type", options=[
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
]),
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
]),
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
io.MultiType.Input("match", [io.Image, io.Mask]),
crop_combo,
]),
]),
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
],
outputs=[io.MatchType.Output(template=template, display_name="resized")]
)
@classmethod
def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput:
selected_type = resize_type["resize_type"]
if selected_type == ResizeType.SCALE_BY:
return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method))
elif selected_type == ResizeType.SCALE_DIMENSIONS:
return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"]))
elif selected_type == ResizeType.SCALE_LONGER_DIMENSION:
return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method))
elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION:
return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method))
elif selected_type == ResizeType.SCALE_WIDTH:
return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method))
elif selected_type == ResizeType.SCALE_HEIGHT:
return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method))
elif selected_type == ResizeType.SCALE_TOTAL_PIXELS:
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
elif selected_type == ResizeType.MATCH_SIZE:
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
raise ValueError(f"Unsupported resize type: {selected_type}")
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
if len(images) == 0:
return None
# first, get the max channels count
max_channels = max(image.shape[-1] for image in images)
# then, pad all images to have the same channels count
padded_images: list[torch.Tensor] = []
for image in images:
if image.shape[-1] < max_channels:
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
else:
padded_images.append(image)
# resize all images to be the same size as the first image
resized_images: list[torch.Tensor] = []
first_image_shape = padded_images[0].shape
for image in padded_images:
if image.shape[1:] != first_image_shape[1:]:
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
else:
resized_images.append(image)
# batch the images in the format [b, h, w, c]
return torch.cat(resized_images, dim=0)
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
if len(masks) == 0:
return None
# resize all masks to be the same size as the first mask
resized_masks: list[torch.Tensor] = []
first_mask_shape = masks[0].shape
for mask in masks:
if mask.shape[1:] != first_mask_shape[1:]:
mask = init_image_mask_input(mask, is_type_image=False)
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
else:
resized_masks.append(mask)
# batch the masks in the format [b, h, w]
return torch.cat(resized_masks, dim=0)
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
if len(latents) == 0:
return None
samples_out = latents[0].copy()
samples_out["batch_index"] = []
first_samples = latents[0]["samples"]
tensors: list[torch.Tensor] = []
for latent in latents:
# first, deal with latent tensors
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
# next, deal with batch_index
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
samples_out["samples"] = torch.cat(tensors, dim=0)
return samples_out
class BatchImagesNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
return io.Schema(
node_id="BatchImagesNode",
display_name="Batch Images",
category="image",
inputs=[
io.Autogrow.Input("images", template=autogrow_template)
],
outputs=[
io.Image.Output()
]
)
@classmethod
def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_images(list(images.values())))
class BatchMasksNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
return io.Schema(
node_id="BatchMasksNode",
display_name="Batch Masks",
category="mask",
inputs=[
io.Autogrow.Input("masks", template=autogrow_template)
],
outputs=[
io.Mask.Output()
]
)
@classmethod
def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_masks(list(masks.values())))
class BatchLatentsNode(io.ComfyNode):
@classmethod
def define_schema(cls):
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
return io.Schema(
node_id="BatchLatentsNode",
display_name="Batch Latents",
category="latent",
inputs=[
io.Autogrow.Input("latents", template=autogrow_template)
],
outputs=[
io.Latent.Output()
]
)
@classmethod
def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(batch_latents(list(latents.values())))
class BatchImagesMasksLatentsNode(io.ComfyNode):
@classmethod
def define_schema(cls):
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
autogrow_template = io.Autogrow.TemplatePrefix(
io.MatchType.Input("input", matchtype_template),
prefix="input", min=1, max=50)
return io.Schema(
node_id="BatchImagesMasksLatentsNode",
display_name="Batch Images/Masks/Latents",
category="util",
inputs=[
io.Autogrow.Input("inputs", template=autogrow_template)
],
outputs=[
io.MatchType.Output(id=None, template=matchtype_template)
]
)
@classmethod
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
batched = None
values = list(inputs.values())
# latents
if isinstance(values[0], dict):
batched = batch_latents(values)
# images
elif is_image(values[0]):
batched = batch_images(values)
# masks
else:
batched = batch_masks(values)
return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension): class PostProcessingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -250,6 +601,11 @@ class PostProcessingExtension(ComfyExtension):
Quantize, Quantize,
Sharpen, Sharpen,
ImageScaleToTotalPixels, ImageScaleToTotalPixels,
ResizeImageMaskNode,
BatchImagesNode,
BatchMasksNode,
BatchLatentsNode,
# BatchImagesMasksLatentsNode,
] ]
async def comfy_entrypoint() -> PostProcessingExtension: async def comfy_entrypoint() -> PostProcessingExtension:

View File

@ -66,7 +66,7 @@ class Float(io.ComfyNode):
display_name="Float", display_name="Float",
category="utils/primitive", category="utils/primitive",
inputs=[ inputs=[
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1),
], ],
outputs=[io.Float.Output()], outputs=[io.Float.Output()],
) )

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.6.0" __version__ = "0.7.0"

View File

@ -79,7 +79,7 @@ class IsChangedCache:
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal) is_v3 = issubclass(class_def, _ComfyNodeInternal)
v3_data: io.V3Data = {} v3_data: io.V3Data = {}
hidden_inputs_v3 = {}
valid_inputs = class_def.INPUT_TYPES()
if is_v3: if is_v3:
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs)
else:
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
missing_keys = {} missing_keys = {}
hidden_inputs_v3 = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
_, input_category, input_info = get_input_info(class_def, x, valid_inputs) _, input_category, input_info = get_input_info(class_def, x, valid_inputs)
@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
if is_v3: if is_v3:
if schema.hidden: if hidden is not None:
if io.Hidden.prompt in schema.hidden: if io.Hidden.prompt.name in hidden:
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
if io.Hidden.dynprompt in schema.hidden: if io.Hidden.dynprompt.name in hidden:
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
if io.Hidden.extra_pnginfo in schema.hidden: if io.Hidden.extra_pnginfo.name in hidden:
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
if io.Hidden.unique_id in schema.hidden: if io.Hidden.unique_id.name in hidden:
hidden_inputs_v3[io.Hidden.unique_id] = unique_id hidden_inputs_v3[io.Hidden.unique_id] = unique_id
if io.Hidden.auth_token_comfy_org in schema.hidden: if io.Hidden.auth_token_comfy_org.name in hidden:
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
if io.Hidden.api_key_comfy_org in schema.hidden: if io.Hidden.api_key_comfy_org.name in hidden:
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
else: else:
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
pre_execute_cb(index) pre_execute_cb(index)
# V3 # V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# if is just a class, then assign no resources or state, just create clone # if is just a class, then assign no state, just create clone
if is_class(obj): if is_class(obj):
type_obj = obj type_obj = obj
obj.VALIDATE_CLASS() obj.VALIDATE_CLASS()
@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else: else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present: if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) # for check_lazy_status, the returned data should include the original key of the input
v3_data_lazy = v3_data.copy()
v3_data_lazy["create_dynamic_tuple"] = True
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy)
required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@ -756,10 +758,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors = [] errors = []
valid = True valid = True
v3_data = None
validate_function_inputs = [] validate_function_inputs = []
validate_has_kwargs = False validate_has_kwargs = False
if issubclass(obj_class, _ComfyNodeInternal): if issubclass(obj_class, _ComfyNodeInternal):
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) obj_class: _io._ComfyNodeBaseInternal
class_inputs = obj_class.INPUT_TYPES()
class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs)
validate_function_name = "validate_inputs" validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name) validate_function = first_real_override(obj_class, validate_function_name)
else: else:
@ -779,10 +784,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
assert extra_info is not None assert extra_info is not None
if x not in inputs: if x not in inputs:
if input_category == "required": if input_category == "required":
details = f"{x}" if not v3_data else x.split(".")[-1]
error = { error = {
"type": "required_input_missing", "type": "required_input_missing",
"message": "Required input is missing", "message": "Required input is missing",
"details": f"{x}", "details": details,
"extra_info": { "extra_info": {
"input_name": x "input_name": x
} }
@ -916,8 +922,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if isinstance(input_type, list): if isinstance(input_type, list) or input_type == io.Combo.io_type:
combo_options = input_type if input_type == io.Combo.io_type:
combo_options = extra_info.get("options", [])
else:
combo_options = input_type
if val not in combo_options: if val not in combo_options:
input_config = info input_config = info
list_info = "" list_info = ""

View File

@ -1 +1 @@
comfyui_manager==4.0.3b7 comfyui_manager==4.0.4

View File

@ -1863,6 +1863,7 @@ class ImageBatch:
FUNCTION = "batch" FUNCTION = "batch"
CATEGORY = "image" CATEGORY = "image"
DEPRECATED = True
def batch(self, image1, image2): def batch(self, image1, image2):
if image1.shape[-1] != image2.shape[-1]: if image1.shape[-1] != image2.shape[-1]:

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.6.0" version = "0.7.0"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.35.9 comfyui-frontend-package==1.35.9
comfyui-workflow-templates==0.7.64 comfyui-workflow-templates==0.7.65
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.3.1
torch torch
torchsde torchsde

View File

@ -25,7 +25,7 @@ class TestImageStitch:
result = node.stitch(image1, "right", True, 0, "white", image2=None) result = node.stitch(image1, "right", True, 0, "white", image2=None)
assert len(result) == 1 assert len(result.result) == 1
assert torch.equal(result[0], image1) assert torch.equal(result[0], image1)
def test_basic_horizontal_stitch_right(self): def test_basic_horizontal_stitch_right(self):