mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 11:32:31 +08:00
Compare commits
14 Commits
7d45e94858
...
a7ea4dc7b8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7ea4dc7b8 | ||
|
|
0167653781 | ||
|
|
0a7993729c | ||
|
|
c6b6128b2b | ||
|
|
3770dc0ec4 | ||
|
|
96e9a81cdf | ||
|
|
5cf4115f50 | ||
|
|
b951181123 | ||
|
|
af4d691d1f | ||
|
|
f511703343 | ||
|
|
1107f4322b | ||
|
|
4683136740 | ||
|
|
38ab4e3c76 | ||
|
|
232995856e |
@ -1252,23 +1252,6 @@ class NodeInfoV1:
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
|
||||
@dataclass
|
||||
class NodeInfoV3:
|
||||
input: dict=None
|
||||
output: dict=None
|
||||
hidden: list[str]=None
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
python_module: Any = None
|
||||
category: str=None
|
||||
output_node: bool=None
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
dev_only: bool=None
|
||||
api_node: bool=None
|
||||
price_badge: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceBadgeDepends:
|
||||
@ -1497,40 +1480,6 @@ class Schema:
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def get_v3_info(self, cls) -> NodeInfoV3:
|
||||
input_dict = {}
|
||||
output_dict = {}
|
||||
hidden_list = []
|
||||
# TODO: make sure dynamic types will be handled correctly
|
||||
if self.inputs:
|
||||
for input in self.inputs:
|
||||
add_to_dict_v3(input, input_dict)
|
||||
if self.outputs:
|
||||
for output in self.outputs:
|
||||
add_to_dict_v3(output, output_dict)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
hidden_list.append(hidden.value)
|
||||
|
||||
info = NodeInfoV3(
|
||||
input=input_dict,
|
||||
output=output_dict,
|
||||
hidden=hidden_list,
|
||||
name=self.node_id,
|
||||
display_name=self.display_name,
|
||||
description=self.description,
|
||||
category=self.category,
|
||||
output_node=self.is_output_node,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
dev_only=self.is_dev_only,
|
||||
api_node=self.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||
)
|
||||
return info
|
||||
|
||||
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||
out_dict = {
|
||||
"required": {},
|
||||
@ -1585,9 +1534,6 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
as_dict.pop("optional", None)
|
||||
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
|
||||
@ -1748,13 +1694,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
# set hidden
|
||||
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
info = schema.get_v3_info(cls)
|
||||
return asdict(info)
|
||||
#############################################
|
||||
# V1 Backwards Compatibility code
|
||||
#--------------------------------------------
|
||||
@ -2107,12 +2046,10 @@ __all__ = [
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
"NodeInfoV1",
|
||||
"NodeInfoV3",
|
||||
"Schema",
|
||||
"ComfyNode",
|
||||
"NodeOutput",
|
||||
"add_to_dict_v1",
|
||||
"add_to_dict_v3",
|
||||
"V3Data",
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, conint, confloat
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RecraftColor:
|
||||
@ -229,24 +226,24 @@ class RecraftColorObject(BaseModel):
|
||||
|
||||
|
||||
class RecraftControlsObject(BaseModel):
|
||||
colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
|
||||
background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
|
||||
no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
|
||||
artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||
colors: list[RecraftColorObject] | None = Field(None, description='An array of preferable colors')
|
||||
background_color: RecraftColorObject | None = Field(None, description='Use given color as a desired background color')
|
||||
no_text: bool | None = Field(None, description='Do not embed text layouts')
|
||||
artistic_level: int | None = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||
|
||||
|
||||
class RecraftImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||
size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
|
||||
negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
|
||||
model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: Optional[int] = Field(None, description="Seed for video generation")
|
||||
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: int = Field(..., description='The number of images to generate')
|
||||
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
|
||||
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: int | None = Field(None, description="Seed for video generation")
|
||||
# text_layout
|
||||
|
||||
|
||||
@ -258,5 +255,13 @@ class RecraftReturnedObject(BaseModel):
|
||||
class RecraftImageGenerationResponse(BaseModel):
|
||||
created: int = Field(..., description='Unix timestamp when the generation was created')
|
||||
credits: int = Field(..., description='Number of credits used for the generation')
|
||||
data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
|
||||
image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')
|
||||
data: list[RecraftReturnedObject] | None = Field(None, description='Array of generated image information')
|
||||
image: RecraftReturnedObject | None = Field(None, description='Single generated image')
|
||||
|
||||
|
||||
class RecraftCreateStyleRequest(BaseModel):
|
||||
style: str = Field(..., description="realistic_image, digital_illustration, vector_illustration, or icon")
|
||||
|
||||
|
||||
class RecraftCreateStyleResponse(BaseModel):
|
||||
id: str = Field(..., description="UUID of the created style")
|
||||
|
||||
@ -12,6 +12,8 @@ from comfy_api_nodes.apis.recraft import (
|
||||
RecraftColor,
|
||||
RecraftColorChain,
|
||||
RecraftControls,
|
||||
RecraftCreateStyleRequest,
|
||||
RecraftCreateStyleResponse,
|
||||
RecraftImageGenerationRequest,
|
||||
RecraftImageGenerationResponse,
|
||||
RecraftImageSize,
|
||||
@ -323,6 +325,75 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
return IO.NodeOutput(RecraftStyle(style_id=style_id))
|
||||
|
||||
|
||||
class RecraftCreateStyleNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreateStyleNode",
|
||||
display_name="Recraft Create Style",
|
||||
category="api node/image/Recraft",
|
||||
description="Create a custom style from reference images. "
|
||||
"Upload 1-5 images to use as style references. "
|
||||
"Total size of all images is limited to 5 MB.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"style",
|
||||
options=["realistic_image", "digital_illustration"],
|
||||
tooltip="The base style of the generated images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="image",
|
||||
min=1,
|
||||
max=5,
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="style_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.04}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
style: str,
|
||||
images: IO.Autogrow.Type,
|
||||
) -> IO.NodeOutput:
|
||||
files = []
|
||||
total_size = 0
|
||||
max_total_size = 5 * 1024 * 1024 # 5 MB limit
|
||||
for i, img in enumerate(list(images.values())):
|
||||
file_bytes = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp").read()
|
||||
total_size += len(file_bytes)
|
||||
if total_size > max_total_size:
|
||||
raise Exception("Total size of all images exceeds 5 MB limit.")
|
||||
files.append((f"file{i + 1}", file_bytes))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/recraft/styles", method="POST"),
|
||||
response_model=RecraftCreateStyleResponse,
|
||||
files=files,
|
||||
data=RecraftCreateStyleRequest(style=style),
|
||||
content_type="multipart/form-data",
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(response.id)
|
||||
|
||||
|
||||
class RecraftTextToImageNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -395,7 +466,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, max_length=1000)
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=1000)
|
||||
default_style = RecraftStyle(RecraftStyleV3.realistic_image)
|
||||
if recraft_style is None:
|
||||
recraft_style = default_style
|
||||
@ -1024,6 +1095,7 @@ class RecraftExtension(ComfyExtension):
|
||||
RecraftStyleV3DigitalIllustrationNode,
|
||||
RecraftStyleV3LogoRasterNode,
|
||||
RecraftStyleInfiniteStyleLibrary,
|
||||
RecraftCreateStyleNode,
|
||||
RecraftColorRGBNode,
|
||||
RecraftControlsNode,
|
||||
]
|
||||
|
||||
@ -3,6 +3,7 @@ import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
import logging
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
@ -14,7 +15,6 @@ from comfy_execution.graph_utils import is_link
|
||||
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def include_unique_id_in_input(class_type: str) -> bool:
|
||||
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
@ -23,7 +23,7 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
|
||||
class CacheKeySet(ABC):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
def __init__(self, dynprompt, node_ids, is_changed):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
@ -45,6 +45,12 @@ class CacheKeySet(ABC):
|
||||
|
||||
def get_subcache_key(self, node_id):
|
||||
return self.subcache_keys.get(node_id, None)
|
||||
|
||||
async def update_cache_key(self, node_id) -> None:
|
||||
pass
|
||||
|
||||
def is_key_updated(self, node_id) -> bool:
|
||||
return True
|
||||
|
||||
class Unhashable:
|
||||
def __init__(self):
|
||||
@ -62,10 +68,21 @@ def to_hashable(obj):
|
||||
else:
|
||||
# TODO - Support other objects like tensors?
|
||||
return Unhashable()
|
||||
|
||||
def throw_on_unhashable(obj):
|
||||
# Same as to_hashable except throwing for unhashables instead.
|
||||
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(throw_on_unhashable(k), throw_on_unhashable(v)) for k, v in sorted(obj.items())])
|
||||
elif isinstance(obj, Sequence):
|
||||
return frozenset(zip(itertools.count(), [throw_on_unhashable(i) for i in obj]))
|
||||
else:
|
||||
raise Exception("Object unhashable.")
|
||||
|
||||
class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
def __init__(self, dynprompt, node_ids, is_changed):
|
||||
super().__init__(dynprompt, node_ids, is_changed)
|
||||
self.dynprompt = dynprompt
|
||||
|
||||
async def add_keys(self, node_ids):
|
||||
@ -79,13 +96,28 @@ class CacheKeySetID(CacheKeySet):
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
class CacheKeySetInputSignature(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
def __init__(self, dynprompt, node_ids, is_changed):
|
||||
super().__init__(dynprompt, node_ids, is_changed)
|
||||
self.dynprompt: DynamicPrompt = dynprompt
|
||||
self.is_changed = is_changed
|
||||
|
||||
self.updated_node_ids = set()
|
||||
self.node_sig_cache = {}
|
||||
self.ancestry_cache = {}
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
return False
|
||||
|
||||
async def update_cache_key(self, node_id):
|
||||
if node_id in self.updated_node_ids:
|
||||
return
|
||||
if node_id not in self.keys:
|
||||
return
|
||||
self.updated_node_ids.add(node_id)
|
||||
self.keys[node_id] = await self.get_node_signature(node_id)
|
||||
|
||||
def is_key_updated(self, node_id):
|
||||
return node_id in self.updated_node_ids
|
||||
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
@ -94,57 +126,105 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
|
||||
self.keys[node_id] = None
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
async def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
async def get_node_signature(self, node_id):
|
||||
signatures = []
|
||||
ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id)
|
||||
self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs))
|
||||
signatures.append(self.node_sig_cache[node_id])
|
||||
|
||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
for ancestor_id in ancestors:
|
||||
assert ancestor_id in self.node_sig_cache
|
||||
signatures.append(self.node_sig_cache[ancestor_id])
|
||||
|
||||
signatures = frozenset(zip(itertools.count(), signatures))
|
||||
logging.debug(f"signature for node {node_id}: {signatures}")
|
||||
return signatures
|
||||
|
||||
async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict):
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
node = dynprompt.get_node(node_id)
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
|
||||
signature = [class_type, await self.is_changed.get(node_id)]
|
||||
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
(ancestor_id, ancestor_socket) = inputs[key]
|
||||
input = inputs[key]
|
||||
if is_link(input):
|
||||
(ancestor_id, ancestor_socket) = input
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, inputs[key]))
|
||||
signature.append((key, input))
|
||||
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
|
||||
return signature
|
||||
|
||||
def get_ordered_ancestry(self, node_id):
|
||||
def get_ancestors(ancestors, ret: list=[]):
|
||||
for ancestor_id in ancestors:
|
||||
if ancestor_id not in ret:
|
||||
ret.append(ancestor_id)
|
||||
get_ancestors(self.ancestry_cache[ancestor_id], ret)
|
||||
return ret
|
||||
|
||||
ancestors, node_inputs = self.get_ordered_ancestry_internal(node_id)
|
||||
ancestors = get_ancestors(ancestors)
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
# deterministic based on which specific inputs the ancestor is connected by.
|
||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||
ancestors = []
|
||||
order_mapping = {}
|
||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||
return ancestors, order_mapping
|
||||
for i, ancestor_id in enumerate(ancestors):
|
||||
order_mapping[ancestor_id] = i
|
||||
|
||||
return ancestors, order_mapping, node_inputs
|
||||
|
||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
return
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
input_keys = sorted(inputs.keys())
|
||||
for key in input_keys:
|
||||
if is_link(inputs[key]):
|
||||
ancestor_id = inputs[key][0]
|
||||
if ancestor_id not in order_mapping:
|
||||
ancestors.append(ancestor_id)
|
||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
def get_ordered_ancestry_internal(self, node_id):
|
||||
def get_hashable(obj):
|
||||
try:
|
||||
return throw_on_unhashable(obj)
|
||||
except:
|
||||
return Unhashable
|
||||
|
||||
ancestors = []
|
||||
node_inputs = {}
|
||||
|
||||
if node_id in self.ancestry_cache:
|
||||
return self.ancestry_cache[node_id], node_inputs
|
||||
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
return ancestors, node_inputs
|
||||
|
||||
input_data_all, _, _ = self.is_changed.get_input_data(node_id)
|
||||
inputs = self.dynprompt.get_node(node_id)["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if key in input_data_all:
|
||||
if is_link(inputs[key]):
|
||||
ancestor_id = inputs[key][0]
|
||||
hashable = get_hashable(input_data_all[key])
|
||||
if hashable is Unhashable or is_link(input_data_all[key][0]):
|
||||
# Link still needed
|
||||
node_inputs[key] = inputs[key]
|
||||
if ancestor_id not in ancestors:
|
||||
ancestors.append(ancestor_id)
|
||||
else:
|
||||
# Replace link
|
||||
node_inputs[key] = input_data_all[key]
|
||||
else:
|
||||
hashable = get_hashable(inputs[key])
|
||||
if hashable is Unhashable:
|
||||
logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}")
|
||||
node_inputs[key] = Unhashable()
|
||||
else:
|
||||
node_inputs[key] = inputs[key]
|
||||
|
||||
self.ancestry_cache[node_id] = ancestors
|
||||
return self.ancestry_cache[node_id], node_inputs
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class):
|
||||
@ -155,11 +235,16 @@ class BasicCache:
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.node_sig_cache = {}
|
||||
self.ancestry_cache = {}
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed)
|
||||
self.cache_key_set.node_sig_cache = self.node_sig_cache
|
||||
self.cache_key_set.ancestry_cache = self.ancestry_cache
|
||||
await self.cache_key_set.add_keys(node_ids)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.is_changed = is_changed
|
||||
self.initialized = True
|
||||
|
||||
def all_node_ids(self):
|
||||
@ -190,22 +275,31 @@ class BasicCache:
|
||||
|
||||
def clean_unused(self):
|
||||
assert self.initialized
|
||||
self.node_sig_cache.clear()
|
||||
self.ancestry_cache.clear()
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def _update_cache_key_immediate(self, node_id):
|
||||
await self.cache_key_set.update_cache_key(node_id)
|
||||
|
||||
def _is_key_updated_immediate(self, node_id):
|
||||
return self.cache_key_set.is_key_updated(node_id)
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
if cache_key is not None:
|
||||
self.cache[cache_key] = value
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
if cache_key is not None and cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
return None
|
||||
@ -215,8 +309,10 @@ class BasicCache:
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
subcache.node_sig_cache = self.node_sig_cache
|
||||
subcache.ancestry_cache = self.ancestry_cache
|
||||
self.subcaches[subcache_key] = subcache
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
@ -272,10 +368,19 @@ class HierarchicalCache(BasicCache):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return await cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
async def update_cache_key(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
await cache._update_cache_key_immediate(node_id)
|
||||
|
||||
def is_key_updated(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._is_key_updated_immediate(node_id)
|
||||
|
||||
class NullCache:
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed):
|
||||
pass
|
||||
|
||||
def all_node_ids(self):
|
||||
@ -295,6 +400,12 @@ class NullCache:
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
return self
|
||||
|
||||
async def update_cache_key(self, node_id):
|
||||
pass
|
||||
|
||||
def is_key_updated(self, node_id):
|
||||
return True
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
@ -305,13 +416,15 @@ class LRUCache(BasicCache):
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed):
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
|
||||
def clean_unused(self):
|
||||
self.node_sig_cache.clear()
|
||||
self.ancestry_cache.clear()
|
||||
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
||||
self.min_generation += 1
|
||||
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
||||
@ -347,6 +460,14 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(child_id)
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
async def update_cache_key(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
await self._update_cache_key_immediate(node_id)
|
||||
|
||||
def is_key_updated(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
return self._is_key_updated_immediate(node_id)
|
||||
|
||||
|
||||
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||
@ -365,12 +486,13 @@ RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class, 0)
|
||||
self.timestamps = {}
|
||||
|
||||
def clean_unused(self):
|
||||
self.node_sig_cache.clear()
|
||||
self.ancestry_cache.clear()
|
||||
self._clean_subcaches()
|
||||
|
||||
def set(self, node_id, value):
|
||||
|
||||
119
execution.py
119
execution.py
@ -36,7 +36,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
|
||||
from server import PromptServer
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@ -46,49 +46,40 @@ class ExecutionResult(Enum):
|
||||
class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
class IsChangedCache:
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
|
||||
class IsChanged:
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, execution_list: ExecutionList|None=None, extra_data: dict={}):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
async def get(self, node_id):
|
||||
if node_id in self.is_changed:
|
||||
return self.is_changed[node_id]
|
||||
self.execution_list = execution_list
|
||||
self.extra_data = extra_data
|
||||
|
||||
def get_input_data(self, node_id):
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return get_input_data(node["inputs"], class_def, node_id, self.execution_list, self.dynprompt, self.extra_data)
|
||||
|
||||
async def get(self, node_id):
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
has_is_changed = False
|
||||
is_changed_name = None
|
||||
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||
has_is_changed = True
|
||||
is_changed_name = "fingerprint_inputs"
|
||||
elif hasattr(class_def, "IS_CHANGED"):
|
||||
has_is_changed = True
|
||||
is_changed_name = "IS_CHANGED"
|
||||
if not has_is_changed:
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
if is_changed_name is None:
|
||||
return False
|
||||
|
||||
if "is_changed" in node:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _, v3_data = self.get_input_data(node_id)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
is_changed = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
logging.warning("WARNING: {}".format(e))
|
||||
node["is_changed"] = float("NaN")
|
||||
finally:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
is_changed = float("NaN")
|
||||
return is_changed
|
||||
|
||||
class CacheEntry(NamedTuple):
|
||||
ui: dict
|
||||
@ -118,7 +109,7 @@ class CacheSet:
|
||||
else:
|
||||
self.init_classic_cache()
|
||||
|
||||
self.all = [self.outputs, self.objects]
|
||||
self.all: list[BasicCache, BasicCache] = [self.outputs, self.objects]
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
@ -406,7 +397,10 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||
async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheSet,
|
||||
current_item: str, extra_data: dict, executed: set, prompt_id: str,
|
||||
execution_list: ExecutionList, pending_subgraph_results: dict,
|
||||
pending_async_nodes: dict, ui_outputs: dict):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@ -414,16 +408,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
cached = caches.outputs.get(unique_id)
|
||||
if cached is not None:
|
||||
if server.client_id is not None:
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[unique_id] = cached.ui
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, cached)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
if caches.outputs.is_key_updated(unique_id):
|
||||
# Key is updated, the cache can be checked.
|
||||
cached = caches.outputs.get(unique_id)
|
||||
logging.debug(f"execute: {unique_id} cached: {cached is not None}")
|
||||
if cached is not None:
|
||||
if server.client_id is not None:
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("execution_cached", { "nodes": [unique_id], "prompt_id": prompt_id}, server.client_id)
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[unique_id] = cached.ui
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, cached)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
@ -464,11 +463,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
del pending_subgraph_results[unique_id]
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
if caches.outputs.is_key_updated(unique_id):
|
||||
# The key is updated, the node is executing.
|
||||
get_progress_state().start_progress(unique_id)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
obj = caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
@ -479,6 +481,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
|
||||
if lazy_status_present:
|
||||
# for check_lazy_status, the returned data should include the original key of the input
|
||||
v3_data_lazy = v3_data.copy()
|
||||
@ -494,6 +497,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
if not caches.outputs.is_key_updated(unique_id):
|
||||
# Update the cache key after any lazy inputs are evaluated.
|
||||
async def update_cache_key(node_id, unblock):
|
||||
await caches.outputs.update_cache_key(node_id)
|
||||
unblock()
|
||||
asyncio.create_task(update_cache_key(unique_id, execution_list.add_external_block(unique_id)))
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
@ -563,8 +574,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||
subcache.clean_unused()
|
||||
await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
execution_list.cache_link(node_id, unique_id)
|
||||
@ -689,25 +699,16 @@ class PromptExecutor:
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed)
|
||||
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
@ -745,7 +746,9 @@ class PromptExecutor:
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
for cache in self.caches.all:
|
||||
cache.clean_unused()
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
unique_id = item
|
||||
|
||||
@ -552,27 +552,50 @@ class TestExecution:
|
||||
assert len(images1) == 1, "Should have 1 image"
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
def test_is_changed_passed_cached_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
g = builder
|
||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||
|
||||
test_node = g.node("TestIsChangedWithAllInputs", image=input1.out(0), value=0.5)
|
||||
output = g.node("PreviewImage", images=test_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
result1 = client.run(g)
|
||||
images = result1.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
result2 = client.run(g)
|
||||
images = result2.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
|
||||
if server["should_cache_results"]:
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
assert not result2.did_run(test_node), "Test node should not have run again"
|
||||
else:
|
||||
assert result.did_run(test_node), "The execution should have been re-run"
|
||||
assert result2.did_run(test_node), "Test node should always run here"
|
||||
|
||||
def test_dont_always_run_downstream(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
g = builder
|
||||
float1 = g.node("TestDontAlwaysRunDownstream", float=0.5) # IS_CHANGED returns float("NaN")
|
||||
image1 = g.node("StubConstantImage", value=float1.out(0), height=512, width=512, batch_size=1)
|
||||
output = g.node("PreviewImage", images=image1.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
images = result1.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50"
|
||||
|
||||
result2 = client.run(g)
|
||||
images = result2.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50"
|
||||
|
||||
assert result2.did_run(float1), "Float node should always run"
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(image1), "Image node should not have run again"
|
||||
assert not result2.did_run(output), "Output node should not have run again"
|
||||
else:
|
||||
assert result2.did_run(image1), "Image node should have run again"
|
||||
assert result2.did_run(output), "Output node should have run again"
|
||||
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
|
||||
@ -100,7 +100,7 @@ class TestCustomIsChanged:
|
||||
else:
|
||||
return False
|
||||
|
||||
class TestIsChangedWithConstants:
|
||||
class TestIsChangedWithAllInputs:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
@ -120,10 +120,29 @@ class TestIsChangedWithConstants:
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, image, value):
|
||||
if image is None:
|
||||
return value
|
||||
else:
|
||||
return image.mean().item() * value
|
||||
# if image is None then an exception is thrown and is_changed becomes float("NaN")
|
||||
return image.mean().item() * value
|
||||
|
||||
class TestDontAlwaysRunDownstream:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"float": ("FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "always_run"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def always_run(self, float):
|
||||
return (float,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, *args, **kwargs):
|
||||
return float("NaN")
|
||||
|
||||
class TestCustomValidation1:
|
||||
@classmethod
|
||||
@ -486,7 +505,8 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
"TestCustomIsChanged": TestCustomIsChanged,
|
||||
"TestIsChangedWithConstants": TestIsChangedWithConstants,
|
||||
"TestIsChangedWithAllInputs": TestIsChangedWithAllInputs,
|
||||
"TestDontAlwaysRunDownstream": TestDontAlwaysRunDownstream,
|
||||
"TestCustomValidation1": TestCustomValidation1,
|
||||
"TestCustomValidation2": TestCustomValidation2,
|
||||
"TestCustomValidation3": TestCustomValidation3,
|
||||
@ -504,7 +524,8 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestLazyMixImages": "Lazy Mix Images",
|
||||
"TestVariadicAverage": "Variadic Average",
|
||||
"TestCustomIsChanged": "Custom IsChanged",
|
||||
"TestIsChangedWithConstants": "IsChanged With Constants",
|
||||
"TestIsChangedWithAllInputs": "IsChanged With All Inputs",
|
||||
"TestDontAlwaysRunDownstream": "Dont Always Run Downstream",
|
||||
"TestCustomValidation1": "Custom Validation 1",
|
||||
"TestCustomValidation2": "Custom Validation 2",
|
||||
"TestCustomValidation3": "Custom Validation 3",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user