mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-29 18:37:23 +08:00
Compare commits
23 Commits
1f0611a393
...
19e26e8f8d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19e26e8f8d | ||
|
|
c168960a12 | ||
|
|
e5369c0eec | ||
|
|
1655f8089a | ||
|
|
89014792c9 | ||
|
|
431fadb520 | ||
|
|
1ac60da2c9 | ||
|
|
41d73ad180 | ||
|
|
ea6880b04b | ||
|
|
639f631a08 | ||
|
|
d794b62939 | ||
|
|
6917bce128 | ||
|
|
c55ff85243 | ||
|
|
de97192962 | ||
|
|
d56a093800 | ||
|
|
8dd41ef82e | ||
|
|
b715186140 | ||
|
|
ae54d7a987 | ||
|
|
f3aebfa2b0 | ||
|
|
b3a066559b | ||
|
|
0b7d56070d | ||
|
|
c3cd2a4e75 | ||
|
|
800bf842a5 |
31
.github/workflows/openapi-lint.yml
vendored
Normal file
31
.github/workflows/openapi-lint.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: OpenAPI Lint
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'openapi.yaml'
|
||||
- '.spectral.yaml'
|
||||
- '.github/workflows/openapi-lint.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
spectral:
|
||||
name: Run Spectral
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install Spectral
|
||||
run: npm install -g @stoplight/spectral-cli@6
|
||||
|
||||
- name: Lint openapi.yaml
|
||||
run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error
|
||||
91
.spectral.yaml
Normal file
91
.spectral.yaml
Normal file
@ -0,0 +1,91 @@
|
||||
extends:
|
||||
- spectral:oas
|
||||
|
||||
# Severity levels: error, warn, info, hint, off
|
||||
# Rules from the built-in "spectral:oas" ruleset are active by default.
|
||||
# Below we tune severity and add custom rules for our conventions.
|
||||
#
|
||||
# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the
|
||||
# organization are linted against a single consistent standard.
|
||||
|
||||
rules:
|
||||
# -----------------------------------------------------------------------
|
||||
# Built-in rule severity overrides
|
||||
# -----------------------------------------------------------------------
|
||||
operation-operationId: error
|
||||
operation-description: warn
|
||||
operation-tag-defined: error
|
||||
info-contact: off
|
||||
info-description: warn
|
||||
no-eval-in-markdown: error
|
||||
no-$ref-siblings: error
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: naming conventions
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Property names should be snake_case
|
||||
property-name-snake-case:
|
||||
description: Property names must be snake_case
|
||||
severity: warn
|
||||
given: "$.components.schemas.*.properties[*]~"
|
||||
then:
|
||||
function: pattern
|
||||
functionOptions:
|
||||
match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$"
|
||||
|
||||
# Operation IDs should be camelCase
|
||||
operation-id-camel-case:
|
||||
description: Operation IDs must be camelCase
|
||||
severity: warn
|
||||
given: "$.paths.*.*.operationId"
|
||||
then:
|
||||
function: pattern
|
||||
functionOptions:
|
||||
match: "^[a-z][a-zA-Z0-9]*$"
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: response conventions
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Error responses (4xx, 5xx) should use a consistent shape
|
||||
error-response-schema:
|
||||
description: Error responses should reference a standard error schema
|
||||
severity: hint
|
||||
given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema"
|
||||
then:
|
||||
field: "$ref"
|
||||
function: truthy
|
||||
|
||||
# All 2xx responses with JSON body should have a schema
|
||||
response-schema-defined:
|
||||
description: Success responses with JSON content should define a schema
|
||||
severity: warn
|
||||
given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']"
|
||||
then:
|
||||
field: schema
|
||||
function: truthy
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: best practices
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Path parameters must have a description
|
||||
path-param-description:
|
||||
description: Path parameters should have a description
|
||||
severity: warn
|
||||
given:
|
||||
- "$.paths.*.parameters[?(@.in == 'path')]"
|
||||
- "$.paths.*.*.parameters[?(@.in == 'path')]"
|
||||
then:
|
||||
field: description
|
||||
function: truthy
|
||||
|
||||
# Schemas should have a description
|
||||
schema-description:
|
||||
description: Component schemas should have a description
|
||||
severity: hint
|
||||
given: "$.components.schemas.*"
|
||||
then:
|
||||
field: description
|
||||
function: truthy
|
||||
@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
indices = self.index_list
|
||||
anchor_idx = getattr(self, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
indices = [anchor_idx] + list(indices)
|
||||
idx = tuple([slice(None)] * dim + [indices])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
# anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
|
||||
skip_count = temporal_offset - 1
|
||||
else:
|
||||
skip_count = temporal_offset
|
||||
|
||||
indices = [i - temporal_offset for i in window.index_list[skip_count:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
@ -150,7 +161,8 @@ class ContextFuseMethod:
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||
causal_window_fix: bool=True):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
self.causal_window_fix = causal_window_fix
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
||||
anchor_applied = False
|
||||
if self.causal_window_fix:
|
||||
anchor_idx = window.index_list[0] - 1
|
||||
if 0 <= anchor_idx < x_in.size(self.dim):
|
||||
window.causal_anchor_index = anchor_idx
|
||||
anchor_applied = True
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
|
||||
# strip causal_window_fix anchor if applied
|
||||
if anchor_applied:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
||||
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ class LatentFormat:
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
temporal_downscale_ratio = 1
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@ -235,6 +236,7 @@ class Flux2(LatentFormat):
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 6
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
@ -278,6 +280,7 @@ class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 32
|
||||
temporal_downscale_ratio = 8
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
@ -421,6 +424,7 @@ class LTXAV(LTXV):
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
scale_factor = 0.476986
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
@ -447,6 +451,7 @@ class HunyuanVideo(LatentFormat):
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 8
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.1817, 0.2284, 0.2423],
|
||||
@ -472,6 +477,7 @@ class Cosmos1CV8x8x8(LatentFormat):
|
||||
class Wan21(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.1299, -0.1692, 0.2932],
|
||||
@ -734,6 +740,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
temporal_downscale_ratio = 4
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
@ -788,6 +795,7 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
class CogVideoX(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.15258426
|
||||
|
||||
@ -395,7 +395,6 @@ class Combo(ComfyTypeIO):
|
||||
@comfytype(io_type="COMBO")
|
||||
class MultiCombo(ComfyTypeI):
|
||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||
# TODO: something is wrong with the serialization, frontend does not recognize it as multiselect
|
||||
Type = list[str]
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
@ -408,12 +407,14 @@ class MultiCombo(ComfyTypeI):
|
||||
self.default: list[str]
|
||||
|
||||
def as_dict(self):
|
||||
to_return = super().as_dict() | prune_dict({
|
||||
"multi_select": self.multiselect,
|
||||
"placeholder": self.placeholder,
|
||||
"chip": self.chip,
|
||||
# Frontend expects `multi_select` to be an object config (not a boolean).
|
||||
# Keep top-level `multiselect` from Combo.Input for backwards compatibility.
|
||||
return super().as_dict() | prune_dict({
|
||||
"multi_select": prune_dict({
|
||||
"placeholder": self.placeholder,
|
||||
"chip": self.chip,
|
||||
}),
|
||||
})
|
||||
return to_return
|
||||
|
||||
@comfytype(io_type="IMAGE")
|
||||
class Image(ComfyTypeIO):
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
|
||||
class LumaIO:
|
||||
LUMA_REF = "LUMA_REF"
|
||||
LUMA_CONCEPTS = "LUMA_CONCEPTS"
|
||||
@ -183,13 +180,13 @@ class LumaAssets(BaseModel):
|
||||
|
||||
|
||||
class LumaImageRef(BaseModel):
|
||||
'''Used for image gen'''
|
||||
"""Used for image gen"""
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
|
||||
class LumaImageReference(BaseModel):
|
||||
'''Used for video gen'''
|
||||
"""Used for video gen"""
|
||||
type: Optional[str] = Field('image', description='Input type, defaults to image')
|
||||
url: str = Field(..., description='The URL of the image')
|
||||
|
||||
@ -251,3 +248,32 @@ class LumaGeneration(BaseModel):
|
||||
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
|
||||
model: str = Field(..., description='The model used for the generation')
|
||||
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
|
||||
|
||||
|
||||
class Luma2ImageRef(BaseModel):
|
||||
url: str | None = None
|
||||
data: str | None = None
|
||||
media_type: str | None = None
|
||||
|
||||
|
||||
class Luma2GenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., min_length=1, max_length=6000)
|
||||
model: str | None = None
|
||||
type: str | None = None
|
||||
aspect_ratio: str | None = None
|
||||
style: str | None = None
|
||||
output_format: str | None = None
|
||||
web_search: bool | None = None
|
||||
image_ref: list[Luma2ImageRef] | None = None
|
||||
source: Luma2ImageRef | None = None
|
||||
|
||||
|
||||
class Luma2Generation(BaseModel):
|
||||
id: str | None = None
|
||||
type: str | None = None
|
||||
state: str | None = None
|
||||
model: str | None = None
|
||||
created_at: str | None = None
|
||||
output: list[LumaImageReference] | None = None
|
||||
failure_reason: str | None = None
|
||||
failure_code: str | None = None
|
||||
|
||||
@ -56,14 +56,14 @@ class ModelResponseProperties(BaseModel):
|
||||
instructions: str | None = Field(None)
|
||||
max_output_tokens: int | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0)
|
||||
temperature: float | None = Field(None, description="Controls randomness in the response", ge=0.0, le=2.0)
|
||||
top_p: float | None = Field(
|
||||
1,
|
||||
None,
|
||||
description="Controls diversity of the response via nucleus sampling",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'")
|
||||
truncation: str | None = Field(None, description="Allowed values: 'auto' or 'disabled'")
|
||||
|
||||
|
||||
class ResponseProperties(BaseModel):
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.luma import (
|
||||
Luma2Generation,
|
||||
Luma2GenerationRequest,
|
||||
Luma2ImageRef,
|
||||
LumaAspectRatio,
|
||||
LumaCharacterRef,
|
||||
LumaConceptChain,
|
||||
@ -30,6 +31,7 @@ from comfy_api_nodes.util import (
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
@ -212,9 +214,9 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
seed,
|
||||
style_image_weight: float,
|
||||
image_luma_ref: Optional[LumaReferenceChain] = None,
|
||||
style_image: Optional[torch.Tensor] = None,
|
||||
character_image: Optional[torch.Tensor] = None,
|
||||
image_luma_ref: LumaReferenceChain | None = None,
|
||||
style_image: torch.Tensor | None = None,
|
||||
character_image: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
# handle image_luma_ref
|
||||
@ -434,7 +436,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: Optional[LumaConceptChain] = None,
|
||||
luma_concepts: LumaConceptChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
@ -533,7 +535,6 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -644,6 +645,293 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
def _luma2_uni1_common_inputs(max_image_refs: int) -> list:
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"style",
|
||||
options=["auto", "manga"],
|
||||
default="auto",
|
||||
tooltip="Style preset. 'auto' picks based on the prompt; "
|
||||
"'manga' applies a manga/anime aesthetic and requires a portrait "
|
||||
"aspect ratio (2:3, 9:16, 1:2, 1:3).",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"web_search",
|
||||
default=False,
|
||||
tooltip="Search the web for visual references before generating.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"image_ref",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_image_refs + 1)],
|
||||
min=0,
|
||||
),
|
||||
optional=True,
|
||||
tooltip=f"Up to {max_image_refs} reference images for style/content guidance.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def _luma2_upload_image_refs(
|
||||
cls: type[IO.ComfyNode],
|
||||
refs: dict | None,
|
||||
max_count: int,
|
||||
) -> list[Luma2ImageRef] | None:
|
||||
if not refs:
|
||||
return None
|
||||
out: list[Luma2ImageRef] = []
|
||||
for key in refs:
|
||||
url = await upload_image_to_comfyapi(cls, refs[key])
|
||||
out.append(Luma2ImageRef(url=url))
|
||||
if len(out) > max_count:
|
||||
raise ValueError(f"Maximum {max_count} reference images are allowed.")
|
||||
return out or None
|
||||
|
||||
|
||||
async def _luma2_submit_and_poll(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: Luma2GenerationRequest,
|
||||
) -> Input.Image:
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
|
||||
response_model=Luma2Generation,
|
||||
data=request,
|
||||
)
|
||||
if not initial.id:
|
||||
raise RuntimeError("Luma 2 API did not return a generation id.")
|
||||
final = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
|
||||
response_model=Luma2Generation,
|
||||
status_extractor=lambda r: r.state,
|
||||
progress_extractor=lambda r: None,
|
||||
)
|
||||
if not final.output:
|
||||
msg = final.failure_reason or "no output returned"
|
||||
raise RuntimeError(f"Luma 2 generation failed: {msg}")
|
||||
url = final.output[0].url
|
||||
if not url:
|
||||
raise RuntimeError("Luma 2 generation completed without an output URL.")
|
||||
return await download_url_to_image_tensor(url)
|
||||
|
||||
|
||||
class LumaImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode2",
|
||||
display_name="Luma UNI-1 Image",
|
||||
category="api node/image/Luma",
|
||||
description="Generate images from text using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired image. 1–6000 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"uni-1",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"auto",
|
||||
"3:1",
|
||||
"2:1",
|
||||
"16:9",
|
||||
"3:2",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"9:16",
|
||||
"1:2",
|
||||
"1:3",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="Output image aspect ratio. 'auto' lets "
|
||||
"the model pick based on the prompt.",
|
||||
),
|
||||
*_luma2_uni1_common_inputs(max_image_refs=9),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"uni-1-max",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"auto",
|
||||
"3:1",
|
||||
"2:1",
|
||||
"16:9",
|
||||
"3:2",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"9:16",
|
||||
"1:2",
|
||||
"1:3",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="Output image aspect ratio. 'auto' lets "
|
||||
"the model pick based on the prompt.",
|
||||
),
|
||||
*_luma2_uni1_common_inputs(max_image_refs=9),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$refs := $lookup(inputGroups, "model.image_ref");
|
||||
$base := $m = "uni-1-max" ? 0.1 : 0.0404;
|
||||
{"type":"usd","usd": $round($base + 0.003 * $refs, 4)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=6000)
|
||||
aspect_ratio = model["aspect_ratio"]
|
||||
style = model["style"]
|
||||
allowed_manga_ratios = {"2:3", "9:16", "1:2", "1:3"}
|
||||
if style == "manga" and aspect_ratio != "auto" and aspect_ratio not in allowed_manga_ratios:
|
||||
raise ValueError(
|
||||
f"'manga' style requires a portrait aspect ratio "
|
||||
f"({', '.join(sorted(allowed_manga_ratios))}) or 'auto'; got '{aspect_ratio}'."
|
||||
)
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model["model"],
|
||||
type="image",
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "auto" else None,
|
||||
style=style if style != "auto" else None,
|
||||
output_format="png",
|
||||
web_search=model["web_search"],
|
||||
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
|
||||
)
|
||||
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
|
||||
|
||||
|
||||
class LumaImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageEditNode2",
|
||||
display_name="Luma UNI-1 Image Edit",
|
||||
category="api node/image/Luma",
|
||||
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"source",
|
||||
tooltip="Source image to edit.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of the desired edit. 1–6000 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"uni-1",
|
||||
_luma2_uni1_common_inputs(max_image_refs=8),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"uni-1-max",
|
||||
_luma2_uni1_common_inputs(max_image_refs=8),
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for editing.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$refs := $lookup(inputGroups, "model.image_ref");
|
||||
$base := $m = "uni-1-max" ? 0.103 : 0.0434;
|
||||
{"type":"usd","usd": $round($base + 0.003 * $refs, 4)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
source: Input.Image,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=6000)
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model["model"],
|
||||
type="image_edit",
|
||||
source=Luma2ImageRef(url=await upload_image_to_comfyapi(cls, source)),
|
||||
style=model["style"] if model["style"] != "auto" else None,
|
||||
output_format="png",
|
||||
web_search=model["web_search"],
|
||||
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
|
||||
)
|
||||
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
|
||||
|
||||
|
||||
class LumaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -654,6 +942,8 @@ class LumaExtension(ComfyExtension):
|
||||
LumaImageToVideoGenerationNode,
|
||||
LumaReferenceNode,
|
||||
LumaConceptsNode,
|
||||
LumaImageNode,
|
||||
LumaImageEditNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -39,16 +39,18 @@ STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||
|
||||
|
||||
class SupportedOpenAIModel(str, Enum):
|
||||
o4_mini = "o4-mini"
|
||||
o1 = "o1"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
gpt_5_5_pro = "gpt-5.5-pro"
|
||||
gpt_5_5 = "gpt-5.5"
|
||||
gpt_5 = "gpt-5"
|
||||
gpt_5_mini = "gpt-5-mini"
|
||||
gpt_5_nano = "gpt-5-nano"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
o4_mini = "o4-mini"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
o1 = "o1"
|
||||
|
||||
|
||||
async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
|
||||
@ -739,6 +741,16 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
"usd": [0.002, 0.008],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5.5-pro") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.03, 0.18],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5.5") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.005, 0.03],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5-nano") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00005, 0.0004],
|
||||
|
||||
@ -29,6 +29,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
dim=dim,
|
||||
freenoise=freenoise,
|
||||
cond_retain_index_list=cond_retain_index_list,
|
||||
split_conds_to_windows=split_conds_to_windows
|
||||
split_conds_to_windows=split_conds_to_windows,
|
||||
causal_window_fix=causal_window_fix,
|
||||
)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||
|
||||
@ -11,6 +11,142 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import re
|
||||
|
||||
def video_latent_composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False):
|
||||
# destination/source shape: [B, C, F, H, W]
|
||||
source = source.to(destination.device)
|
||||
|
||||
if resize_source:
|
||||
target_size = (source.shape[2], destination.shape[3], destination.shape[4])
|
||||
source = torch.nn.functional.interpolate(
|
||||
source,
|
||||
size=target_size,
|
||||
mode="trilinear",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
x_latent = x // multiplier
|
||||
y_latent = y // multiplier
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(source)
|
||||
else:
|
||||
mask = mask.to(destination.device, copy=True)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
mask_target_size = (mask.shape[2], source.shape[3], source.shape[4])
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask,
|
||||
size=mask_target_size,
|
||||
mode="trilinear",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
dst_h, dst_w = destination.shape[3], destination.shape[4]
|
||||
src_h, src_w = source.shape[3], source.shape[4]
|
||||
|
||||
visible_h = max(0, min(y_latent + src_h, dst_h) - max(0, y_latent))
|
||||
visible_w = max(0, min(x_latent + src_w, dst_w) - max(0, x_latent))
|
||||
|
||||
if visible_h <= 0 or visible_w <= 0:
|
||||
return destination
|
||||
|
||||
src_top = max(0, -y_latent)
|
||||
src_left = max(0, -x_latent)
|
||||
dst_top = max(0, y_latent)
|
||||
dst_left = max(0, x_latent)
|
||||
|
||||
m = mask[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w]
|
||||
s = source[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w]
|
||||
d = destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w]
|
||||
|
||||
destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w] = (m * s) + ((1.0 - m) * d)
|
||||
|
||||
return destination
|
||||
|
||||
def time_to_move_sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, latent_mask, denoise=1.0, start_step=None, time_to_move_last_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
process_latent_out = model.get_model_object("process_latent_out")
|
||||
process_latent_in = model.get_model_object("process_latent_in")
|
||||
|
||||
reference_latent_image = latent_image.clone()
|
||||
|
||||
reference_sigmas = sampler.sigmas
|
||||
reference_noise = noise.clone()
|
||||
|
||||
if last_step == None or last_step > steps:
|
||||
last_step = steps
|
||||
|
||||
if time_to_move_last_step == None or time_to_move_last_step > last_step:
|
||||
time_to_move_last_step = last_step
|
||||
|
||||
if start_step == None:
|
||||
start_step = 0
|
||||
|
||||
total_iterations = min(last_step, steps) - start_step
|
||||
if total_iterations <= 0:
|
||||
return latent_image.to(
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
dtype=comfy.model_management.intermediate_dtype(),
|
||||
)
|
||||
|
||||
for i in range(total_iterations):
|
||||
if i > 0:
|
||||
#don't add new noise to samples after first step taken
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
|
||||
temp_start = start_step + i
|
||||
|
||||
if temp_start < last_step - 1:
|
||||
temp_force_full_denoise = False
|
||||
else:
|
||||
temp_force_full_denoise = force_full_denoise
|
||||
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=temp_start, last_step=temp_start + 1, force_full_denoise=temp_force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
|
||||
if temp_start < time_to_move_last_step:
|
||||
scale = reference_sigmas[temp_start + 1].to(noise.device)
|
||||
|
||||
if torch.count_nonzero(reference_latent_image) > 0: #Don't shift the empty latent image.
|
||||
noisy = model_sampling.noise_scaling(scale, reference_noise, process_latent_in(reference_latent_image))
|
||||
noisy = model_sampling.inverse_noise_scaling(scale, noisy)
|
||||
noisy = process_latent_out(noisy)
|
||||
else:
|
||||
noisy = reference_latent_image
|
||||
|
||||
noisy.to(samples.device)
|
||||
|
||||
samples = video_latent_composite(samples, noisy, 0, 0, latent_mask, multiplier=1, resize_source=True)
|
||||
|
||||
latent_image = samples
|
||||
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return samples
|
||||
|
||||
|
||||
def time_to_move_common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, latent_mask, denoise=1.0, disable_noise=False, start_step=None, time_to_move_last_step = None, last_step=None, force_full_denoise=False):
|
||||
latent_image = latent["samples"]
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
|
||||
|
||||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
else:
|
||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
callback = latent_preview.prepare_callback(model, steps)
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = time_to_move_sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, latent_mask,
|
||||
denoise=denoise, start_step=start_step, time_to_move_last_step = time_to_move_last_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = samples
|
||||
return (out, )
|
||||
|
||||
class BasicScheduler(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -978,6 +1114,46 @@ class SamplerCustomAdvanced(io.ComfyNode):
|
||||
return io.NodeOutput(out, out_denoised)
|
||||
|
||||
sample = execute
|
||||
|
||||
class TimeToMoveKSamplerAdvanced(io.ComfyNode):
|
||||
@classmethod
|
||||
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TimeToMoveKSamplerAdvanced",
|
||||
category="sampling/time_to_move",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("add_noise", options=["enable", "disable"], advanced=True),
|
||||
io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||
io.Combo.Input("sampler_name", options = comfy.samplers.KSampler.SAMPLERS),
|
||||
io.Combo.Input("scheduler", options = comfy.samplers.KSampler.SCHEDULERS),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Latent.Input("latent_image"),
|
||||
io.Mask.Input("latent_mask", tooltip = "Make sure mask is the same length as the latents rather than the original video."),
|
||||
io.Int.Input("start_at_step", default = 0, min = 0, max = 10000, advanced = True, tooltip = "Generally should set at a step greater than 0."),
|
||||
io.Int.Input("time_to_move_end_at_step", default = 0, min = 0, max = 10000, advanced = True, tooltip = "Generally should set at a step greater than 0 and less than total number of steps."),
|
||||
io.Int.Input("end_at_step", default = 10000, min = 0, max = 10000, advanced = True, tooltip = "Use just like typical end_at_step with normal KSamplerAdvanced"),
|
||||
io.Combo.Input("return_with_leftover_noise", options=["disable", "enable"], advanced = True),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, latent_mask, start_at_step, time_to_move_end_at_step, end_at_step, return_with_leftover_noise, denoise=1.0) -> io.NodeOutput:
|
||||
force_full_denoise = True
|
||||
if return_with_leftover_noise == "enable":
|
||||
force_full_denoise = False
|
||||
disable_noise = False
|
||||
if add_noise == "disable":
|
||||
disable_noise = True
|
||||
|
||||
return time_to_move_common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, latent_mask, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, time_to_move_last_step = time_to_move_end_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
|
||||
|
||||
class AddNoise(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -1087,6 +1263,7 @@ class CustomSamplersExtension(ComfyExtension):
|
||||
DisableNoise,
|
||||
AddNoise,
|
||||
SamplerCustomAdvanced,
|
||||
TimeToMoveKSamplerAdvanced,
|
||||
ManualSigmas,
|
||||
]
|
||||
|
||||
|
||||
@ -147,7 +147,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
|
||||
z_channels = audio_vae.latent_channels
|
||||
audio_freq = audio_vae.first_stage_model.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
|
||||
|
||||
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
|
||||
@ -159,7 +158,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": sampling_rate,
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
@ -46,6 +46,42 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||
return destination
|
||||
|
||||
def convert_rgb_mask_to_latent_mask(
|
||||
mask: torch.Tensor,
|
||||
k: int,
|
||||
spatial_downsample_h: int,
|
||||
spatial_downsample_w: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts [T, H, W] mask to [T_latent, H_latent, W_latent].
|
||||
Handles non-square spatial downsampling.
|
||||
"""
|
||||
# 1. Temporal Sampling
|
||||
# Select first frame and every k-th frame thereafter
|
||||
mask0 = mask[0:1]
|
||||
mask1 = mask[1::k]
|
||||
sampled = torch.cat([mask0, mask1], dim=0) # [T_latent, H, W]
|
||||
|
||||
# 2. Prepare for Spatial Interpolation
|
||||
# Shape: [Batch=1, Channels=1, Depth=T_latent, Height=H, Width=W]
|
||||
sampled = sampled.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# 3. Calculate New Spatial Dimensions
|
||||
h_latent = sampled.shape[-2] // spatial_downsample_h
|
||||
w_latent = sampled.shape[-1] // spatial_downsample_w
|
||||
|
||||
# 4. Interpolate
|
||||
# We maintain the temporal count (sampled.shape[2])
|
||||
# but resize H and W independently
|
||||
pooled = torch.nn.functional.interpolate(
|
||||
sampled,
|
||||
size=(sampled.shape[2], h_latent, w_latent),
|
||||
mode="nearest"
|
||||
)
|
||||
|
||||
# 5. Return to [T_latent, H_latent, W_latent]
|
||||
return pooled.squeeze(0).squeeze(0)
|
||||
|
||||
class LatentCompositeMasked(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -73,8 +109,7 @@ class LatentCompositeMasked(IO.ComfyNode):
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
composite = execute # TODO: remove
|
||||
|
||||
|
||||
|
||||
class ImageCompositeMasked(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -403,6 +438,30 @@ class ThresholdMask(IO.ComfyNode):
|
||||
|
||||
image_to_mask = execute # TODO: remove
|
||||
|
||||
class RGBMaskToLatentMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RGBMasktoLatentMask",
|
||||
search_aliases=["rgb mask to latent mask", "rgb mask", "latent mask"],
|
||||
description="Converts an RGB mask to a latent-space mask for use with causal Video VAEs (e.g., Wan).",
|
||||
category="latent",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask", optional=False),
|
||||
IO.Vae.Input("vae", optional=False),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, vae) -> IO.NodeOutput:
|
||||
# Ensure we work on a copy of the mask to remain non-destructive
|
||||
mask_copy = mask.clone()
|
||||
downscale_ratio = vae.downscale_ratio
|
||||
if not isinstance(downscale_ratio, tuple) or len(downscale_ratio) < 3:
|
||||
raise ValueError("RGBMaskToLatentMask requires a causal Video VAE (e.g., Wan). The provided VAE does not have a compatible downscale_ratio.")
|
||||
k = (mask.shape[0] - 1) // (downscale_ratio[0](mask.shape[0]) - 1) if (downscale_ratio[0](mask.shape[0]) - 1) > 1 else 1
|
||||
return IO.NodeOutput(convert_rgb_mask_to_latent_mask(mask_copy, k, spatial_downsample_h = downscale_ratio[1], spatial_downsample_w = downscale_ratio[2]))
|
||||
|
||||
# Mask Preview - original implement from
|
||||
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||
@ -444,6 +503,7 @@ class MaskExtension(ComfyExtension):
|
||||
FeatherMask,
|
||||
GrowMask,
|
||||
ThresholdMask,
|
||||
RGBMaskToLatentMask,
|
||||
MaskPreview,
|
||||
]
|
||||
|
||||
|
||||
@ -9,7 +9,8 @@ class String(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveString",
|
||||
display_name="String",
|
||||
search_aliases=["text", "string", "text box", "prompt"],
|
||||
display_name="Text String",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.String.Input("value"),
|
||||
@ -27,7 +28,8 @@ class StringMultiline(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveStringMultiline",
|
||||
display_name="String (Multiline)",
|
||||
search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"],
|
||||
display_name="Text String (Multiline)",
|
||||
category="utils/primitive",
|
||||
essentials_category="Basics",
|
||||
inputs=[
|
||||
|
||||
@ -10,9 +10,9 @@ class StringConcatenate(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringConcatenate",
|
||||
display_name="Text Concatenate",
|
||||
category="utils/string",
|
||||
search_aliases=["Concatenate", "text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
|
||||
search_aliases=["concatenate", "text concat", "join text", "merge text", "combine strings", "string concat", "append text", "combine text"],
|
||||
display_name="Concatenate Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
@ -33,9 +33,9 @@ class StringSubstring(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringSubstring",
|
||||
search_aliases=["Substring", "extract text", "text portion"],
|
||||
display_name="Text Substring",
|
||||
category="utils/string",
|
||||
search_aliases=["substring", "extract text", "text portion"],
|
||||
display_name="Substring",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Int.Input("start"),
|
||||
@ -58,7 +58,7 @@ class StringLength(io.ComfyNode):
|
||||
node_id="StringLength",
|
||||
search_aliases=["character count", "text size", "string length"],
|
||||
display_name="Text Length",
|
||||
category="utils/string",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
],
|
||||
@ -77,9 +77,9 @@ class CaseConverter(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CaseConverter",
|
||||
search_aliases=["Case Converter", "text case", "uppercase", "lowercase", "capitalize"],
|
||||
display_name="Text Case Converter",
|
||||
category="utils/string",
|
||||
search_aliases=["case converter", "text case", "uppercase", "lowercase", "capitalize"],
|
||||
display_name="Convert Text Case",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]),
|
||||
@ -110,9 +110,9 @@ class StringTrim(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringTrim",
|
||||
search_aliases=["Trim", "clean whitespace", "remove whitespace", "strip"],
|
||||
display_name="Text Trim",
|
||||
category="utils/string",
|
||||
search_aliases=["trim", "clean whitespace", "remove whitespace", "remove spaces","strip"],
|
||||
display_name="Trim Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Combo.Input("mode", options=["Both", "Left", "Right"]),
|
||||
@ -141,9 +141,9 @@ class StringReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringReplace",
|
||||
search_aliases=["Replace", "find and replace", "substitute", "swap text"],
|
||||
display_name="Text Replace",
|
||||
category="utils/string",
|
||||
search_aliases=["replace", "find and replace", "substitute", "swap text"],
|
||||
display_name="Replace Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("find", multiline=True),
|
||||
@ -164,9 +164,9 @@ class StringContains(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringContains",
|
||||
search_aliases=["Contains", "text includes", "string includes"],
|
||||
display_name="Text Contains",
|
||||
category="utils/string",
|
||||
search_aliases=["contains", "text includes", "string includes"],
|
||||
display_name="Contains Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("substring", multiline=True),
|
||||
@ -192,9 +192,9 @@ class StringCompare(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringCompare",
|
||||
search_aliases=["Compare", "text match", "string equals", "starts with", "ends with"],
|
||||
display_name="Text Compare",
|
||||
category="utils/string",
|
||||
search_aliases=["compare", "text match", "string equals", "starts with", "ends with"],
|
||||
display_name="Compare Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
@ -228,9 +228,9 @@ class RegexMatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexMatch",
|
||||
search_aliases=["Regex Match", "regex", "pattern match", "text contains", "string match"],
|
||||
display_name="Text Match",
|
||||
category="utils/string",
|
||||
search_aliases=["regex match", "regex", "pattern match", "text contains", "string match"],
|
||||
display_name="Match Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
@ -269,9 +269,9 @@ class RegexExtract(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexExtract",
|
||||
search_aliases=["Regex Extract", "regex", "pattern extract", "text parser", "parse text"],
|
||||
display_name="Text Extract Substring",
|
||||
category="utils/string",
|
||||
search_aliases=["regex extract", "regex", "pattern extract", "text parser", "parse text"],
|
||||
display_name="Extract Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
@ -344,9 +344,9 @@ class RegexReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexReplace",
|
||||
search_aliases=["Regex Replace", "regex", "pattern replace", "regex replace", "substitution"],
|
||||
display_name="Text Replace (Regex)",
|
||||
category="utils/string",
|
||||
search_aliases=["regex replace", "regex", "pattern replace", "substitution"],
|
||||
display_name="Replace Text (Regex)",
|
||||
category="text",
|
||||
description="Find and replace text using regex patterns.",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
@ -381,8 +381,8 @@ class JsonExtractString(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="JsonExtractString",
|
||||
display_name="Extract String from JSON",
|
||||
category="utils/string",
|
||||
display_name="Extract Text from JSON",
|
||||
category="text",
|
||||
search_aliases=["json", "extract json", "parse json", "json value", "read json"],
|
||||
inputs=[
|
||||
io.String.Input("json_string", multiline=True),
|
||||
|
||||
@ -1019,7 +1019,12 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
combo_options = extra_info.get("options", [])
|
||||
else:
|
||||
combo_options = input_type
|
||||
if val not in combo_options:
|
||||
is_multiselect = extra_info.get("multiselect", False)
|
||||
if is_multiselect and isinstance(val, list):
|
||||
invalid_vals = [v for v in val if v not in combo_options]
|
||||
else:
|
||||
invalid_vals = [val] if val not in combo_options else []
|
||||
if invalid_vals:
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
@ -1034,7 +1039,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
error = {
|
||||
"type": "value_not_in_list",
|
||||
"message": "Value not in list",
|
||||
"details": f"{x}: '{val}' not in {list_info}",
|
||||
"details": f"{x}: {', '.join(repr(v) for v in invalid_vals)} not in {list_info}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": input_config,
|
||||
|
||||
@ -432,7 +432,9 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
remainder = filename[prefix_len + 1:]
|
||||
base_remainder = remainder.split('.')[0]
|
||||
digits = int(base_remainder.split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return digits, prefix
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -2262,7 +2262,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
122
openapi.yaml
122
openapi.yaml
@ -631,7 +631,7 @@ paths:
|
||||
operationId: getFeatures
|
||||
tags: [system]
|
||||
summary: Get enabled feature flags
|
||||
description: Returns a dictionary of feature flag names to their enabled state.
|
||||
description: Returns a dictionary of feature flag names to their enabled state. Cloud deployments may include additional typed fields alongside the boolean flags.
|
||||
responses:
|
||||
"200":
|
||||
description: Feature flags
|
||||
@ -641,6 +641,43 @@ paths:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: boolean
|
||||
properties:
|
||||
max_upload_size:
|
||||
type: integer
|
||||
format: int64
|
||||
minimum: 0
|
||||
description: "Maximum file upload size in bytes."
|
||||
free_tier_credits:
|
||||
type: integer
|
||||
format: int32
|
||||
minimum: 0
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Credits available to free-tier users. Local ComfyUI returns null."
|
||||
posthog_api_host:
|
||||
type: string
|
||||
format: uri
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] PostHog analytics proxy URL for frontend telemetry. Local ComfyUI returns null."
|
||||
max_concurrent_jobs:
|
||||
type: integer
|
||||
format: int32
|
||||
minimum: 0
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Maximum concurrent jobs the authenticated user can run. Local ComfyUI returns null."
|
||||
workflow_templates_version:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Version identifier for the workflow templates bundle. Local ComfyUI returns null."
|
||||
workflow_templates_source:
|
||||
type: string
|
||||
nullable: true
|
||||
enum: [dynamic_config_override, workflow_templates_version_json]
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] How the templates version was resolved. Local ComfyUI returns null."
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node / Object Info
|
||||
@ -1497,6 +1534,24 @@ paths:
|
||||
type: string
|
||||
enum: [asc, desc]
|
||||
description: Sort direction
|
||||
- name: job_ids
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
|
||||
- name: include_public
|
||||
in: query
|
||||
schema:
|
||||
type: boolean
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Include workspace-public assets in addition to the caller's own."
|
||||
- name: asset_hash
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Filter by exact content hash."
|
||||
responses:
|
||||
"200":
|
||||
description: Asset list
|
||||
@ -1542,6 +1597,49 @@ paths:
|
||||
type: string
|
||||
format: uuid
|
||||
description: ID of an existing asset to use as the preview image
|
||||
id:
|
||||
type: string
|
||||
format: uuid
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned."
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] URL-based asset upload. Caller supplies a URL instead of a file body; the server fetches the content."
|
||||
required:
|
||||
- url
|
||||
properties:
|
||||
url:
|
||||
type: string
|
||||
format: uri
|
||||
description: "[cloud-only] URL of the file to import as an asset"
|
||||
name:
|
||||
type: string
|
||||
description: Display name for the asset
|
||||
tags:
|
||||
type: string
|
||||
description: Comma-separated tags
|
||||
user_metadata:
|
||||
type: string
|
||||
description: JSON-encoded user metadata
|
||||
hash:
|
||||
type: string
|
||||
description: "Blake3 hash of the file content (e.g. blake3:abc123...)"
|
||||
mime_type:
|
||||
type: string
|
||||
description: MIME type of the file (overrides auto-detected type)
|
||||
preview_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: ID of an existing asset to use as the preview image
|
||||
id:
|
||||
type: string
|
||||
format: uuid
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned."
|
||||
responses:
|
||||
"201":
|
||||
description: Asset created
|
||||
@ -1580,6 +1678,11 @@ paths:
|
||||
user_metadata:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
mime_type:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] MIME type of the content, so the type is preserved without re-inspecting content. Ignored by local ComfyUI."
|
||||
responses:
|
||||
"201":
|
||||
description: Asset created from hash
|
||||
@ -1644,6 +1747,11 @@ paths:
|
||||
type: string
|
||||
format: uuid
|
||||
description: ID of the asset to use as the preview
|
||||
mime_type:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] MIME type override when auto-detection was wrong. Ignored by local ComfyUI."
|
||||
responses:
|
||||
"200":
|
||||
description: Asset updated
|
||||
@ -2004,21 +2112,13 @@ components:
|
||||
format: uuid
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: |
|
||||
UUID identifying a hosted-cloud workflow entity to associate with this
|
||||
job. Local ComfyUI doesn't track workflow entities and returns `null`
|
||||
(or omits the field). The `x-runtime: [cloud]` extension marks this
|
||||
as populated only by the hosted-cloud runtime; absence of the tag
|
||||
means a field is populated by all runtimes.
|
||||
description: "[cloud-only] Cloud workflow entity ID for tracking and gallery association. Ignored by local ComfyUI."
|
||||
workflow_version_id:
|
||||
type: string
|
||||
format: uuid
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: |
|
||||
UUID identifying a hosted-cloud workflow version to associate with
|
||||
this job. Local ComfyUI returns `null` (or omits the field). See
|
||||
`workflow_id` above for `x-runtime` semantics.
|
||||
description: "[cloud-only] Cloud workflow version ID for pinning execution to a specific version. Ignored by local ComfyUI."
|
||||
|
||||
PromptResponse:
|
||||
type: object
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.15
|
||||
comfyui-workflow-templates==0.9.68
|
||||
comfyui-workflow-templates==0.9.69
|
||||
comfyui-embedded-docs==0.4.4
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -560,7 +560,7 @@ class PromptServer():
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
|
||||
if 'channel' not in request.rel_url.query:
|
||||
channel = 'rgba'
|
||||
@ -580,7 +580,7 @@ class PromptServer():
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
|
||||
elif channel == 'a':
|
||||
with Image.open(file) as img:
|
||||
@ -597,7 +597,7 @@ class PromptServer():
|
||||
alpha_buffer.seek(0)
|
||||
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
else:
|
||||
# Use the content type from asset resolution if available,
|
||||
# otherwise guess from the filename.
|
||||
@ -614,7 +614,7 @@ class PromptServer():
|
||||
return web.FileResponse(
|
||||
file,
|
||||
headers={
|
||||
"Content-Disposition": f"filename=\"{filename}\"",
|
||||
"Content-Disposition": f"attachment; filename=\"{filename}\"",
|
||||
"Content-Type": content_type
|
||||
}
|
||||
)
|
||||
|
||||
78
tests-unit/comfy_api_test/multicombo_serialization_test.py
Normal file
78
tests-unit/comfy_api_test/multicombo_serialization_test.py
Normal file
@ -0,0 +1,78 @@
|
||||
from comfy_api.latest._io import Combo, MultiCombo
|
||||
|
||||
|
||||
def test_multicombo_serializes_multi_select_as_object():
|
||||
multi_combo = MultiCombo.Input(
|
||||
id="providers",
|
||||
options=["a", "b", "c"],
|
||||
default=["a"],
|
||||
)
|
||||
|
||||
serialized = multi_combo.as_dict()
|
||||
|
||||
assert serialized["multiselect"] is True
|
||||
assert "multi_select" in serialized
|
||||
assert serialized["multi_select"] == {}
|
||||
|
||||
|
||||
def test_multicombo_serializes_multi_select_with_placeholder_and_chip():
|
||||
multi_combo = MultiCombo.Input(
|
||||
id="providers",
|
||||
options=["a", "b", "c"],
|
||||
default=["a"],
|
||||
placeholder="Select providers",
|
||||
chip=True,
|
||||
)
|
||||
|
||||
serialized = multi_combo.as_dict()
|
||||
|
||||
assert serialized["multiselect"] is True
|
||||
assert serialized["multi_select"] == {
|
||||
"placeholder": "Select providers",
|
||||
"chip": True,
|
||||
}
|
||||
|
||||
|
||||
def test_combo_does_not_serialize_multiselect():
|
||||
"""Regular Combo should not have multiselect in its serialized output."""
|
||||
combo = Combo.Input(
|
||||
id="choice",
|
||||
options=["a", "b", "c"],
|
||||
)
|
||||
|
||||
serialized = combo.as_dict()
|
||||
|
||||
# Combo sets multiselect=False, but prune_dict keeps False (not None),
|
||||
# so it should be present but False
|
||||
assert serialized.get("multiselect") is False
|
||||
assert "multi_select" not in serialized
|
||||
|
||||
|
||||
def _validate_combo_values(val, combo_options, is_multiselect):
|
||||
"""Reproduce the validation logic from execution.py for testing."""
|
||||
if is_multiselect and isinstance(val, list):
|
||||
return [v for v in val if v not in combo_options]
|
||||
else:
|
||||
return [val] if val not in combo_options else []
|
||||
|
||||
|
||||
def test_multicombo_validation_accepts_valid_list():
|
||||
options = ["a", "b", "c"]
|
||||
assert _validate_combo_values(["a", "b"], options, True) == []
|
||||
|
||||
|
||||
def test_multicombo_validation_rejects_invalid_values():
|
||||
options = ["a", "b", "c"]
|
||||
assert _validate_combo_values(["a", "x"], options, True) == ["x"]
|
||||
|
||||
|
||||
def test_multicombo_validation_accepts_empty_list():
|
||||
options = ["a", "b", "c"]
|
||||
assert _validate_combo_values([], options, True) == []
|
||||
|
||||
|
||||
def test_combo_validation_rejects_list_even_with_valid_items():
|
||||
"""A regular Combo should not accept a list value."""
|
||||
options = ["a", "b", "c"]
|
||||
invalid = _validate_combo_values(["a", "b"], options, False)
|
||||
assert len(invalid) > 0
|
||||
Loading…
Reference in New Issue
Block a user