Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-08-01 12:40:30 +09:00
commit 56cff964f2
16 changed files with 2710 additions and 74 deletions

View File

@ -111,7 +111,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process ## Release Process
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories: ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) - Releases a new stable version (e.g., v0.7.0)

View File

@ -5,3 +5,146 @@ from .api_registry import (
register_versions as register_versions, register_versions as register_versions,
get_all_versions as get_all_versions, get_all_versions as get_all_versions,
) )
import asyncio
from dataclasses import asdict
from typing import Callable, Optional
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
"""Return the *callable* override of `name` visible on `cls`, or None if every
implementation up to (and including) `base` is the placeholder defined on `base`.
If base is not provided, it will assume cls has a GET_BASE_CLASS
"""
if base is None:
if not hasattr(cls, "GET_BASE_CLASS"):
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
base = cls.GET_BASE_CLASS()
base_attr = getattr(base, name, None)
if base_attr is None:
return None
base_func = base_attr.__func__
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
if c is base: # reached the placeholder we're done
break
if name in c.__dict__: # first class that *defines* the attr
func = getattr(c, name).__func__
if func is not base_func: # real override
return getattr(cls, name) # bound to *cls*
return None
class _ComfyNodeInternal:
"""Class that all V3-based APIs inherit from for ComfyNode.
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
@classmethod
def GET_NODE_INFO_V1(cls):
...
class _NodeOutputInternal:
"""Class that all V3-based APIs inherit from for NodeOutput.
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
...
def as_pruned_dict(dataclass_obj):
'''Return dict of dataclass object with pruned None values.'''
return prune_dict(asdict(dataclass_obj))
def prune_dict(d: dict):
return {k: v for k,v in d.items() if v is not None}
def is_class(obj):
'''
Returns True if is a class type.
Returns False if is a class instance.
'''
return isinstance(obj, type)
def copy_class(cls: type) -> type:
'''
Copy a class and its attributes.
'''
if cls is None:
return None
cls_dict = {
k: v for k, v in cls.__dict__.items()
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
}
# new class
new_cls = type(
cls.__name__,
(cls,),
cls_dict
)
# metadata preservation
new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__
return new_cls
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
# NOTE: this was ai generated and validated by hand
def shallow_clone_class(cls, new_name=None):
'''
Shallow clone a class while preserving super() functionality.
'''
new_name = new_name or f"{cls.__name__}Clone"
# Include the original class in the bases to maintain proper inheritance
new_bases = (cls,) + cls.__bases__
return type(new_name, new_bases, dict(cls.__dict__))
# NOTE: this was ai generated and validated by hand
def lock_class(cls):
'''
Lock a class so that its top-levelattributes cannot be modified.
'''
# Locked instance __setattr__
def locked_instance_setattr(self, name, value):
raise AttributeError(
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
)
# Locked metaclass
class LockedMeta(type(cls)):
def __setattr__(cls_, name, value):
raise AttributeError(
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
)
# Rebuild class with locked behavior
locked_dict = dict(cls.__dict__)
locked_dict['__setattr__'] = locked_instance_setattr
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
def make_locked_method_func(type_obj, func, class_clone):
"""
Returns a function that, when called with **inputs, will execute:
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
Supports both synchronous and asynchronous methods.
"""
locked_class = lock_class(class_clone)
method = getattr(type_obj, func).__func__
# Check if the original method is async
if asyncio.iscoroutinefunction(method):
async def wrapped_async_func(**inputs):
return await method(locked_class, **inputs)
return wrapped_async_func
else:
def wrapped_func(**inputs):
return method(locked_class, **inputs)
return wrapped_func

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.singleton import ProxiedSingleton
@ -7,6 +8,9 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image from PIL import Image
@ -72,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase):
execution: Execution execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
"""
@abstractmethod
async def get_node_list(self) -> list[type[io.ComfyNode]]:
"""
Returns a list of nodes that this extension provides.
"""
class Input: class Input:
Image = ImageInput Image = ImageInput
Audio = AudioInput Audio = AudioInput
@ -103,4 +120,5 @@ __all__ = [
"Input", "Input",
"InputImpl", "InputImpl",
"Types", "Types",
"ComfyExtension",
] ]

1618
comfy_api/latest/_io.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,72 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

457
comfy_api/latest/_ui.py Normal file
View File

@ -0,0 +1,457 @@
from __future__ import annotations
import json
import os
import random
from io import BytesIO
from typing import Type
import av
import numpy as np
import torch
import torchaudio
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo
import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
def __init__(self, filename: str, subfolder: str, type: FolderType):
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
@property
def filename(self) -> str:
return self["filename"]
@property
def subfolder(self) -> str:
return self["subfolder"]
@property
def type(self) -> FolderType:
return FolderType(self["type"])
class SavedImages(_UIOutput):
"""A UI output class to represent one or more saved images, potentially animated."""
def __init__(self, results: list[SavedResult], is_animated: bool = False):
super().__init__()
self.results = results
self.is_animated = is_animated
def as_dict(self) -> dict:
data = {"images": self.results}
if self.is_animated:
data["animated"] = (True,)
return data
class SavedAudios(_UIOutput):
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
def __init__(self, results: list[SavedResult]):
super().__init__()
self.results = results
def as_dict(self) -> dict:
return {"audio": self.results}
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
if folder_type == FolderType.input:
return folder_paths.get_input_directory()
if folder_type == FolderType.output:
return folder_paths.get_output_directory()
return folder_paths.get_temp_directory()
class ImageSaveHelper:
"""A helper class with static methods to handle image saving and metadata."""
@staticmethod
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
"""Converts a single torch tensor to a PIL Image."""
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
metadata = PngInfo()
if cls.hidden.prompt:
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
if cls.hidden.extra_pnginfo:
for x in cls.hidden.extra_pnginfo:
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
return metadata
@staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
metadata = PngInfo()
if cls.hidden.prompt:
metadata.add(
b"comf",
"prompt".encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
after_idat=True,
)
if cls.hidden.extra_pnginfo:
for x in cls.hidden.extra_pnginfo:
metadata.add(
b"comf",
x.encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
after_idat=True,
)
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
return exif_data
if cls.hidden.prompt is not None:
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
if cls.hidden.extra_pnginfo is not None:
inital_exif_tag = 0x010F # EXIF 0x010f = Make
for key, value in cls.hidden.extra_pnginfo.items():
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
inital_exif_tag -= 1
return exif_data
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
results = []
metadata = ImageSaveHelper._create_png_metadata(cls)
for batch_number, image_tensor in enumerate(images):
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
compress_level=compress_level,
)
)
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
file = f"{filename}_{counter:05}_.png"
save_path = os.path.join(full_output_folder, file)
pil_images[0].save(
save_path,
pnginfo=metadata,
compress_level=compress_level,
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
)
return SavedResult(file, subfolder, folder_type)
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
fps=fps,
compress_level=compress_level,
)
return SavedImages([result], is_animated=len(images) > 1)
@staticmethod
def save_animated_webp(
images,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
method: int,
) -> SavedResult:
"""Saves a batch of images as a single animated WebP."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
file = f"{filename}_{counter:05}_.webp"
pil_images[0].save(
os.path.join(full_output_folder, file),
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
exif=pil_exif,
lossless=lossless,
quality=quality,
method=method,
)
return SavedResult(file, subfolder, folder_type)
@staticmethod
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
method: int,
) -> SavedImages:
"""Saves an animated WebP and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_webp(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
fps=fps,
lossless=lossless,
quality=quality,
method=method,
)
return SavedImages([result], is_animated=len(images) > 1)
class AudioSaveHelper:
"""A helper class with static methods to handle audio saving and metadata."""
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
@staticmethod
def save_audio(
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type)
)
metadata = {}
if not args.disable_metadata and cls is not None:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
results = []
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(AudioSaveHelper._OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode="w", format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, "wb") as f:
f.write(output_buffer.getbuffer())
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
return results
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
AudioSaveHelper.save_audio(
audio,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
format=format,
quality=quality,
)
)
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
folder_type=FolderType.temp,
cls=cls,
compress_level=1,
)
self.animated = animated
def as_dict(self):
return {
"images": self.values,
"animated": (self.animated,)
}
class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, cls, **kwargs)
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
folder_type=FolderType.temp,
cls=cls,
format="flac",
quality="128k",
)
def as_dict(self) -> dict:
return {"audio": self.values}
class PreviewVideo(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
return {"images": self.values, "animated": (True,)}
class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
def as_dict(self):
return {"result": [self.model_file, self.camera_info]}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText

View File

@ -6,6 +6,7 @@ from comfy_api.latest import (
) )
from typing import Type, TYPE_CHECKING from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@ -40,4 +41,5 @@ __all__ = [
"Input", "Input",
"InputImpl", "InputImpl",
"Types", "Types",
"ComfyExtension",
] ]

View File

@ -4,9 +4,12 @@ from typing import Type, Literal
import nodes import nodes
import asyncio import asyncio
import inspect import inspect
from comfy_execution.graph_utils import is_link from comfy_execution.graph_utils import is_link, ExecutionBlocker
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
ExecutionBlocker = ExecutionBlocker
class DependencyCycleError(Exception): class DependencyCycleError(Exception):
pass pass
@ -294,21 +297,3 @@ class ExecutionList(TopologicalSort):
del blocked_by[node_id] del blocked_by[node_id]
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
return list(blocked_by.keys()) return list(blocked_by.keys())
class ExecutionBlocker:
"""
Return this from a node and any users will be blocked with the given error message.
If the message is None, execution will be blocked silently instead.
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
possible, a lazy input will be more efficient and have a better user experience.
This functionality is useful in two cases:
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
lazy evaluation to let it conditionally disable itself.)
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
"""
def __init__(self, message):
self.message = message

View File

@ -137,3 +137,19 @@ def add_graph_prefix(graph, outputs, prefix):
return new_graph, tuple(new_outputs) return new_graph, tuple(new_outputs)
class ExecutionBlocker:
"""
Return this from a node and any users will be blocked with the given error message.
If the message is None, execution will be blocked silently instead.
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
possible, a lazy input will be more efficient and have a better user experience.
This functionality is useful in two cases:
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
lazy evaluation to let it conditionally disable itself.)
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
"""
def __init__(self, message):
self.message = message

View File

@ -149,6 +149,7 @@ class WanFirstLastFrameToVideo:
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
clip_vision_output = None
if clip_vision_start_image is not None: if clip_vision_start_image is not None:
clip_vision_output = clip_vision_start_image clip_vision_output = clip_vision_start_image

View File

@ -7,7 +7,7 @@ import threading
import time import time
import traceback import traceback
from enum import Enum from enum import Enum
from typing import List, Literal, NamedTuple, Optional from typing import List, Literal, NamedTuple, Optional, Union
import asyncio import asyncio
import torch import torch
@ -32,6 +32,8 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -56,7 +58,15 @@ class IsChangedCache:
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if not hasattr(class_def, "IS_CHANGED"): has_is_changed = False
is_changed_name = None
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
has_is_changed = True
is_changed_name = "fingerprint_inputs"
elif hasattr(class_def, "IS_CHANGED"):
has_is_changed = True
is_changed_name = "IS_CHANGED"
if not has_is_changed:
self.is_changed[node_id] = False self.is_changed[node_id] = False
return self.is_changed[node_id] return self.is_changed[node_id]
@ -65,9 +75,9 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED") is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
@ -126,9 +136,14 @@ class CacheSet:
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
else:
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
missing_keys = {} missing_keys = {}
hidden_inputs_v3 = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
_, input_category, input_info = get_input_info(class_def, x, valid_inputs) _, input_category, input_info = get_input_info(class_def, x, valid_inputs)
@ -153,6 +168,21 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
elif input_category is not None: elif input_category is not None:
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
if is_v3:
if schema.hidden:
if io.Hidden.prompt in schema.hidden:
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
if io.Hidden.dynprompt in schema.hidden:
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
if io.Hidden.extra_pnginfo in schema.hidden:
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
if io.Hidden.unique_id in schema.hidden:
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
if io.Hidden.auth_token_comfy_org in schema.hidden:
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
if io.Hidden.api_key_comfy_org in schema.hidden:
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
else:
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
h = valid_inputs["hidden"] h = valid_inputs["hidden"]
for x in h: for x in h:
@ -168,7 +198,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG": if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys return input_data_all, missing_keys, hidden_inputs_v3
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please
@ -184,7 +214,7 @@ async def resolve_map_node_over_list_results(results):
raise exc raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results] return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
# check if node wants the lists # check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False) input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -214,6 +244,21 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if execution_block is None: if execution_block is None:
if pre_execute_cb is not None and index is not None: if pre_execute_cb is not None and index is not None:
pre_execute_cb(index) pre_execute_cb(index)
# V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# if is just a class, then assign no resources or state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
f = make_locked_method_func(type_obj, func, class_clone)
# V1
else:
f = getattr(obj, func) f = getattr(obj, func)
if inspect.iscoroutinefunction(f): if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args): async def async_wrapper(f, prompt_id, unique_id, list_index, args):
@ -266,8 +311,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task: if has_pending_task:
return return_values, {}, False, has_pending_task return return_values, {}, False, has_pending_task
@ -298,6 +343,26 @@ def get_output_from_returns(return_values, obj):
result = tuple([result] * len(obj.RETURN_TYPES)) result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result) results.append(result)
subgraph_results.append((None, result)) subgraph_results.append((None, result))
elif isinstance(r, _NodeOutputInternal):
# V3
if r.ui is not None:
if isinstance(r.ui, dict):
uis.append(r.ui)
else:
uis.append(r.ui.as_dict())
if r.expand is not None:
has_subgraph = True
new_graph = r.expand
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif r.result is not None:
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else: else:
if isinstance(r, ExecutionBlocker): if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES)) r = tuple([r] * len(obj.RETURN_TYPES))
@ -381,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) get_progress_state().start_progress(unique_id)
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -391,8 +456,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
obj = class_def() obj = class_def()
caches.objects.set(unique_id, obj) caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"): if issubclass(class_def, _ComfyNodeInternal):
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True) lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@ -424,7 +493,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index): def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id) unblock = execution_list.add_external_block(unique_id)
@ -672,8 +741,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
validate_function_inputs = [] validate_function_inputs = []
validate_has_kwargs = False validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"): if issubclass(obj_class, _ComfyNodeInternal):
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS) validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name)
else:
validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None:
argspec = inspect.getfullargspec(validate_function)
validate_function_inputs = argspec.args validate_function_inputs = argspec.args
validate_has_kwargs = argspec.varkw is not None validate_has_kwargs = argspec.varkw is not None
received_types = {} received_types = {}
@ -848,7 +923,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue continue
if len(validate_function_inputs) > 0 or validate_has_kwargs: if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _ = get_input_data(inputs, obj_class, unique_id) input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
input_filtered = {} input_filtered = {}
for x in input_data_all: for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs: if x in validate_function_inputs or validate_has_kwargs:
@ -856,8 +931,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs: if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await resolve_map_node_over_list_results(ret) ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):
@ -891,7 +965,7 @@ def full_type_name(klass):
return klass.__qualname__ return klass.__qualname__
return module + '.' + klass.__qualname__ return module + '.' + klass.__qualname__
async def validate_prompt(prompt_id, prompt): async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set() outputs = set()
for x in prompt: for x in prompt:
if 'class_type' not in prompt[x]: if 'class_type' not in prompt[x]:
@ -915,6 +989,7 @@ async def validate_prompt(prompt_id, prompt):
return (False, error, [], {}) return (False, error, [], {})
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x) outputs.add(x)
if len(outputs) == 0: if len(outputs) == 0:

View File

@ -6,6 +6,7 @@ import os
import sys import sys
import json import json
import hashlib import hashlib
import inspect
import traceback import traceback
import math import math
import time import time
@ -29,6 +30,7 @@ import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions from comfy_api.version_list import supported_versions
from comfy_api.latest import io, ComfyExtension
import comfy.clip_vision import comfy.clip_vision
@ -2155,6 +2157,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if os.path.isdir(web_dir): if os.path.isdir(web_dir):
EXTENSION_WEB_DIRS[module_name] = web_dir EXTENSION_WEB_DIRS[module_name] = web_dir
# V1 node definition
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
if name not in ignore: if name not in ignore:
@ -2163,8 +2166,38 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True return True
# V3 Extension Definition
elif hasattr(module, "comfy_entrypoint"):
entrypoint = getattr(module, "comfy_entrypoint")
if not callable(entrypoint):
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
return False
try:
if inspect.iscoroutinefunction(entrypoint):
extension = await entrypoint()
else: else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") extension = entrypoint()
if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
return False
for node_cls in node_list:
node_cls: io.ComfyNode
schema = node_cls.GET_SCHEMA()
if schema.node_id not in ignore:
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if schema.display_name is not None:
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
return True
except Exception as e:
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
return False
else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
return False return False
except Exception as e: except Exception as e:
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
@ -2295,7 +2328,7 @@ async def init_builtin_extra_nodes():
"nodes_string.py", "nodes_string.py",
"nodes_camera_trajectory.py", "nodes_camera_trajectory.py",
"nodes_edit_model.py", "nodes_edit_model.py",
"nodes_tcfg.py" "nodes_tcfg.py",
] ]
import_failed = [] import_failed = []

View File

@ -30,6 +30,7 @@ from comfy_api import feature_flags
import node_helpers import node_helpers
from comfyui_version import __version__ from comfyui_version import __version__
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
from comfy_api.internal import _ComfyNodeInternal
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -597,6 +598,8 @@ class PromptServer():
def node_info(node_class): def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
if issubclass(obj_class, _ComfyNodeInternal):
return obj_class.GET_NODE_INFO_V1()
info = {} info = {}
info['input'] = obj_class.INPUT_TYPES() info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
@ -687,7 +690,12 @@ class PromptServer():
if "prompt" in json_data: if "prompt" in json_data:
prompt = json_data["prompt"] prompt = json_data["prompt"]
prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
valid = await execution.validate_prompt(prompt_id, prompt)
partial_execution_targets = None
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {} extra_data = {}
if "extra_data" in json_data: if "extra_data" in json_data:
extra_data = json_data["extra_data"] extra_data = json_data["extra_data"]

View File

@ -7,7 +7,7 @@ import subprocess
from pytest import fixture from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient from tests.inference.test_execution import ComfyClient, run_warmup
@pytest.mark.execution @pytest.mark.execution
@ -24,6 +24,7 @@ class TestAsyncNodes:
'--listen', args_pytest["listen"], '--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]), '--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
] ]
use_lru, lru_size = request.param use_lru, lru_size = request.param
if use_lru: if use_lru:
@ -82,6 +83,9 @@ class TestAsyncNodes:
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel.""" """Test that multiple async nodes execute in parallel."""
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -148,6 +152,9 @@ class TestAsyncNodes:
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation.""" """Test async nodes with lazy evaluation."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_lazy")
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -305,6 +312,9 @@ class TestAsyncNodes:
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached.""" """Test that async nodes are properly cached."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_cache")
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
@ -324,6 +334,9 @@ class TestAsyncNodes:
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts.""" """Test async nodes within dynamically generated prompts."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_dynamic")
g = builder g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)

View File

@ -15,10 +15,18 @@ import urllib.parse
import urllib.error import urllib.error
from comfy_execution.graph_utils import GraphBuilder, Node from comfy_execution.graph_utils import GraphBuilder, Node
def run_warmup(client, prefix="warmup"):
"""Run a simple workflow to warm up the server."""
warmup_g = GraphBuilder(prefix=prefix)
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
warmup_g.node("PreviewImage", images=warmup_image.out(0))
client.run(warmup_g)
class RunResult: class RunResult:
def __init__(self, prompt_id: str): def __init__(self, prompt_id: str):
self.outputs: Dict[str,Dict] = {} self.outputs: Dict[str,Dict] = {}
self.runs: Dict[str,bool] = {} self.runs: Dict[str,bool] = {}
self.cached: Dict[str,bool] = {}
self.prompt_id: str = prompt_id self.prompt_id: str = prompt_id
def get_output(self, node: Node): def get_output(self, node: Node):
@ -27,6 +35,13 @@ class RunResult:
def did_run(self, node: Node): def did_run(self, node: Node):
return self.runs.get(node.id, False) return self.runs.get(node.id, False)
def was_cached(self, node: Node):
return self.cached.get(node.id, False)
def was_executed(self, node: Node):
"""Returns True if node was either run or cached"""
return self.did_run(node) or self.was_cached(node)
def get_images(self, node: Node): def get_images(self, node: Node):
output = self.get_output(node) output = self.get_output(node)
if output is None: if output is None:
@ -51,8 +66,10 @@ class ComfyClient:
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
self.ws = ws self.ws = ws
def queue_prompt(self, prompt): def queue_prompt(self, prompt, partial_execution_targets=None):
p = {"prompt": prompt, "client_id": self.client_id} p = {"prompt": prompt, "client_id": self.client_id}
if partial_execution_targets is not None:
p["partial_execution_targets"] = partial_execution_targets
data = json.dumps(p).encode('utf-8') data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read()) return json.loads(urllib.request.urlopen(req).read())
@ -70,13 +87,13 @@ class ComfyClient:
def set_test_name(self, name): def set_test_name(self, name):
self.test_name = name self.test_name = name
def run(self, graph): def run(self, graph, partial_execution_targets=None):
prompt = graph.finalize() prompt = graph.finalize()
for node in graph.nodes.values(): for node in graph.nodes.values():
if node.class_type == 'SaveImage': if node.class_type == 'SaveImage':
node.inputs['filename_prefix'] = self.test_name node.inputs['filename_prefix'] = self.test_name
prompt_id = self.queue_prompt(prompt)['prompt_id'] prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
result = RunResult(prompt_id) result = RunResult(prompt_id)
while True: while True:
out = self.ws.recv() out = self.ws.recv()
@ -92,7 +109,10 @@ class ComfyClient:
elif message['type'] == 'execution_error': elif message['type'] == 'execution_error':
raise Exception(message['data']) raise Exception(message['data'])
elif message['type'] == 'execution_cached': elif message['type'] == 'execution_cached':
pass # Probably want to store this off for testing if message['data']['prompt_id'] == prompt_id:
cached_nodes = message['data'].get('nodes', [])
for node_id in cached_nodes:
result.cached[node_id] = True
history = self.get_history(prompt_id)[prompt_id] history = self.get_history(prompt_id)[prompt_id]
for node_id in history['outputs']: for node_id in history['outputs']:
@ -130,6 +150,7 @@ class TestExecution:
'--listen', args_pytest["listen"], '--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]), '--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
] ]
use_lru, lru_size = request.param use_lru, lru_size = request.param
if use_lru: if use_lru:
@ -498,12 +519,15 @@ class TestExecution:
assert not result.did_run(test_node), "The execution should have been cached" assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder): def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create sleep nodes for each duration # Create sleep nodes for each duration
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8) sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9) sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0) sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
# Add outputs to verify the execution # Add outputs to verify the execution
@ -515,10 +539,9 @@ class TestExecution:
result = client.run(g) result = client.run(g)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# The test should take around 0.4 seconds (the longest sleep duration) # The test should take around 3.0 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (0.9s) # plus some overhead, but definitely less than the sum of all sleeps (9.0s)
# We'll allow for up to 0.8s total to account for overhead assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
# Verify that all nodes executed # Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run" assert result.did_run(sleep_node1), "Sleep node 1 should have run"
@ -526,6 +549,9 @@ class TestExecution:
assert result.did_run(sleep_node3), "Sleep node 3 should have run" assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder): def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder g = builder
# Create input images with different values # Create input images with different values
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -537,9 +563,9 @@ class TestExecution:
image1=image1.out(0), image1=image1.out(0),
image2=image2.out(0), image2=image2.out(0),
image3=image3.out(0), image3=image3.out(0),
sleep1=0.4, sleep1=4.8,
sleep2=0.5, sleep2=4.9,
sleep3=0.6) sleep3=5.0)
output = g.node("SaveImage", images=parallel_sleep.out(0)) output = g.node("SaveImage", images=parallel_sleep.out(0))
start_time = time.time() start_time = time.time()
@ -548,7 +574,7 @@ class TestExecution:
# Similar to the previous test, expect parallel execution of the sleep nodes # Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps # which should complete in less than the sum of all sleeps
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s" assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
# Verify the parallel sleep node executed # Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run" assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
@ -585,3 +611,151 @@ class TestExecution:
assert len(images) == 2, "Should have 2 images" assert len(images) == 2, "Should have 2 images"
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black" assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black" assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
# Output nodes included in the partial execution list are executed
def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create two separate output nodes
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input2.out(0))
# Run with partial execution targeting only output1
result = client.run(g, partial_execution_targets=[output1.id])
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Verify only output1 produced results
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
# Output nodes NOT included in the partial execution list are NOT executed
def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
# Create three output nodes
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input2.out(0))
output3 = g.node("SaveImage", images=input3.out(0))
# Run with partial execution targeting only output1 and output3
result = client.run(g, partial_execution_targets=[output1.id, output3.id])
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(input3), "Input3 should have been executed"
assert result.was_executed(output1), "Output1 should have been executed"
assert result.was_executed(output3), "Output3 should have been executed"
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create a processing chain with an OUTPUT_NODE that has socket outputs
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
# Create another node that depends on the output_with_socket
dependent_node = g.node("TestLazyMixImages",
image1=output_with_socket.out(0),
image2=input1.out(0),
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
# Create the final output
final_output = g.node("SaveImage", images=dependent_node.out(0))
# Run with partial execution targeting only the final output
result = client.run(g, partial_execution_targets=[final_output.id])
# All nodes should have been executed because they're dependencies
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
assert result.was_executed(dependent_node), "Dependent node should have been executed"
assert result.was_executed(final_output), "Final output should have been executed"
# Lazy execution works with partial execution
def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
# Create masks that will trigger different lazy execution paths
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
# Create two lazy mix nodes
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
# Run with partial execution targeting only output1
result = client.run(g, partial_execution_targets=[output1.id])
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
assert result.was_executed(input1), "Input1 should have been executed"
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
assert result.was_executed(mask1), "Mask1 should have been executed"
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
assert result.was_executed(output1), "Output1 should have been executed"
# Nothing from output2 path should run
assert not result.did_run(input3), "Input3 should not have run"
assert not result.did_run(mask2), "Mask2 should not have run"
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Multiple OUTPUT_NODEs with dependencies
def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create a chain of OUTPUT_NODEs
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
# Create regular output nodes
save1 = g.node("SaveImage", images=output_node1.out(0))
save2 = g.node("SaveImage", images=output_node2.out(0))
save3 = g.node("SaveImage", images=input2.out(0))
# Run targeting only save2
result = client.run(g, partial_execution_targets=[save2.id])
# Should run: input1, output_node1, output_node2, save2
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
assert result.was_executed(save2), "Save2 should have been executed"
# Should NOT run: input2, save1, save3
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(save1), "Save1 should not have run"
assert not result.did_run(save3), "Save3 should not have run"
# Empty partial execution list (should execute nothing)
def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
_output1 = g.node("SaveImage", images=input1.out(0))
# Run with empty partial execution list
try:
_result = client.run(g, partial_execution_targets=[])
# Should get an error because no outputs are selected
assert False, "Should have raised an error for empty partial execution list"
except urllib.error.HTTPError:
pass # Expected behavior

View File

@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC):
"expand": g.finalize(), "expand": g.finalize(),
} }
class TestOutputNodeWithSocketOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def process(self, image, value):
# Apply value scaling and return both as output and socket
result = image * value
return (result,)
TEST_NODE_CLASS_MAPPINGS = { TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages, "TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage, "TestVariadicAverage": TestVariadicAverage,
@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestSamplingInExpansion": TestSamplingInExpansion, "TestSamplingInExpansion": TestSamplingInExpansion,
"TestSleep": TestSleep, "TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep, "TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
} }
TEST_NODE_DISPLAY_NAME_MAPPINGS = { TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestSamplingInExpansion": "Sampling In Expansion", "TestSamplingInExpansion": "Sampling In Expansion",
"TestSleep": "Test Sleep", "TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep", "TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
} }