mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-16 03:57:27 +08:00
Compare commits
9 Commits
ab342e6f51
...
c7896dd75d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7896dd75d | ||
|
|
0a7d2ffd68 | ||
|
|
20e439419c | ||
|
|
428c323780 | ||
|
|
28d538ddf9 | ||
|
|
e326b41d62 | ||
|
|
aa9e7a84bc | ||
|
|
9b7a2a3248 | ||
|
|
5b913f0377 |
@ -97,12 +97,14 @@ def load_lora(lora, to_load, log_missing=True):
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
prefix_set = set()
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
||||
if tp > 0 and not k.startswith("clip_"):
|
||||
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
||||
prefix_set.add(k.split('.')[0])
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
@ -163,6 +165,13 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
if len(prefix_set) == 1:
|
||||
full_prefix = "{}.transformer.model.".format(next(iter(prefix_set))) # kohya anima and maybe other single TE models that use a single llama arch based te
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith(full_prefix):
|
||||
l_key = k[len(full_prefix):-len(".weight")]
|
||||
key_map["lora_te_{}".format(l_key.replace(".", "_"))] = k
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
|
||||
@ -242,6 +242,37 @@ class LazyCastingParam(torch.nn.Parameter):
|
||||
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
||||
|
||||
|
||||
class LazyCastingQuantizedParam:
|
||||
def __init__(self, model, key):
|
||||
self.model = model
|
||||
self.key = key
|
||||
self.cpu_state_dict = None
|
||||
|
||||
def state_dict_tensor(self, state_dict_key):
|
||||
if self.cpu_state_dict is None:
|
||||
weight = self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True)
|
||||
self.cpu_state_dict = {k: v.to("cpu") for k, v in weight.state_dict(self.key).items()}
|
||||
return self.cpu_state_dict[state_dict_key]
|
||||
|
||||
|
||||
class LazyCastingParamPiece(torch.nn.Parameter):
|
||||
def __new__(cls, caster, state_dict_key, tensor):
|
||||
return super().__new__(cls, tensor)
|
||||
|
||||
def __init__(self, caster, state_dict_key, tensor):
|
||||
self.caster = caster
|
||||
self.state_dict_key = state_dict_key
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return CustomTorchDevice
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
caster = self.caster
|
||||
del self.caster
|
||||
return caster.state_dict_tensor(self.state_dict_key)
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@ -1463,20 +1494,37 @@ class ModelPatcher:
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model.diffusion_model.state_dict()
|
||||
for k, v in unet_state_dict.items():
|
||||
original_state_dict = self.model.diffusion_model.state_dict()
|
||||
unet_state_dict = {}
|
||||
keys = list(original_state_dict)
|
||||
while len(keys) > 0:
|
||||
k = keys.pop(0)
|
||||
v = original_state_dict[k]
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
except:
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
|
||||
qt_state_dict = weight.state_dict(k)
|
||||
caster = LazyCastingQuantizedParam(self, key)
|
||||
for group_key in (x for x in qt_state_dict if x in original_state_dict):
|
||||
if group_key in keys:
|
||||
keys.remove(group_key)
|
||||
unet_state_dict.pop(group_key, "")
|
||||
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
|
||||
continue
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@ -327,11 +327,14 @@ class String(ComfyTypeIO):
|
||||
'''String input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None,
|
||||
min_length: int=None, max_length: int=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.dynamic_prompts = dynamic_prompts
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.default: str
|
||||
|
||||
def as_dict(self):
|
||||
@ -339,6 +342,8 @@ class String(ComfyTypeIO):
|
||||
"multiline": self.multiline,
|
||||
"placeholder": self.placeholder,
|
||||
"dynamicPrompts": self.dynamic_prompts,
|
||||
"minLength": self.min_length,
|
||||
"maxLength": self.max_length,
|
||||
})
|
||||
|
||||
@comfytype(io_type="COMBO")
|
||||
|
||||
@ -27,6 +27,7 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
downscale_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
@ -372,6 +373,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -640,6 +642,316 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
def _gpt_image_shared_inputs():
|
||||
"""Inputs shared by all GPT Image models (quality + reference images + mask)."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"quality",
|
||||
default="low",
|
||||
options=["low", "medium", "high"],
|
||||
tooltip="Image quality, affects cost and generation time.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 17)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional reference image(s) for image editing. Up to 16 images.",
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
optional=True,
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced). "
|
||||
"Requires exactly one reference image.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _gpt_image_legacy_model_inputs():
|
||||
"""Per-model widget set for legacy gpt-image-1 / gpt-image-1.5 (4 base sizes, transparent bg allowed)."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
tooltip="Image size.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="auto",
|
||||
options=["auto", "opaque", "transparent"],
|
||||
tooltip="Return image with or without background.",
|
||||
),
|
||||
*_gpt_image_shared_inputs(),
|
||||
]
|
||||
|
||||
|
||||
class OpenAIGPTImageNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImageNodeV2",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Text prompt for GPT Image",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"gpt-image-2",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=[
|
||||
"auto",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
"2048x2048",
|
||||
"2048x1152",
|
||||
"1152x2048",
|
||||
"3840x2160",
|
||||
"2160x3840",
|
||||
"Custom",
|
||||
],
|
||||
tooltip="Image size. Select 'Custom' to use the custom width and height.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_width",
|
||||
default=1024,
|
||||
min=1024,
|
||||
max=3840,
|
||||
step=16,
|
||||
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_height",
|
||||
default=1024,
|
||||
min=1024,
|
||||
max=3840,
|
||||
step=16,
|
||||
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="auto",
|
||||
options=["auto", "opaque"],
|
||||
tooltip="Return image with or without background.",
|
||||
),
|
||||
*_gpt_image_shared_inputs(),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("gpt-image-1.5", _gpt_image_legacy_model_inputs()),
|
||||
IO.DynamicCombo.Option("gpt-image-1", _gpt_image_legacy_model_inputs()),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="How many images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="not implemented yet in backend",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.quality", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$ranges := {
|
||||
"gpt-image-1": {
|
||||
"low": [0.011, 0.02],
|
||||
"medium": [0.042, 0.07],
|
||||
"high": [0.167, 0.25]
|
||||
},
|
||||
"gpt-image-1.5": {
|
||||
"low": [0.009, 0.02],
|
||||
"medium": [0.034, 0.062],
|
||||
"high": [0.133, 0.22]
|
||||
},
|
||||
"gpt-image-2": {
|
||||
"low": [0.0048, 0.019],
|
||||
"medium": [0.041, 0.168],
|
||||
"high": [0.165, 0.67]
|
||||
}
|
||||
};
|
||||
$range := $lookup($lookup($ranges, widgets.model), $lookup(widgets, "model.quality"));
|
||||
$nRaw := widgets.n;
|
||||
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
|
||||
($n = 1)
|
||||
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
|
||||
: {
|
||||
"type":"range_usd",
|
||||
"min_usd": $range[0] * $n,
|
||||
"max_usd": $range[1] * $n,
|
||||
"format": { "suffix": "/Run", "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
model_id = model["model"]
|
||||
size = model["size"]
|
||||
background = model["background"]
|
||||
quality = model["quality"]
|
||||
custom_width = model.get("custom_width", 1024)
|
||||
custom_height = model.get("custom_height", 1024)
|
||||
|
||||
images_dict = model.get("images") or {}
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
mask = model.get("mask")
|
||||
|
||||
if mask is not None and n_images == 0:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
if size == "Custom":
|
||||
if custom_width % 16 != 0 or custom_height % 16 != 0:
|
||||
raise ValueError(
|
||||
f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}"
|
||||
)
|
||||
if max(custom_width, custom_height) > 3840:
|
||||
raise ValueError(
|
||||
f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}"
|
||||
)
|
||||
ratio = max(custom_width, custom_height) / min(custom_width, custom_height)
|
||||
if ratio > 3:
|
||||
raise ValueError(
|
||||
f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}"
|
||||
)
|
||||
total_pixels = custom_width * custom_height
|
||||
if not 655_360 <= total_pixels <= 8_294_400:
|
||||
raise ValueError(
|
||||
f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}"
|
||||
)
|
||||
size = f"{custom_width}x{custom_height}"
|
||||
|
||||
if model_id == "gpt-image-1":
|
||||
price_extractor = calculate_tokens_price_image_1
|
||||
elif model_id == "gpt-image-1.5":
|
||||
price_extractor = calculate_tokens_price_image_1_5
|
||||
elif model_id == "gpt-image-2":
|
||||
price_extractor = calculate_tokens_price_image_2_0
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model_id}")
|
||||
|
||||
if image_tensors:
|
||||
flat: list[torch.Tensor] = []
|
||||
for tensor in image_tensors:
|
||||
if len(tensor.shape) == 4:
|
||||
flat.extend(tensor[i : i + 1] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat.append(tensor.unsqueeze(0))
|
||||
|
||||
files = []
|
||||
for i, single_image in enumerate(flat):
|
||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
|
||||
if len(flat) == 1:
|
||||
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
else:
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if len(flat) != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
ref_image = flat[0]
|
||||
if mask.shape[1:] != ref_image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
_, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
scaled_mask = downscale_image_tensor(
|
||||
rgba_mask.unsqueeze(0), total_pixels=2048 * 2048
|
||||
).squeeze()
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageEditRequest(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
files=files,
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
else:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageGenerationRequest(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
class OpenAIChatNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from an OpenAI model.
|
||||
@ -999,6 +1311,7 @@ class OpenAIExtension(ComfyExtension):
|
||||
OpenAIDalle2,
|
||||
OpenAIDalle3,
|
||||
OpenAIGPTImage1,
|
||||
OpenAIGPTImageNodeV2,
|
||||
OpenAIChatNode,
|
||||
OpenAIInputFiles,
|
||||
OpenAIChatConfig,
|
||||
|
||||
66
execution.py
66
execution.py
@ -83,7 +83,7 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
@ -215,7 +215,35 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||
return input_data_all, missing_keys, v3_data
|
||||
return input_data_all, missing_keys, v3_data, valid_inputs
|
||||
|
||||
def validate_resolved_inputs(input_data_all, class_def, valid_inputs):
|
||||
"""Validate resolved input values against schema constraints.
|
||||
|
||||
This is needed because validate_inputs() only sees direct widget values.
|
||||
Linked inputs aren't resolved during validate_inputs(), so this runs after resolution to catch any violations.
|
||||
"""
|
||||
for x, values in input_data_all.items():
|
||||
input_type, input_category, extra_info = get_input_info(class_def, x, valid_inputs)
|
||||
if input_type != "STRING":
|
||||
continue
|
||||
min_length = extra_info.get("minLength")
|
||||
max_length = extra_info.get("maxLength")
|
||||
if min_length is None and max_length is None:
|
||||
continue
|
||||
for val in values:
|
||||
if val is None or not isinstance(val, str):
|
||||
continue
|
||||
if min_length is not None and len(val) < min_length:
|
||||
raise ValueError(
|
||||
f"Input '{x}': value length {len(val)} is shorter than "
|
||||
f"minimum length of {min_length}"
|
||||
)
|
||||
if max_length is not None and len(val) > max_length:
|
||||
raise ValueError(
|
||||
f"Input '{x}': value length {len(val)} is longer than "
|
||||
f"maximum length of {max_length}"
|
||||
)
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
@ -480,7 +508,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@ -509,6 +537,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
validate_resolved_inputs(input_data_all, class_def, valid_inputs)
|
||||
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
@ -1014,6 +1044,34 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if input_type == "STRING":
|
||||
if "minLength" in extra_info and len(val) < extra_info["minLength"]:
|
||||
error = {
|
||||
"type": "value_shorter_than_min_length",
|
||||
"message": "Value length {} shorter than min length of {}".format(len(val), extra_info["minLength"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "maxLength" in extra_info and len(val) > extra_info["maxLength"]:
|
||||
error = {
|
||||
"type": "value_longer_than_max_length",
|
||||
"message": "Value length {} longer than max length of {}".format(len(val), extra_info["maxLength"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||
if input_type == io.Combo.io_type:
|
||||
combo_options = extra_info.get("options", [])
|
||||
@ -1050,7 +1108,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
|
||||
@ -1011,3 +1011,49 @@ class TestExecution:
|
||||
"""Test getting a non-existent job returns 404"""
|
||||
job = client.get_job("nonexistent-job-id")
|
||||
assert job is None, "Non-existent job should return None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text, expect_error", [
|
||||
("hello", False), # 5 chars, within [3, 10]
|
||||
("abc", False), # 3 chars, exact min boundary
|
||||
("abcdefghij", False), # 10 chars, exact max boundary
|
||||
("ab", True), # 2 chars, below min
|
||||
("abcdefghijk", True), # 11 chars, above max
|
||||
("", True), # 0 chars, below min
|
||||
])
|
||||
def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test minLength/maxLength validation for direct widget values (validate_inputs path)."""
|
||||
g = builder
|
||||
node = g.node("StubStringWithLength", text=text)
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
if expect_error:
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.run(g)
|
||||
assert exc_info.value.code == 400
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text, expect_error", [
|
||||
("hello", False), # 5 chars, within [3, 10]
|
||||
("abc", False), # 3 chars, exact min boundary
|
||||
("abcdefghij", False), # 10 chars, exact max boundary
|
||||
("ab", True), # 2 chars, below min
|
||||
("abcdefghijk", True), # 11 chars, above max
|
||||
("", True), # 0 chars, below min
|
||||
])
|
||||
def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test minLength/maxLength validation for linked inputs (validate_resolved_inputs path)."""
|
||||
g = builder
|
||||
str_node = g.node("StubStringOutput", value=text)
|
||||
node = g.node("StubStringWithLength", text=str_node.out(0))
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
|
||||
if expect_error:
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
@ -113,12 +113,48 @@ class StubFloat:
|
||||
def stub_float(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubStringOutput:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {"default": ""}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "stub_string"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_string(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubStringWithLength:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_string_with_length"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_string_with_length(self, text):
|
||||
return (torch.zeros(1, 64, 64, 3),)
|
||||
|
||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubConstantImage": StubConstantImage,
|
||||
"StubMask": StubMask,
|
||||
"StubInt": StubInt,
|
||||
"StubFloat": StubFloat,
|
||||
"StubStringOutput": StubStringOutput,
|
||||
"StubStringWithLength": StubStringWithLength,
|
||||
}
|
||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubImage": "Stub Image",
|
||||
@ -126,4 +162,6 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubMask": "Stub Mask",
|
||||
"StubInt": "Stub Int",
|
||||
"StubFloat": "Stub Float",
|
||||
"StubStringOutput": "Stub String Output",
|
||||
"StubStringWithLength": "Stub String With Length",
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user