diff --git a/README.md b/README.md index b0f62695b..6d09758c0 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 +torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. + ### Instructions: Git clone this repo. diff --git a/app/model_manager.py b/app/model_manager.py index ab36bca74..f124d1117 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -44,7 +44,7 @@ class ModelFileManager: @routes.get("/experiment/models/{folder}") async def get_all_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = self.get_model_file_list(folder) return web.json_response(files) @@ -55,7 +55,7 @@ class ModelFileManager: path_index = int(request.match_info.get("path_index", None)) filename = request.match_info.get("filename", None) - if not folder_name in folder_paths.folder_names_and_paths: + if folder_name not in folder_paths.folder_names_and_paths: return web.Response(status=404) folders = folder_paths.folder_names_and_paths[folder_name] diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7c0cadab5..e88872728 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -2,6 +2,25 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.ops +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + if crop: + scale = (size / min(image.shape[2], image.shape[3])) + scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) + else: + scale_size = (size, size) + + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) + class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): super().__init__() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 447b1ce4a..d5fc53497 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,6 +1,5 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os -import torch import json import logging @@ -17,24 +16,7 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): - image = image[:, :, :, :3] if image.shape[3] > 3 else image - mean = torch.tensor(mean, device=image.device, dtype=image.dtype) - std = torch.tensor(std, device=image.device, dtype=image.dtype) - image = image.movedim(-1, 1) - if not (image.shape[2] == size and image.shape[3] == size): - if crop: - scale = (size / min(image.shape[2], image.shape[3])) - scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) - else: - scale_size = (size, size) - - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] - image = torch.clip((255. * image), 0, 255).round() / 255.0 - return (image - mean.view([3,1,1])) / std.view([3,1,1]) +clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually IMAGE_ENCODERS = { "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, @@ -73,7 +55,7 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 1e0f86026..2f82d51da 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -188,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC): audio_cond = cond_value.cond if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) + # Handle vace_context (temporal dim is 3) + elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + vace_cond = cond_value.cond + if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim): + sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list) + new_cond_item[cond_key] = cond_value._copy_with(sliced_vace) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ diff --git a/comfy/hooks.py b/comfy/hooks.py index 9d0731072..1a76c7ba4 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -527,7 +527,8 @@ class HookKeyframeGroup: if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further - else: break + else: + break # update steps current context is used self._current_used_steps += 1 # update current timestep this was performed on diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 1ba9edad7..0949dee44 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): def default_noise_sampler(x, seed=None): if seed is not None: + if x.device == torch.device("cpu"): + seed += 1 + generator = torch.Generator(device=x.device) generator.manual_seed(seed) else: diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 70d173889..4fb56165e 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -270,7 +270,7 @@ class ChromaRadiance(Chroma): bad_keys = tuple( k for k, v in overrides.items() - if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys) ) if bad_keys: e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 85f515f67..d9e76922f 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,7 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management, model_patcher +import model_management +import model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e80b1c138..afbab2ac7 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -491,7 +491,8 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) - self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + # This norm final is in the lumina 2.0 code but isn't actually used for anything. + # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) if self.pad_tokens_multiple is not None: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..ccf690945 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,13 @@ except ImportError as e: raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError: + pass + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + dim_head = D + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + if dim_head >= 256 or N <= 1024: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if not skip_reshape: + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + exception_fallback = True + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 681a55db5..1ae3ef034 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -394,7 +394,8 @@ class Model(nn.Module): attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) @@ -548,7 +549,8 @@ class Encoder(nn.Module): conv3d=False, time_compress=None, **ignore_kwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) diff --git a/comfy/ldm/modules/ema.py b/comfy/ldm/modules/ema.py index bded25019..96ee6e895 100644 --- a/comfy/ldm/modules/ema.py +++ b/comfy/ldm/modules/ema.py @@ -45,7 +45,7 @@ class LitEma(nn.Module): shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -54,7 +54,7 @@ class LitEma(nn.Module): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/comfy/ldm/util.py b/comfy/ldm/util.py index 30b4b4721..304936ff4 100644 --- a/comfy/ldm/util.py +++ b/comfy/ldm/util.py @@ -71,7 +71,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": diff --git a/comfy/model_management.py b/comfy/model_management.py index 1889ab0ac..2501cecb7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1019,8 +1019,8 @@ NUM_STREAMS = 0 if args.async_offload is not None: NUM_STREAMS = args.async_offload else: - # Enable by default on Nvidia - if is_nvidia(): + # Enable by default on Nvidia and AMD + if is_nvidia() or is_amd(): NUM_STREAMS = 2 if args.disable_async_offload: @@ -1126,6 +1126,16 @@ if not args.disable_pinned_memory: PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +def discard_cuda_async_error(): + try: + a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + _ = a + b + torch.cuda.synchronize() + except torch.AcceleratorError: + #Dump it! We already know about it from the synchronous return + pass + def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: @@ -1158,6 +1168,9 @@ def pin_memory(tensor): PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size return True + else: + logging.warning("Pin error.") + discard_cuda_async_error() return False @@ -1186,6 +1199,9 @@ def unpin_memory(tensor): if len(PINNED_MEMORY) == 0: TOTAL_PINNED_MEMORY = 0 return True + else: + logging.warning("Unpin error.") + discard_cuda_async_error() return False @@ -1526,6 +1542,10 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) +def debug_memory_summary(): + if is_amd() or is_nvidia(): + return torch.cuda.memory.memory_summary() + return "" #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 3dfe1e4d4..0e5f9a378 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -154,7 +154,8 @@ class TAEHV(nn.Module): self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] if x.shape[1] % 4 != 0: # pad at end to multiple of 4 @@ -167,5 +168,6 @@ class TAEHV(nn.Module): def decode(self, x, **kwargs): x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) - if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_shuffle(x, self.patch_size) return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ed29e014d..faa4e1de8 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -8,7 +8,6 @@ from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit -import comfy.model_management from . import qwen_vl @dataclass diff --git a/comfy/utils.py b/comfy/utils.py index 8d4e2b445..e4162d7ac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): out_sd = {} layers = {} for k in list(state_dict.keys()): + if k == scaled_fp8_key: + continue if not k.startswith(model_prefix): out_sd[k] = state_dict[k] continue diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index fab63c7df..b0fa14ff6 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io_public as io from . import _ui_public as ui -# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple from PIL import Image diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a0f506279..b04b43ff6 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -26,11 +26,9 @@ if TYPE_CHECKING: from comfy_api.input import VideoInput from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL +from ._util import MESH, VOXEL, SVG as _SVG -# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference class FolderType(str, Enum): input = "input" @@ -77,16 +75,6 @@ class NumberDisplay(str, Enum): slider = "slider" -class _StringIOType(str): - def __ne__(self, value: object) -> bool: - if self == "*" or value == "*": - return False - if not isinstance(value, str): - return True - a = frozenset(self.split(",")) - b = frozenset(value.split(",")) - return not (b.issubset(a) or a.issubset(b)) - class _ComfyType(ABC): Type = Any io_type: str = None @@ -126,8 +114,7 @@ def comfytype(io_type: str, **kwargs): new_cls.__module__ = cls.__module__ new_cls.__doc__ = cls.__doc__ # assign ComfyType attributes, if needed - # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) - new_cls.io_type = _StringIOType(io_type) + new_cls.io_type = io_type if hasattr(new_cls, "Input") and new_cls.Input is not None: new_cls.Input.Parent = new_cls if hasattr(new_cls, "Output") and new_cls.Output is not None: @@ -166,7 +153,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -174,6 +161,7 @@ class Input(_IO_V3): self.tooltip = tooltip self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} + self.rawLink = raw_link def as_dict(self): return prune_dict({ @@ -181,10 +169,11 @@ class Input(_IO_V3): "optional": self.optional, "tooltip": self.tooltip, "lazy": self.lazy, + "rawLink": self.rawLink, }) | prune_dict(self.extra_dict) def get_io_type(self): - return _StringIOType(self.io_type) + return self.io_type def get_all(self) -> list[Input]: return [self] @@ -195,8 +184,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -218,13 +207,14 @@ class Output(_IO_V3): def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): self.id = id - self.display_name = display_name + self.display_name = display_name if display_name else id self.tooltip = tooltip self.is_output_list = is_output_list def as_dict(self): + display_name = self.display_name if self.display_name else self.id return prune_dict({ - "display_name": self.display_name, + "display_name": display_name, "tooltip": self.tooltip, "is_output_list": self.is_output_list, }) @@ -252,8 +242,8 @@ class Boolean(ComfyTypeIO): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.label_on = label_on self.label_off = label_off self.default: bool @@ -272,8 +262,8 @@ class Int(ComfyTypeIO): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.min = min self.max = max self.step = step @@ -298,8 +288,8 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.min = min self.max = max self.step = step @@ -324,8 +314,8 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -358,12 +348,14 @@ class Combo(ComfyTypeIO): image_folder: FolderType=None, remote: RemoteOptions=None, socketless: bool=None, + extra_dict=None, + raw_link: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -387,10 +379,6 @@ class Combo(ComfyTypeIO): super().__init__(id, display_name, tooltip, is_output_list) self.options = options if options is not None else [] - @property - def io_type(self): - return self.options - @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' @@ -399,8 +387,8 @@ class MultiCombo(ComfyTypeI): class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) + socketless: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -433,9 +421,9 @@ class Webcam(ComfyTypeIO): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) @comfytype(io_type="MASK") @@ -656,7 +644,7 @@ class Video(ComfyTypeIO): @comfytype(io_type="SVG") class SVG(ComfyTypeIO): - Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + Type = _SVG @comfytype(io_type="LORA_MODEL") class LoraModel(ComfyTypeIO): @@ -788,7 +776,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -801,7 +789,7 @@ class MultiType: # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self._io_types = types @property @@ -855,8 +843,8 @@ class MatchType(ComfyTypeIO): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self.template = template def as_dict(self): @@ -867,6 +855,8 @@ class MatchType(ComfyTypeIO): class Output(Output): def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): + if not id and not display_name: + display_name = "MATCHTYPE" super().__init__(id, display_name, tooltip, is_output_list) self.template = template @@ -879,24 +869,30 @@ class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - def get_dynamic(self) -> list[Input]: - return [] - - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - pass + pass class DynamicOutput(Output, ABC): ''' Abstract class for dynamic output registration. ''' - def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) + pass - def get_dynamic(self) -> list[Output]: - return [] +def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]: + if prefix_list is None: + prefix_list = [] + if id is not None: + prefix_list = prefix_list + [id] + return prefix_list + +def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str: + assert not (prefix_list is None and id is None) + if prefix_list is None: + return id + elif id is not None: + prefix_list = prefix_list + [id] + return ".".join(prefix_list) @comfytype(io_type="COMFY_AUTOGROW_V3") class Autogrow(ComfyTypeI): @@ -933,14 +929,6 @@ class Autogrow(ComfyTypeI): def validate(self): self.input.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - real_inputs = [] - for name, input in self.cached_inputs.items(): - if name in live_inputs: - real_inputs.append(input) - add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, real_inputs, curr_prefix) - class TemplatePrefix(_AutogrowTemplate): def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): super().__init__(input) @@ -985,22 +973,45 @@ class Autogrow(ComfyTypeI): "template": self.template.as_dict(), }) - def get_dynamic(self) -> list[Input]: - return self.template.get_all() - def get_all(self) -> list[Input]: return [self] + self.template.get_all() def validate(self): self.template.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - curr_prefix = f"{curr_prefix}{self.id}." - # need to remove self from expected inputs dictionary; replaced by template inputs in frontend - for inner_dict in d.values(): - if self.id in inner_dict: - del inner_dict[self.id] - self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + # NOTE: purposely do not include self in out_dict; instead use only the template inputs + # need to figure out names based on template type + is_names = ("names" in value[1]["template"]) + is_prefix = ("prefix" in value[1]["template"]) + input = value[1]["template"]["input"] + if is_names: + min = value[1]["template"]["min"] + names = value[1]["template"]["names"] + max = len(names) + elif is_prefix: + prefix = value[1]["template"]["prefix"] + min = value[1]["template"]["min"] + max = value[1]["template"]["max"] + names = [f"{prefix}{i}" for i in range(max)] + # need to create a new input based on the contents of input + template_input = None + for _, dict_input in input.items(): + # for now, get just the first value from dict_input + template_input = list(dict_input.values())[0] + new_dict = {} + for i, name in enumerate(names): + expected_id = finalize_prefix(curr_prefix, name) + if expected_id in live_inputs: + # required + if i < min: + type_dict = new_dict.setdefault("required", {}) + # optional + else: + type_dict = new_dict.setdefault("optional", {}) + type_dict[name] = template_input + parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") class DynamicCombo(ComfyTypeI): @@ -1023,23 +1034,6 @@ class DynamicCombo(ComfyTypeI): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self.options = options - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - # check if dynamic input's id is in live_inputs - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - key = live_inputs[self.id] - selected_option = None - for option in self.options: - if option.key == key: - selected_option = option - break - if selected_option is not None: - add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) - - def get_dynamic(self) -> list[Input]: - return [input for option in self.options for input in option.inputs] - def get_all(self) -> list[Input]: return [self] + [input for option in self.options for input in option.inputs] @@ -1054,6 +1048,24 @@ class DynamicCombo(ComfyTypeI): for input in option.inputs: input.validate() + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + key = live_inputs[finalized_id] + selected_option = None + # get options from dict + options: list[dict[str, str | dict[str, Any]]] = value[1]["options"] + for option in options: + if option["key"] == key: + selected_option = option + break + if selected_option is not None: + parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + @comfytype(io_type="COMFY_DYNAMICSLOT_V3") class DynamicSlot(ComfyTypeI): Type = dict[str, Any] @@ -1076,17 +1088,8 @@ class DynamicSlot(ComfyTypeI): self.force_input = True self.slot.force_input = True - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) - - def get_dynamic(self) -> list[Input]: - return [self.slot] + self.inputs - def get_all(self) -> list[Input]: - return [self] + [self.slot] + self.inputs + return [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ @@ -1100,17 +1103,41 @@ class DynamicSlot(ComfyTypeI): for input in self.inputs: input.validate() -def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): - dynamic = d.setdefault("dynamic_paths", {}) - if self is not None: - dynamic[self.id] = f"{curr_prefix}{self.id}" - for i in inputs: - if not isinstance(i, DynamicInput): - dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + inputs = value[1]["inputs"] + parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + +DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} +def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): + DYNAMIC_INPUT_LOOKUP[io_type] = func + +def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]: + return DYNAMIC_INPUT_LOOKUP[io_type] + +def setup_dynamic_input_funcs(): + # Autogrow.Input + register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic) + # DynamicCombo.Input + register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic) + # DynamicSlot.Input + register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic) + +if len(DYNAMIC_INPUT_LOOKUP) == 0: + setup_dynamic_input_funcs() class V3Data(TypedDict): hidden_inputs: dict[str, Any] + 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] + 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + create_dynamic_tuple: bool + 'When True, the value of the dynamic input will be in the format (value, path_key).' class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -1146,6 +1173,10 @@ class HiddenHolder: api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), ) + @classmethod + def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder: + return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None) + class Hidden(str, Enum): ''' Enumerator for requesting hidden variables in nodes. @@ -1251,61 +1282,56 @@ class Schema: - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' nested_inputs: list[Input] = [] - if self.inputs is not None: - for input in self.inputs: + for input in self.inputs: + if not isinstance(input, DynamicInput): nested_inputs.extend(input.get_all()) - input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] - output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] + input_ids = [i.id for i in nested_inputs] + output_ids = [o.id for o in self.outputs] input_set = set(input_ids) output_set = set(output_ids) - issues = [] + issues: list[str] = [] # verify ids are unique per list if len(input_set) != len(input_ids): issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") if len(output_set) != len(output_ids): issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") - # verify ids are unique between lists - intersection = input_set & output_set - if len(intersection) > 0: - issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) # validate inputs and outputs - if self.inputs is not None: - for input in self.inputs: - input.validate() - if self.outputs is not None: - for output in self.outputs: - output.validate() + for input in self.inputs: + input.validate() + for output in self.outputs: + output.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" + # ensure inputs, outputs, and hidden are lists + if self.inputs is None: + self.inputs = [] + if self.outputs is None: + self.outputs = [] + if self.hidden is None: + self.hidden = [] # if is an api_node, will need key-related hidden if self.is_api_node: - if self.hidden is None: - self.hidden = [] if Hidden.auth_token_comfy_org not in self.hidden: self.hidden.append(Hidden.auth_token_comfy_org) if Hidden.api_key_comfy_org not in self.hidden: self.hidden.append(Hidden.api_key_comfy_org) # if is an output_node, will need prompt and extra_pnginfo if self.is_output_node: - if self.hidden is None: - self.hidden = [] if Hidden.prompt not in self.hidden: self.hidden.append(Hidden.prompt) if Hidden.extra_pnginfo not in self.hidden: self.hidden.append(Hidden.extra_pnginfo) # give outputs without ids default ids - if self.outputs is not None: - for i, output in enumerate(self.outputs): - if output.id is None: - output.id = f"_{i}_{output.io_type}_" + for i, output in enumerate(self.outputs): + if output.id is None: + output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: - # NOTE: live_inputs will not be used anymore very soon and this will be done another way + def get_v1_info(self, cls) -> NodeInfoV1: # get V1 inputs - input = create_input_dict_v1(self.inputs, live_inputs) + input = create_input_dict_v1(self.inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1385,33 +1411,54 @@ class Schema: ) return info +def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]: + out_dict = { + "required": {}, + "optional": {}, + "dynamic_paths": {}, + } + d = d.copy() + # ignore hidden for parsing + hidden = d.pop("hidden", None) + parse_class_inputs(out_dict, live_inputs, d) + if hidden is not None and include_hidden: + out_dict["hidden"] = hidden + v3_data = {} + dynamic_paths = out_dict.pop("dynamic_paths", None) + if dynamic_paths is not None: + v3_data["dynamic_paths"] = dynamic_paths + return out_dict, hidden, v3_data -def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: +def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: + for input_type, inner_d in curr_dict.items(): + for id, value in inner_d.items(): + io_type = value[0] + if io_type in DYNAMIC_INPUT_LOOKUP: + # dynamic inputs need to be handled with lookup functions + dynamic_input_func = get_dynamic_input_func(io_type) + new_prefix = handle_prefix(curr_prefix, id) + dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix) + else: + # non-dynamic inputs get directly transferred + finalized_id = finalize_prefix(curr_prefix, id) + out_dict[input_type][finalized_id] = value + if curr_prefix: + out_dict["dynamic_paths"][finalized_id] = finalized_id + +def create_input_dict_v1(inputs: list[Input]) -> dict: input = { "required": {} } - add_to_input_dict_v1(input, inputs, live_inputs) + for i in inputs: + add_to_dict_v1(i, input) return input -def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): - for i in inputs: - if isinstance(i, DynamicInput): - add_to_dict_v1(i, d) - if live_inputs is not None: - i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) - else: - add_to_dict_v1(i, d) - -def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): +def add_to_dict_v1(i: Input, d: dict): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - if dynamic_dict is None: - value = (i.get_io_type(), as_dict) - else: - value = (i.get_io_type(), as_dict, dynamic_dict) - d.setdefault(key, {})[i.id] = value + d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) @@ -1423,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): values = values.copy() result = {} + create_tuple = v3_data.get("create_dynamic_tuple", False) + for key, path in paths.items(): parts = path.split(".") current = result @@ -1431,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): is_last = (i == len(parts) - 1) if is_last: - current[p] = values.pop(key, None) + value = values.pop(key, None) + if create_tuple: + value = (value, key) + current[p] = value else: current = current.setdefault(p, {}) @@ -1446,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): SCHEMA = None # filled in during execution - resources: Resources = None hidden: HiddenHolder = None @classmethod @@ -1493,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): return [name for name in kwargs if kwargs[name] is None] def __init__(self): - self.local_resources: ResourcesLocal = None self.__class__.VALIDATE_CLASS() @classmethod @@ -1561,7 +1611,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) + type_clone.hidden = HiddenHolder.from_v3_data(v3_data) return type_clone @final @@ -1678,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: + def INPUT_TYPES(cls) -> dict[str, dict]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls, live_inputs) - input = info.input - if not include_hidden: - input.pop("hidden", None) - if return_schema: - v3_data: V3Data = {} - dynamic = input.pop("dynamic_paths", None) - if dynamic is not None: - v3_data["dynamic_paths"] = dynamic - return input, schema, v3_data - return input + info = schema.get_v1_info(cls) + return info.input @final @classmethod @@ -1809,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal): return self.args if len(self.args) > 0 else None @classmethod - def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": + def from_dict(cls, data: dict[str, Any]) -> NodeOutput: args = () ui = None expand = None @@ -1905,8 +1946,8 @@ __all__ = [ "Tracks", # Dynamic Types "MatchType", - # "DynamicCombo", - # "Autogrow", + "DynamicCombo", + "Autogrow", # Other classes "HiddenHolder", "Hidden", diff --git a/comfy_api/latest/_resources.py b/comfy_api/latest/_resources.py deleted file mode 100644 index a6bdda972..000000000 --- a/comfy_api/latest/_resources.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations -import comfy.utils -import folder_paths -import logging -from abc import ABC, abstractmethod -from typing import Any -import torch - -class ResourceKey(ABC): - Type = Any - def __init__(self): - ... - -class TorchDictFolderFilename(ResourceKey): - '''Key for requesting a torch file via file_name from a folder category.''' - Type = dict[str, torch.Tensor] - def __init__(self, folder_name: str, file_name: str): - self.folder_name = folder_name - self.file_name = file_name - - def __hash__(self): - return hash((self.folder_name, self.file_name)) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TorchDictFolderFilename): - return False - return self.folder_name == other.folder_name and self.file_name == other.file_name - - def __str__(self): - return f"{self.folder_name} -> {self.file_name}" - -class Resources(ABC): - def __init__(self): - ... - - @abstractmethod - def get(self, key: ResourceKey, default: Any=...) -> Any: - pass - -class ResourcesLocal(Resources): - def __init__(self): - super().__init__() - self.local_resources: dict[ResourceKey, Any] = {} - - def get(self, key: ResourceKey, default: Any=...) -> Any: - cached = self.local_resources.get(key, None) - if cached is not None: - logging.info(f"Using cached resource '{key}'") - return cached - logging.info(f"Loading resource '{key}'") - to_return = None - if isinstance(key, TorchDictFolderFilename): - if default is ...: - to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) - else: - full_path = folder_paths.get_full_path(key.folder_name, key.file_name) - if full_path is not None: - to_return = comfy.utils.load_torch_file(full_path, safe_load=True) - - if to_return is not None: - self.local_resources[key] = to_return - return to_return - if default is not ...: - return default - raise Exception(f"Unsupported resource key type: {type(key)}") - - -class _RESOURCES: - ResourceKey = ResourceKey - TorchDictFolderFilename = TorchDictFolderFilename - Resources = Resources - ResourcesLocal = ResourcesLocal diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index fc5431dda..6313eb01b 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,6 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents from .geometry_types import VOXEL, MESH +from .image_types import SVG __all__ = [ # Utility Types @@ -8,4 +9,5 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "SVG", ] diff --git a/comfy_api/latest/_util/image_types.py b/comfy_api/latest/_util/image_types.py new file mode 100644 index 000000000..f031ed426 --- /dev/null +++ b/comfy_api/latest/_util/image_types.py @@ -0,0 +1,18 @@ +from io import BytesIO + + +class SVG: + """Stores SVG representations via a list of BytesIO objects.""" + + def __init__(self, data: list[BytesIO]): + self.data = data + + def combine(self, other: 'SVG') -> 'SVG': + return SVG(self.data + other.data) + + @staticmethod + def combine_all(svgs: list['SVG']) -> 'SVG': + all_svgs_list: list[BytesIO] = [] + for svg_item in svgs: + all_svgs_list.extend(svg_item.data) + return SVG(all_svgs_list) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index f8edc38c9..d81337dae 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel): systemInstruction: GeminiSystemInstructionContent | None = Field(None) tools: list[GeminiTool] | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None) + uploadImagesToStorage: bool = Field(True) class GeminiGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index 80a758466..bf54ede3e 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel): prompt: str = Field(...) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + + +class MotionControlRequest(BaseModel): + prompt: str = Field(...) + image_url: str = Field(...) + video_url: str = Field(...) + keep_original_sound: str = Field(...) + character_orientation: str = Field(...) + mode: str = Field(..., description="'pro' or 'std'") diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 636cc1265..d4a2cfae6 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -229,6 +229,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -269,7 +270,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 4", + display_name="ByteDance Seedream 4.5", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ad0f4b4d1..e8ed7e797 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -34,6 +34,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + download_url_to_image_tensor, get_number_of_images, sync_op, tensor_to_base64_string, @@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera ) parts = [] for part in response.candidates[0].content.parts: - if part_type == "text" and hasattr(part, "text") and part.text: + if part_type == "text" and part.text: parts.append(part) - elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: - image_data = base64.b64decode(part.inlineData.data) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + if part.inlineData: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + else: + returned_image = await download_url_to_image_tensor(part.fileData.fileUri) image_tensors.append(returned_image) if len(image_tensors) == 0: return torch.zeros((1, 1024, 1024, 4)) @@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode): response = await sync_op( cls, - endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode): response = await sync_op( cls, - ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 5294b10d4..9c707a339 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -51,6 +51,7 @@ from comfy_api_nodes.apis import ( ) from comfy_api_nodes.apis.kling_api import ( ImageToVideoWithAudioRequest, + MotionControlRequest, OmniImageParamImage, OmniParamImage, OmniParamVideo, @@ -806,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("duration", options=[5, 10]), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -825,6 +827,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt: str, aspect_ratio: str, duration: int, + resolution: str = "1080p", ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=2500) response = await sync_op( @@ -836,6 +839,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt=prompt, aspect_ratio=aspect_ratio, duration=str(duration), + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -871,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): optional=True, tooltip="Up to 6 additional reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -892,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): first_frame: Input.Image, end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -935,6 +941,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): prompt=prompt, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -963,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): "reference_images", tooltip="Up to 7 reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -983,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio: str, duration: int, reference_images: Input.Image, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1004,6 +1013,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio=aspect_ratio, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1035,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -1057,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): reference_video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1089,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): duration=str(duration), image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1118,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode): tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -1138,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode): video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1170,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode): duration=None, image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -2163,6 +2179,91 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +class MotionControl(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingMotionControl", + display_name="Kling Motion Control", + category="api node/video/Kling", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.Image.Input("reference_image"), + IO.Video.Input( + "reference_video", + tooltip="Motion reference video used to drive movement/expression.\n" + "Duration limits depend on character_orientation:\n" + " - image: 3–10s (max 10s)\n" + " - video: 3–30s (max 30s)", + ), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Combo.Input( + "character_orientation", + options=["video", "image"], + tooltip="Controls where the character's facing/orientation comes from.\n" + "video: movements, expressions, camera moves, and orientation " + "follow the motion reference video (other details via prompt).\n" + "image: movements and expressions still follow the motion reference video, " + "but the character orientation matches the reference image (camera/other details via prompt).", + ), + IO.Combo.Input("mode", options=["pro", "std"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + reference_image: Input.Image, + reference_video: Input.Video, + keep_original_sound: bool, + character_orientation: str, + mode: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2500) + validate_image_dimensions(reference_image, min_width=340, min_height=340) + validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1)) + if character_orientation == "image": + validate_video_duration(reference_video, min_duration=3, max_duration=10) + else: + validate_video_duration(reference_video, min_duration=3, max_duration=30) + validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"), + response_model=TaskStatusResponse, + data=MotionControlRequest( + prompt=prompt, + image_url=(await upload_images_to_comfyapi(cls, reference_image))[0], + video_url=await upload_video_to_comfyapi(cls, reference_video), + keep_original_sound="yes" if keep_original_sound else "no", + character_orientation=character_orientation, + mode=mode, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + class KlingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -2188,6 +2289,7 @@ class KlingExtension(ComfyExtension): OmniProImageNode, TextToVideoWithAudio, ImageToVideoWithAudio, + MotionControl, ] diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index bd3c24fb3..e72f8e96a 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -155,7 +155,7 @@ class TripoTextToModelNode(IO.ComfyNode): model_seed=model_seed, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, geometry_quality=geometry_quality, auto_size=True, quad=quad, @@ -255,7 +255,7 @@ class TripoImageToModelNode(IO.ComfyNode): texture_alignment=texture_alignment, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, auto_size=True, quad=quad, ), @@ -369,7 +369,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): texture_quality=texture_quality, geometry_quality=geometry_quality, texture_alignment=texture_alignment, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, quad=quad, ), ) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e165b8380..13a6bfd91 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -168,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Only add generateAudio for Veo 3 models if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio + # force "enhance_prompt" to True for Veo3 models + parameters["enhancePrompt"] = True initial_response = await sync_op( cls, @@ -291,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Boolean.Input( "enhance_prompt", default=True, - tooltip="Whether to enhance the prompt with AI assistance", + tooltip="This parameter is deprecated and ignored.", optional=True, ), IO.Combo.Input( diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 491e6b6a8..648defe3d 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -1,16 +1,22 @@ import asyncio import contextlib import os +import re import time from collections.abc import Callable from io import BytesIO +from yarl import URL + from comfy.cli_args import args from comfy.model_management import processing_interrupted from comfy_api.latest import IO from .common_exceptions import ProcessingInterrupted +_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits +_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits + def is_processing_interrupted() -> bool: """Return True if user/runtime requested interruption.""" @@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) + + +def to_aiohttp_url(url: str) -> URL: + """If `url` appears to be already percent-encoded (contains at least one valid %HH + escape and no malformed '%' sequences) and contains no raw whitespace/control + characters preserve the original encoding byte-for-byte (important for signed/presigned URLs). + Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed.""" + if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url): + # Avoid encoded=True if URL contains raw whitespace/control chars + return URL(url) + if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url): + # Preserve encoding only if it appears pre-encoded AND has no invalid % sequences + return URL(url, encoded=True) + return URL(url) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf37cba5f..f372ec7b5 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -430,9 +430,9 @@ def _display_text( if status: display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: - p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": - display_lines.append(f"Price: ${p}") + display_lines.append(f"Price: {p} credits") if text is not None: display_lines.append(text) if display_lines: diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 3e0d0352d..4668d14a9 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -19,6 +19,7 @@ from ._helpers import ( get_auth_header, is_processing_interrupted, sleep_with_interrupt, + to_aiohttp_url, ) from .client import _diagnose_connectivity from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted @@ -94,7 +95,7 @@ async def download_url_to_bytesio( monitor_task = asyncio.create_task(_monitor()) - req_task = asyncio.create_task(session.get(url, headers=headers)) + req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers)) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) if monitor_task in done and req_task in pending: diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0d811e354..9d170b16e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -97,6 +97,11 @@ def get_input_info( extra_info = input_info[1] else: extra_info = {} + # if input_type is a list, it is a Combo defined in outdated format; convert it. + # NOTE: uncomment this when we are confident old format going away won't cause too much trouble. + # if isinstance(input_type, list): + # extra_info["options"] = input_type + # input_type = IO.Combo.io_type return input_type, input_category, extra_info class TopologicalSort: @@ -202,15 +207,15 @@ class ExecutionList(TopologicalSort): return self.output_cache.get(node_id) is not None def cache_link(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) - if not from_node_id in self.execution_cache_listeners: + if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) def get_cache(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: return None value = self.execution_cache[to_node_id].get(from_node_id) if value is None: diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index 24c0b4ed7..e73624bd1 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -21,14 +21,24 @@ def validate_node_input( """ # If the types are exactly the same, we can return immediately # Use pre-union behaviour: inverse of `__ne__` + # NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class. if not received_type != input_type: return True + # If one of the types is '*', we can return True immediately; this is the 'Any' type. + if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type: + return True + # If the received type or input_type is a MatchType, we can return True immediately; # validation for this is handled by the frontend if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: return True + # This accounts for some custom nodes that output lists of options as the type; + # if we ever want to break them on purpose, this can be removed + if isinstance(received_type, list) and input_type == IO.Combo.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False @@ -37,6 +47,10 @@ def validate_node_input( received_types = set(t.strip() for t in received_type.split(",")) input_types = set(t.strip() for t in input_type.split(",")) + # If any of the types is '*', we can return True immediately; this is the 'Any' type. + if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types: + return True + if strict: # In strict mode, all received types must be in the input types return received_types.issubset(input_types) diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index f27ae7da8..b9df2dcc9 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -55,7 +55,8 @@ class APG(io.ComfyNode): def pre_cfg_function(args): nonlocal running_avg, prev_sigma - if len(args["conds_out"]) == 1: return args["conds_out"] + if len(args["conds_out"]) == 1: + return args["conds_out"] cond = args["conds_out"][0] uncond = args["conds_out"][1] diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 7ee4caac1..f19adf4b9 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -9,6 +9,7 @@ import comfy.utils import node_helpers from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import re class BasicScheduler(io.ComfyNode): @@ -760,8 +761,12 @@ class SamplerCustom(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -948,8 +953,12 @@ class SamplerCustomAdvanced(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -1005,6 +1014,25 @@ class AddNoise(io.ComfyNode): add_noise = execute +class ManualSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ManualSigmas", + category="_for_testing/custom_sampling", + is_experimental=True, + inputs=[ + io.String.Input("sigmas", default="1, 0.5", multiline=False) + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: + sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas) + sigmas = [float(i) for i in sigmas] + sigmas = torch.FloatTensor(sigmas) + return io.NodeOutput(sigmas) class CustomSamplersExtension(ComfyExtension): @override @@ -1044,6 +1072,7 @@ class CustomSamplersExtension(ComfyExtension): DisableNoise, AddNoise, SamplerCustomAdvanced, + ManualSigmas, ] diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 513aecf3a..5ef851bd0 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): @classmethod def _process(cls, image, longer_edge): - img = tensor_to_pil(image) - w, h = img.size - if w > h: - new_w = longer_edge - new_h = int(h * (longer_edge / w)) - else: - new_h = longer_edge - new_w = int(w * (longer_edge / h)) - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) + resized_images = [] + for image_i in image: + img = tensor_to_pil(image_i) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + resized_images.append(pil_to_tensor(img)) + return torch.cat(resized_images, dim=0) class CenterCropImagesNode(ImageProcessingNode): diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 392aea32c..ce21caade 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -2,280 +2,231 @@ from __future__ import annotations import nodes import folder_paths -from comfy.cli_args import args -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import numpy as np import json import os import re -from io import BytesIO -from inspect import cleandoc import torch import comfy.utils -from comfy.comfy_types import FileLocator, IO from server import PromptServer +from comfy_api.latest import ComfyExtension, IO, UI +from typing_extensions import override + +SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. MAX_RESOLUTION = nodes.MAX_RESOLUTION -class ImageCrop: +class ImageCrop(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "crop" + def define_schema(cls): + return IO.Schema( + node_id="ImageCrop", + display_name="Image Crop", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def crop(self, image, width, height, x, y): + @classmethod + def execute(cls, image, width, height, x, y) -> IO.NodeOutput: x = min(x, image.shape[2] - 1) y = min(y, image.shape[1] - 1) to_x = width + x to_y = height + y img = image[:,y:to_y, x:to_x, :] - return (img,) + return IO.NodeOutput(img) -class RepeatImageBatch: + crop = execute # TODO: remove + + +class RepeatImageBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="RepeatImageBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("amount", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def repeat(self, image, amount): + @classmethod + def execute(cls, image, amount) -> IO.NodeOutput: s = image.repeat((amount, 1,1,1)) - return (s,) + return IO.NodeOutput(s) -class ImageFromBatch: + repeat = execute # TODO: remove + + +class ImageFromBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), - "length": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "frombatch" + def define_schema(cls): + return IO.Schema( + node_id="ImageFromBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("batch_index", default=0, min=0, max=4095), + IO.Int.Input("length", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def frombatch(self, image, batch_index, length): + @classmethod + def execute(cls, image, batch_index, length) -> IO.NodeOutput: s_in = image batch_index = min(s_in.shape[0] - 1, batch_index) length = min(s_in.shape[0] - batch_index, length) s = s_in[batch_index:batch_index + length].clone() - return (s,) + return IO.NodeOutput(s) + + frombatch = execute # TODO: remove -class ImageAddNoise: +class ImageAddNoise(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), - "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="ImageAddNoise", + category="image", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def repeat(self, image, seed, strength): + @classmethod + def execute(cls, image, seed, strength) -> IO.NodeOutput: generator = torch.manual_seed(seed) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) - return (s,) + return IO.NodeOutput(s) -class SaveAnimatedWEBP: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" + repeat = execute # TODO: remove - methods = {"default": 4, "fastest": 0, "slowest": 6} - @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "lossless": ("BOOLEAN", {"default": True}), - "quality": ("INT", {"default": 80, "min": 0, "max": 100}), - "method": (list(s.methods.keys()),), - # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): - method = self.methods.get(method) - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results: list[FileLocator] = [] - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = pil_images[0].getexif() - if not args.disable_metadata: - if prompt is not None: - metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) - if extra_pnginfo is not None: - inital_exif = 0x010f - for x in extra_pnginfo: - metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) - inital_exif -= 1 - - if num_frames == 0: - num_frames = len(pil_images) - - c = len(pil_images) - for i in range(0, c, num_frames): - file = f"{filename}_{counter:05}_.webp" - pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - animated = num_frames != 1 - return { "ui": { "images": results, "animated": (animated,) } } - -class SaveAnimatedPNG: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAnimatedWEBP(IO.ComfyNode): + COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedWEBP", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Boolean.Input("lossless", default=True), + IO.Int.Input("quality", default=80, min=0, max=100), + IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - RETURN_TYPES = () - FUNCTION = "save_images" + @classmethod + def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_webp_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=cls.COMPRESS_METHODS.get(method) + ) + ) - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results = list() - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) - - file = f"{filename}_{counter:05}_.png" - pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - - return { "ui": { "images": results, "animated": (True,)} } - -class SVG: - """ - Stores SVG representations via a list of BytesIO objects. - """ - def __init__(self, data: list[BytesIO]): - self.data = data - - def combine(self, other: 'SVG') -> 'SVG': - return SVG(self.data + other.data) - - @staticmethod - def combine_all(svgs: list['SVG']) -> 'SVG': - all_svgs_list: list[BytesIO] = [] - for svg_item in svgs: - all_svgs_list.extend(svg_item.data) - return SVG(all_svgs_list) + save_images = execute # TODO: remove -class ImageStitch: +class SaveAnimatedPNG(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedPNG", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Int.Input("compress_level", default=4, min=0, max=9), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_png_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + ) + + save_images = execute # TODO: remove + + +class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "direction": (["right", "down", "left", "up"], {"default": "right"}), - "match_image_size": ("BOOLEAN", {"default": True}), - "spacing_width": ( - "INT", - {"default": 0, "min": 0, "max": 1024, "step": 2}, - ), - "spacing_color": ( - ["white", "black", "red", "green", "blue"], - {"default": "white"}, - ), - }, - "optional": { - "image2": ("IMAGE",), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="ImageStitch", + display_name="Image Stitch", + description="Stitches image2 to image1 in the specified direction.\n" + "If image2 is not provided, returns image1 unchanged.\n" + "Optional spacing can be added between images.", + category="image/transform", + inputs=[ + IO.Image.Input("image1"), + IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), + IO.Boolean.Input("match_image_size", default=True), + IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2), + IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"), + IO.Image.Input("image2", optional=True), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "stitch" - CATEGORY = "image/transform" - DESCRIPTION = """ -Stitches image2 to image1 in the specified direction. -If image2 is not provided, returns image1 unchanged. -Optional spacing can be added between images. -""" - - def stitch( - self, + @classmethod + def execute( + cls, image1, direction, match_image_size, spacing_width, spacing_color, image2=None, - ): + ) -> IO.NodeOutput: if image2 is None: - return (image1,) + return IO.NodeOutput(image1) # Handle batch size differences if image1.shape[0] != image2.shape[0]: @@ -412,36 +363,30 @@ Optional spacing can be added between images. images.insert(1, spacing) concat_dim = 2 if direction in ["left", "right"] else 1 - return (torch.cat(images, dim=concat_dim),) + return IO.NodeOutput(torch.cat(images, dim=concat_dim)) + + stitch = execute # TODO: remove + + +class ResizeAndPadImage(IO.ComfyNode): -class ResizeAndPadImage: @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE",), - "target_width": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "target_height": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "padding_color": (["white", "black"],), - "interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ResizeAndPadImage", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("padding_color", options=["white", "black"]), + IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "resize_and_pad" - CATEGORY = "image/transform" - - def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation): + @classmethod + def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: batch_size, orig_height, orig_width, channels = image.shape scale_w = target_width / orig_width @@ -469,52 +414,47 @@ class ResizeAndPadImage: padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized output = padded.permute(0, 2, 3, 1) - return (output,) + return IO.NodeOutput(output) -class SaveSVGNode: - """ - Save SVG files on disk. - """ + resize_and_pad = execute # TODO: remove - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" - RETURN_TYPES = () - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "save_svg" - CATEGORY = "image/save" # Changed - OUTPUT_NODE = True +class SaveSVGNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "svg": ("SVG",), # Changed - "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - } - } + def define_schema(cls): + return IO.Schema( + node_id="SaveSVGNode", + description="Save SVG files on disk.", + category="image/save", + inputs=[ + IO.SVG.Input("svg"), + IO.String.Input( + "filename_prefix", + default="svg/ComfyUI", + tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results = list() + @classmethod + def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) + results: list[UI.SavedResult] = [] # Prepare metadata JSON metadata_dict = {} - if prompt is not None: - metadata_dict["prompt"] = prompt - if extra_pnginfo is not None: - metadata_dict.update(extra_pnginfo) + if cls.hidden.prompt is not None: + metadata_dict["prompt"] = cls.hidden.prompt + if cls.hidden.extra_pnginfo is not None: + metadata_dict.update(cls.hidden.extra_pnginfo) # Convert metadata to JSON string metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + for batch_number, svg_bytes in enumerate(svg.data): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.svg" @@ -544,57 +484,64 @@ class SaveSVGNode: with open(os.path.join(full_output_folder, file), 'wb') as svg_file: svg_file.write(svg_content.encode('utf-8')) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) + results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return { "ui": { "images": results } } + return IO.NodeOutput(ui={"images": results}) -class GetImageSize: + save_svg = execute # TODO: remove + + +class GetImageSize(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "hidden": { - "unique_id": "UNIQUE_ID", - } - } + def define_schema(cls): + return IO.Schema( + node_id="GetImageSize", + display_name="Get Image Size", + description="Returns width and height of the image, and passes it through unchanged.", + category="image", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + IO.Int.Output(display_name="batch_size"), + ], + hidden=[IO.Hidden.unique_id], + ) - RETURN_TYPES = (IO.INT, IO.INT, IO.INT) - RETURN_NAMES = ("width", "height", "batch_size") - FUNCTION = "get_size" - - CATEGORY = "image" - DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" - - def get_size(self, image, unique_id=None) -> tuple[int, int]: + @classmethod + def execute(cls, image) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] batch_size = image.shape[0] # Send progress text to display size on the node - if unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) - return width, height, batch_size + return IO.NodeOutput(width, height, batch_size) + + get_size = execute # TODO: remove + + +class ImageRotate(IO.ComfyNode): -class ImageRotate: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "rotate" + def define_schema(cls): + return IO.Schema( + node_id="ImageRotate", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def rotate(self, image, rotation): + @classmethod + def execute(cls, image, rotation) -> IO.NodeOutput: rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -604,41 +551,57 @@ class ImageRotate: rotate_by = 3 image = torch.rot90(image, k=rotate_by, dims=[2, 1]) - return (image,) + return IO.NodeOutput(image) + + rotate = execute # TODO: remove + + +class ImageFlip(IO.ComfyNode): -class ImageFlip: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "flip" + def define_schema(cls): + return IO.Schema( + node_id="ImageFlip", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def flip(self, image, flip_method): + @classmethod + def execute(cls, image, flip_method) -> IO.NodeOutput: if flip_method.startswith("x"): image = torch.flip(image, dims=[1]) elif flip_method.startswith("y"): image = torch.flip(image, dims=[2]) - return (image,) + return IO.NodeOutput(image) -class ImageScaleToMaxDimension: - upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + flip = execute # TODO: remove + + +class ImageScaleToMaxDimension(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "upscale_method": (s.upscale_methods,), - "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return IO.Schema( + node_id="ImageScaleToMaxDimension", + category="image/upscaling", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "upscale_method", + options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], + ), + IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/upscaling" - - def upscale(self, image, upscale_method, largest_size): + @classmethod + def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] @@ -655,20 +618,30 @@ class ImageScaleToMaxDimension: samples = image.movedim(-1, 1) s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1, -1) - return (s,) + return IO.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "ImageCrop": ImageCrop, - "RepeatImageBatch": RepeatImageBatch, - "ImageFromBatch": ImageFromBatch, - "ImageAddNoise": ImageAddNoise, - "SaveAnimatedWEBP": SaveAnimatedWEBP, - "SaveAnimatedPNG": SaveAnimatedPNG, - "SaveSVGNode": SaveSVGNode, - "ImageStitch": ImageStitch, - "ResizeAndPadImage": ResizeAndPadImage, - "GetImageSize": GetImageSize, - "ImageRotate": ImageRotate, - "ImageFlip": ImageFlip, - "ImageScaleToMaxDimension": ImageScaleToMaxDimension, -} + upscale = execute # TODO: remove + + +class ImagesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCrop, + RepeatImageBatch, + ImageFromBatch, + ImageAddNoise, + SaveAnimatedWEBP, + SaveAnimatedPNG, + SaveSVGNode, + ImageStitch, + ResizeAndPadImage, + GetImageSize, + ImageRotate, + ImageFlip, + ImageScaleToMaxDimension, + ] + + +async def comfy_entrypoint() -> ImagesExtension: + return ImagesExtension() diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 2815c5ffc..9ba1c4ba8 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -255,6 +255,7 @@ class LatentBatch(io.ComfyNode): return io.Schema( node_id="LatentBatch", category="latent/batch", + is_deprecated=True, inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 95a6ba788..eb888316a 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -1,8 +1,11 @@ +from __future__ import annotations from typing import TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy_api.latest import _io +# sentinel for missing inputs +MISSING = object() class SwitchNode(io.ComfyNode): @@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode): display_name="Switch", category="logic", is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True), + io.MatchType.Input("on_true", template=template, lazy=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=None, on_true=None): + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def execute(cls, switch, on_true, on_false) -> io.NodeOutput: + return io.NodeOutput(on_true if switch else on_false) + + +class SoftSwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySoftSwitchNode", + display_name="Soft Switch", + category="logic", + is_experimental=True, inputs=[ io.Boolean.Input("switch"), io.MatchType.Input("on_false", template=template, lazy=True, optional=True), @@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode): ) @classmethod - def check_lazy_status(cls, switch, on_false=..., on_true=...): - # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING): + # We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs. # This trick allows us to ignore the value of the switch and still be able to run execute(). # One of the inputs may be missing, in which case we need to evaluate the other input - if on_false is ...: + if on_false is MISSING: return ["on_true"] - if on_true is ...: + if on_true is MISSING: return ["on_false"] # Normal lazy switch operation if switch and on_true is None: @@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode): return ["on_false"] @classmethod - def validate_inputs(cls, switch, on_false=..., on_true=...): + def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING): # This check happens before check_lazy_status(), so we can eliminate the case where # both inputs are missing. - if on_false is ... and on_true is ...: + if on_false is MISSING and on_true is MISSING: return "At least one of on_false or on_true must be connected to Switch node" return True @classmethod - def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: - if on_true is ...: + def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput: + if on_true is MISSING: return io.NodeOutput(on_false) - if on_false is ...: + if on_false is MISSING: return io.NodeOutput(on_true) return io.NodeOutput(on_true if switch else on_false) +class CustomComboNode(io.ComfyNode): + """ + Frontend node that allows user to write their own options for a combo. + This is here to make sure the node has a backend-representation to avoid some annoyances. + """ + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CustomCombo", + display_name="Custom Combo", + category="utils", + is_experimental=True, + inputs=[io.Combo.Input("choice", options=[])], + outputs=[io.String.Output()] + ) + + @classmethod + def validate_inputs(cls, choice: io.Combo.Type) -> bool: + # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs. + # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined. + # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user. + return True + + @classmethod + def execute(cls, choice: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(choice) + + class DCTestNode(io.ComfyNode): class DCValues(TypedDict): combo: str @@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode): display_name="DCTest", category="logic", is_output_node=True, - inputs=[_io.DynamicCombo.Input("combo", options=[ - _io.DynamicCombo.Option("option1", [io.String.Input("string")]), - _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), - _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), - _io.DynamicCombo.Option("option4", [ - _io.DynamicCombo.Input("subcombo", options=[ - _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), - _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + inputs=[io.DynamicCombo.Input("combo", options=[ + io.DynamicCombo.Option("option1", [io.String.Input("string")]), + io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + io.DynamicCombo.Option("option4", [ + io.DynamicCombo.Input("subcombo", options=[ + io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), ]) ])] )], @@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode): combined = ",".join([str(x) for x in vals]) return io.NodeOutput(combined) +class ComboOutputTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ComboOptionTestNode", + display_name="ComboOptionTest", + category="logic", + inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]), + io.Combo.Input("combo2", options=["option4", "option5", "option6"])], + outputs=[io.Combo.Output(), io.Combo.Output()], + ) + + @classmethod + def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(combo, combo2) + +class ConvertStringToComboNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ConvertStringToComboNode", + display_name="Convert String to Combo", + category="logic", + inputs=[io.String.Input("string")], + outputs=[io.Combo.Output()], + ) + + @classmethod + def execute(cls, string: str) -> io.NodeOutput: + return io.NodeOutput(string) + +class InvertBooleanNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="InvertBooleanNode", + display_name="Invert Boolean", + category="logic", + inputs=[io.Boolean.Input("boolean")], + outputs=[io.Boolean.Output()], + ) + + @classmethod + def execute(cls, boolean: bool) -> io.NodeOutput: + return io.NodeOutput(not boolean) + class LogicExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - # SwitchNode, + SwitchNode, + CustomComboNode, + # SoftSwitchNode, + # ConvertStringToComboNode, # DCTestNode, # AutogrowNamesTestNode, # AutogrowPrefixTestNode, + # ComboOutputTestNode, + # InvertBooleanNode, ] async def comfy_entrypoint() -> LogicExtension: diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 07b3353f4..6459ca8c1 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", + display_name="Mahiro CFG", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ca2cdeb50..01afa13a1 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -4,11 +4,15 @@ import torch import torch.nn.functional as F from PIL import Image import math +from enum import Enum +from typing import TypedDict, Literal import comfy.utils import comfy.model_management +from comfy_extras.nodes_latent import reshape_latent_to import node_helpers from comfy_api.latest import ComfyExtension, io +from nodes import MAX_RESOLUTION class Blend(io.ComfyNode): @classmethod @@ -241,6 +245,353 @@ class ImageScaleToTotalPixels(io.ComfyNode): s = s.movedim(1,-1) return io.NodeOutput(s) +class ResizeType(str, Enum): + SCALE_BY = "scale by multiplier" + SCALE_DIMENSIONS = "scale dimensions" + SCALE_LONGER_DIMENSION = "scale longer dimension" + SCALE_SHORTER_DIMENSION = "scale shorter dimension" + SCALE_WIDTH = "scale width" + SCALE_HEIGHT = "scale height" + SCALE_TOTAL_PIXELS = "scale total pixels" + MATCH_SIZE = "match size" + +def is_image(input: torch.Tensor) -> bool: + # images have 4 dimensions: [batch, height, width, channels] + # masks have 3 dimensions: [batch, height, width] + return len(input.shape) == 4 + +def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(-1, 1) + else: + input = input.unsqueeze(1) + return input + +def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(1, -1) + else: + input = input.squeeze(1) + return input + +def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = round(input.shape[-1] * multiplier) + height = round(input.shape[-2] * multiplier) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor: + if width == 0 and height == 0: + return input + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + + if width == 0: + width = max(1, round(input.shape[-1] * height / input.shape[-2])) + elif height == 0: + height = max(1, round(input.shape[-2] * width / input.shape[-1])) + + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height > width: + width = round((width / height) * longer_size) + height = longer_size + elif width > height: + height = round((height / width) * longer_size) + width = longer_size + else: + height = longer_size + width = longer_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height < width: + width = round((width / height) * shorter_size) + height = shorter_size + elif width > height: + height = round((height / width) * shorter_size) + width = shorter_size + else: + height = shorter_size + width = shorter_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2])) + width = round(input.shape[-1] * scale_by) + height = round(input.shape[-2] * scale_by) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + match = init_image_mask_input(match, is_image(match)) + + width = match.shape[-1] + height = match.shape[-2] + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +class ResizeImageMaskNode(io.ComfyNode): + + scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + class ResizeTypedDict(TypedDict): + resize_type: ResizeType + scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop: Literal["disabled", "center"] + multiplier: float + width: int + height: int + longer_size: int + shorter_size: int + megapixels: float + + @classmethod + def define_schema(cls): + template = io.MatchType.Template("input_type", [io.Image, io.Mask]) + crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") + return io.Schema( + node_id="ResizeImageMaskNode", + display_name="Resize Image/Mask", + category="transform", + inputs=[ + io.MatchType.Input("input", template=template), + io.DynamicCombo.Input("resize_type", options=[ + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + crop_combo, + ]), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask]), + crop_combo, + ]), + ]), + io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), + ], + outputs=[io.MatchType.Output(template=template, display_name="resized")] + ) + + @classmethod + def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput: + selected_type = resize_type["resize_type"] + if selected_type == ResizeType.SCALE_BY: + return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method)) + elif selected_type == ResizeType.SCALE_DIMENSIONS: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_LONGER_DIMENSION: + return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method)) + elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION: + return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method)) + elif selected_type == ResizeType.SCALE_WIDTH: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method)) + elif selected_type == ResizeType.SCALE_HEIGHT: + return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method)) + elif selected_type == ResizeType.SCALE_TOTAL_PIXELS: + return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) + elif selected_type == ResizeType.MATCH_SIZE: + return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + raise ValueError(f"Unsupported resize type: {selected_type}") + +def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: + if len(images) == 0: + return None + # first, get the max channels count + max_channels = max(image.shape[-1] for image in images) + # then, pad all images to have the same channels count + padded_images: list[torch.Tensor] = [] + for image in images: + if image.shape[-1] < max_channels: + padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0)) + else: + padded_images.append(image) + # resize all images to be the same size as the first image + resized_images: list[torch.Tensor] = [] + first_image_shape = padded_images[0].shape + for image in padded_images: + if image.shape[1:] != first_image_shape[1:]: + resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1)) + else: + resized_images.append(image) + # batch the images in the format [b, h, w, c] + return torch.cat(resized_images, dim=0) + +def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None: + if len(masks) == 0: + return None + # resize all masks to be the same size as the first mask + resized_masks: list[torch.Tensor] = [] + first_mask_shape = masks[0].shape + for mask in masks: + if mask.shape[1:] != first_mask_shape[1:]: + mask = init_image_mask_input(mask, is_type_image=False) + mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center") + resized_masks.append(finalize_image_mask_input(mask, is_type_image=False)) + else: + resized_masks.append(mask) + # batch the masks in the format [b, h, w] + return torch.cat(resized_masks, dim=0) + +def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None: + if len(latents) == 0: + return None + samples_out = latents[0].copy() + samples_out["batch_index"] = [] + first_samples = latents[0]["samples"] + tensors: list[torch.Tensor] = [] + for latent in latents: + # first, deal with latent tensors + tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False)) + # next, deal with batch_index + samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])])) + samples_out["samples"] = torch.cat(tensors, dim=0) + return samples_out + +class BatchImagesNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50) + return io.Schema( + node_id="BatchImagesNode", + display_name="Batch Images", + category="image", + inputs=[ + io.Autogrow.Input("images", template=autogrow_template) + ], + outputs=[ + io.Image.Output() + ] + ) + + @classmethod + def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_images(list(images.values()))) + +class BatchMasksNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50) + return io.Schema( + node_id="BatchMasksNode", + display_name="Batch Masks", + category="mask", + inputs=[ + io.Autogrow.Input("masks", template=autogrow_template) + ], + outputs=[ + io.Mask.Output() + ] + ) + + @classmethod + def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_masks(list(masks.values()))) + +class BatchLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50) + return io.Schema( + node_id="BatchLatentsNode", + display_name="Batch Latents", + category="latent", + inputs=[ + io.Autogrow.Input("latents", template=autogrow_template) + ], + outputs=[ + io.Latent.Output() + ] + ) + + @classmethod + def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_latents(list(latents.values()))) + +class BatchImagesMasksLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent]) + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("input", matchtype_template), + prefix="input", min=1, max=50) + return io.Schema( + node_id="BatchImagesMasksLatentsNode", + display_name="Batch Images/Masks/Latents", + category="util", + inputs=[ + io.Autogrow.Input("inputs", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(id=None, template=matchtype_template) + ] + ) + + @classmethod + def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput: + batched = None + values = list(inputs.values()) + # latents + if isinstance(values[0], dict): + batched = batch_latents(values) + # images + elif is_image(values[0]): + batched = batch_images(values) + # masks + else: + batched = batch_masks(values) + return io.NodeOutput(batched) + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -250,6 +601,11 @@ class PostProcessingExtension(ComfyExtension): Quantize, Sharpen, ImageScaleToTotalPixels, + ResizeImageMaskNode, + BatchImagesNode, + BatchMasksNode, + BatchLatentsNode, + # BatchImagesMasksLatentsNode, ] async def comfy_entrypoint() -> PostProcessingExtension: diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 5a1aeba80..937321800 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -66,7 +66,7 @@ class Float(io.ComfyNode): display_name="Float", category="utils/primitive", inputs=[ - io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), + io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1), ], outputs=[io.Float.Output()], ) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 8e0f8287b..9016bb3be 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -818,7 +818,7 @@ def get_sample_indices(original_fps, if required_duration > total_frames / original_fps: raise ValueError("required_duration must be less than video length") - if not fixed_start is None and fixed_start >= 0: + if fixed_start is not None and fixed_start >= 0: start_frame = fixed_start else: max_start = total_frames - required_origin_frames diff --git a/comfyui_version.py b/comfyui_version.py index b45309198..1ed60fe5c 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.5.1" +__version__ = "0.7.0" diff --git a/execution.py b/execution.py index 0c239efd7..648f204ec 100644 --- a/execution.py +++ b/execution.py @@ -79,7 +79,7 @@ class IsChangedCache: # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: - is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) + is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await resolve_map_node_over_list_results(is_changed) node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: @@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) v3_data: io.V3Data = {} + hidden_inputs_v3 = {} + valid_inputs = class_def.INPUT_TYPES() if is_v3: - valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) - else: - valid_inputs = class_def.INPUT_TYPES() + valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs) input_data_all = {} missing_keys = {} - hidden_inputs_v3 = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) @@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [input_data] if is_v3: - if schema.hidden: - if io.Hidden.prompt in schema.hidden: + if hidden is not None: + if io.Hidden.prompt.name in hidden: hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} - if io.Hidden.dynprompt in schema.hidden: + if io.Hidden.dynprompt.name in hidden: hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt - if io.Hidden.extra_pnginfo in schema.hidden: + if io.Hidden.extra_pnginfo.name in hidden: hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) - if io.Hidden.unique_id in schema.hidden: + if io.Hidden.unique_id.name in hidden: hidden_inputs_v3[io.Hidden.unique_id] = unique_id - if io.Hidden.auth_token_comfy_org in schema.hidden: + if io.Hidden.auth_token_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) - if io.Hidden.api_key_comfy_org in schema.hidden: + if io.Hidden.api_key_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) else: if "hidden" in valid_inputs: @@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no resources or state, just create clone + # if is just a class, then assign no state, just create clone if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() @@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) + # for check_lazy_status, the returned data should include the original key of the input + v3_data_lazy = v3_data.copy() + v3_data_lazy["create_dynamic_tuple"] = True + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -599,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if isinstance(ex, comfy.model_management.OOM_EXCEPTION): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." + logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() @@ -756,10 +759,13 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors = [] valid = True + v3_data = None validate_function_inputs = [] validate_has_kwargs = False if issubclass(obj_class, _ComfyNodeInternal): - class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) + obj_class: _io._ComfyNodeBaseInternal + class_inputs = obj_class.INPUT_TYPES() + class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs) validate_function_name = "validate_inputs" validate_function = first_real_override(obj_class, validate_function_name) else: @@ -779,10 +785,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): assert extra_info is not None if x not in inputs: if input_category == "required": + details = f"{x}" if not v3_data else x.split(".")[-1] error = { "type": "required_input_missing", "message": "Required input is missing", - "details": f"{x}", + "details": details, "extra_info": { "input_name": x } @@ -916,8 +923,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue - if isinstance(input_type, list): - combo_options = input_type + if isinstance(input_type, list) or input_type == io.Combo.io_type: + if input_type == io.Combo.io_type: + combo_options = extra_info.get("options", []) + else: + combo_options = input_type if val not in combo_options: input_config = info list_info = "" diff --git a/manager_requirements.txt b/manager_requirements.txt index 2300f0c70..6585b0c19 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b7 +comfyui_manager==4.0.4 diff --git a/nodes.py b/nodes.py index 7d83ecb21..662907ae6 100644 --- a/nodes.py +++ b/nodes.py @@ -1663,8 +1663,6 @@ class LoadImage: output_masks = [] w, h = None, None - excluded_formats = ['MPO'] - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) @@ -1692,7 +1690,10 @@ class LoadImage: output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if len(output_images) > 1 and img.format not in excluded_formats: + if img.format == "MPO": + break # ignore all frames except the first one for MPO format + + if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: @@ -1863,6 +1864,7 @@ class ImageBatch: FUNCTION = "batch" CATEGORY = "image" + DEPRECATED = True def batch(self, image1, image2): if image1.shape[-1] != image2.shape[-1]: @@ -2241,8 +2243,10 @@ async def init_external_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - if module_path.endswith(".disabled"): continue + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": + continue + if module_path.endswith(".disabled"): + continue if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue diff --git a/pyproject.toml b/pyproject.toml index 3a6960811..60378de1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.5.1" +version = "0.7.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" @@ -15,12 +15,16 @@ lint.select = [ "N805", # invalid-first-argument-name-for-method "S307", # suspicious-eval-usage "S102", # exec + "E", "T", # print-usage "W", # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", ] + +lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"] + exclude = ["*.ipynb", "**/generated/*.pyi"] [tool.pylint] diff --git a/requirements.txt b/requirements.txt index 54696395f..3a05799eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.60 +comfyui-frontend-package==1.35.9 +comfyui-workflow-templates==0.7.65 comfyui-embedded-docs==0.3.1 torch torchsde diff --git a/server.py b/server.py index c27f8be7d..70c8b5e3b 100644 --- a/server.py +++ b/server.py @@ -324,7 +324,7 @@ class PromptServer(): @routes.get("/models/{folder}") async def get_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = folder_paths.get_filename_list(folder) return web.json_response(files) @@ -579,7 +579,7 @@ class PromptServer(): folder_name = request.match_info.get("folder_name", None) if folder_name is None: return web.Response(status=404) - if not "filename" in request.rel_url.query: + if "filename" not in request.rel_url.query: return web.Response(status=404) filename = request.rel_url.query["filename"] @@ -593,7 +593,7 @@ class PromptServer(): if out is None: return web.Response(status=404) dt = json.loads(out) - if not "__metadata__" in dt: + if "__metadata__" not in dt: return web.Response(status=404) return web.json_response(dt["__metadata__"]) diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py index b5a0f022c..5c6a15ac4 100644 --- a/tests-unit/comfy_extras_test/image_stitch_test.py +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -25,7 +25,7 @@ class TestImageStitch: result = node.stitch(image1, "right", True, 0, "white", image2=None) - assert len(result) == 1 + assert len(result.result) == 1 assert torch.equal(result[0], image1) def test_basic_horizontal_stitch_right(self):