use new API client in Pixverse and Ideogram nodes (#10543)

This commit is contained in:
Alexander Piskun 2025-10-30 08:49:03 +02:00 committed by GitHub
parent 998bf60beb
commit 163b629c70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 220 additions and 459 deletions

View File

@ -1,15 +1,12 @@
from __future__ import annotations from __future__ import annotations
import aiohttp import aiohttp
import mimetypes import mimetypes
from typing import Optional, Union from typing import Union
from comfy.utils import common_upscale
from server import PromptServer from server import PromptServer
from comfy.cli_args import args
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import math
import base64 import base64
from io import BytesIO from io import BytesIO
@ -60,85 +57,6 @@ async def validate_and_cast_response(
return torch.stack(image_tensors, dim=0) return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def text_filepath_to_base64_string(filepath: str) -> str: def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string.""" """Converts a text file to a base64 string."""
with open(filepath, "rb") as f: with open(filepath, "rb") as f:
@ -153,28 +71,3 @@ def text_filepath_to_data_uri(filepath: str) -> str:
if mime_type is None: if mime_type is None:
mime_type = "application/octet-stream" mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}" return f"data:{mime_type};base64,{base64_string}"
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask

View File

@ -5,10 +5,6 @@ import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
validate_aspect_ratio,
)
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
@ -23,8 +19,10 @@ from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_image_tensor, download_url_to_image_tensor,
poll_op, poll_op,
resize_mask_to_image,
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
validate_aspect_ratio_string,
validate_string, validate_string,
) )
@ -43,11 +41,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
""" """
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
@ -112,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
@classmethod @classmethod
def validate_inputs(cls, aspect_ratio: str): def validate_inputs(cls, aspect_ratio: str):
try: validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True return True
@classmethod @classmethod
@ -145,13 +129,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
prompt_upsampling=prompt_upsampling, prompt_upsampling=prompt_upsampling,
seed=seed, seed=seed,
aspect_ratio=validate_aspect_ratio( aspect_ratio=aspect_ratio,
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
),
raw=raw, raw=raw,
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
@ -180,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
""" """
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
@ -261,13 +234,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
aspect_ratio = validate_aspect_ratio( validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
if input_image is None: if input_image is None:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
initial_response = await sync_op( initial_response = await sync_op(

View File

@ -17,7 +17,7 @@ from comfy_api_nodes.util import (
poll_op, poll_op,
sync_op, sync_op,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_image_aspect_ratio_range, validate_image_aspect_ratio,
validate_image_dimensions, validate_image_dimensions,
validate_string, validate_string,
) )
@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.") raise ValueError("Exactly one input image is required.")
validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) validate_image_aspect_ratio(image, (1, 3), (3, 1))
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
payload = Image2ImageTaskCreationRequest( payload = Image2ImageTaskCreationRequest(
model=model, model=model,
@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
reference_images_urls = [] reference_images_urls = []
if n_input_images: if n_input_images:
for i in image: for i in image:
validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) validate_image_aspect_ratio(i, (1, 3), (3, 1))
reference_images_urls = await upload_images_to_comfyapi( reference_images_urls = await upload_images_to_comfyapi(
cls, cls,
image, image,
@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
prompt = ( prompt = (
@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
for i in (first_frame, last_frame): for i in (first_frame, last_frame):
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
for image in images: for image in images:
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
prompt = ( prompt = (

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import IO, ComfyExtension
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import torch import torch
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
IdeogramV3Request, IdeogramV3Request,
IdeogramV3EditRequest, IdeogramV3EditRequest,
) )
from comfy_api_nodes.util import (
from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
HttpMethod, bytesio_to_image_tensor,
SynchronousOperation, download_url_as_bytesio,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
resize_mask_to_image, resize_mask_to_image,
sync_op,
) )
from comfy_api_nodes.util import bytesio_to_image_tensor
from server import PromptServer
V1_V1_RES_MAP = { V1_V1_RES_MAP = {
"Auto":"AUTO", "Auto":"AUTO",
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls: for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing # Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor) image_tensors.append(img_tensor)
@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
return stacked_tensors return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(IO.ComfyNode): class IdeogramV1(IO.ComfyNode):
@classmethod @classmethod
@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1" model = "V_1_TURBO" if turbo else "V_1"
auth = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
} response_model=IdeogramGenerateResponse,
operation = SynchronousOperation( data=IdeogramGenerateRequest(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest( image_request=ImageRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
num_images=num_images, num_images=num_images,
seed=seed, seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=( magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
) )
), ),
auth_kwargs=auth, max_retries=1,
) )
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
else: else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
} response_model=IdeogramGenerateResponse,
operation = SynchronousOperation( data=IdeogramGenerateRequest(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest( image_request=ImageRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
seed=seed, seed=seed,
aspect_ratio=final_aspect_ratio, aspect_ratio=final_aspect_ratio,
resolution=final_resolution, resolution=final_resolution,
magic_prompt_option=( magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
style_type=style_type if style_type != "NONE" else None, style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None, color_palette=color_palette if color_palette else None,
) )
), ),
auth_kwargs=auth, max_retries=1,
) )
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
character_image=None, character_image=None,
character_mask=None, character_mask=None,
): ):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT" rendering_speed = "DEFAULT"
@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
# Check if both image and mask are provided for editing mode # Check if both image and mask are provided for editing mode
if image is not None and mask is not None: if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask # Process image and mask
input_tensor = image.squeeze().cpu() input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension # Resize mask to match image dimension
@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
if character_mask_binary: if character_mask_binary:
files["character_mask_binary"] = character_mask_binary files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
path=path, response_model=IdeogramGenerateResponse,
method=HttpMethod.POST, data=edit_request,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth, max_retries=1,
) )
elif image is not None or mask is not None: elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error # If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask") raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else: else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request # Create generation request
gen_request = IdeogramV3Request( gen_request = IdeogramV3Request(
prompt=prompt, prompt=prompt,
@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
if files: if files:
gen_request.style_type = "AUTO" gen_request.style_type = "AUTO"
# Execute the operation for generation mode response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
path=path, response_model=IdeogramGenerateResponse,
method=HttpMethod.POST, data=gen_request,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
files=files if files else None, files=files if files else None,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth, max_retries=1,
) )
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
IdeogramV3, IdeogramV3,
] ]
async def comfy_entrypoint() -> IdeogramExtension: async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension() return IdeogramExtension()

View File

@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
""" """
validate_image_dimensions(image, min_width=300, min_height=300) validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
def get_video_from_response(response) -> KlingVideoResult: def get_video_from_response(response) -> KlingVideoResult:

View File

@ -1,7 +1,6 @@
from inspect import cleandoc import torch
from typing import Optional
from typing_extensions import override from typing_extensions import override
from io import BytesIO from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import ( from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest, PixverseTextVideoRequest,
PixverseImageVideoRequest, PixverseImageVideoRequest,
@ -17,53 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
PixverseIO, PixverseIO,
pixverse_templates, pixverse_templates,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_video_output,
SynchronousOperation, poll_op,
PollingOperation, sync_op,
EmptyRequest, tensor_to_bytesio,
validate_string,
) )
from comfy_api_nodes.util import validate_string, tensor_to_bytesio
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
import torch
import aiohttp
AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30 AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52 AVERAGE_DURATION_T2T = 52
def get_video_url_from_response( async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response: PixverseGenerationStatusResponse, response_upload = await sync_op(
) -> Optional[str]: cls,
if response.Resp is None or response.Resp.url is None: ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
return None response_model=PixverseImageUploadResponse,
return str(response.Resp.url)
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
files={"image": tensor_to_bytesio(image)}, files={"image": tensor_to_bytesio(image)},
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
) )
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None: if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id return response_upload.Resp.img_id
@ -93,17 +69,13 @@ class PixverseTemplateNode(IO.ComfyNode):
class PixverseTextToVideoNode(IO.ComfyNode): class PixverseTextToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseTextToVideoNode", node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video", display_name="PixVerse Text to Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -170,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
negative_prompt: str = None, negative_prompt: str = None,
pixverse_template: int = None, pixverse_template: int = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False, min_length=1)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p: if quality == PixverseQuality.res_1080p:
@ -179,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
} response_model=PixverseVideoResponse,
operation = SynchronousOperation( data=PixverseTextVideoRequest(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
quality=quality, quality=quality,
@ -200,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -221,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(IO.ComfyNode): class PixverseImageToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseImageToVideoNode", node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video", display_name="PixVerse Image to Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.String.Input( IO.String.Input(
@ -309,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
pixverse_template: int = None, pixverse_template: int = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
auth = { img_id = await upload_image_to_pixverse(cls, image)
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -323,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/pixverse/video/img/generate", ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
method=HttpMethod.POST, response_model=PixverseVideoResponse,
request_model=PixverseImageVideoRequest, data=PixverseImageVideoRequest(
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
img_id=img_id, img_id=img_id,
prompt=prompt, prompt=prompt,
quality=quality, quality=quality,
@ -340,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -361,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(IO.ComfyNode): class PixverseTransitionVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseTransitionVideoNode", node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video", display_name="PixVerse Transition Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("first_frame"), IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"), IO.Image.Input("last_frame"),
@ -445,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt: str = None, negative_prompt: str = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
auth = { first_frame_id = await upload_image_to_pixverse(cls, first_frame)
"auth_token": cls.hidden.auth_token_comfy_org, last_frame_id = await upload_image_to_pixverse(cls, last_frame)
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -460,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/pixverse/video/transition/generate", ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
method=HttpMethod.POST, response_model=PixverseVideoResponse,
request_model=PixverseTransitionVideoRequest, data=PixverseTransitionVideoRequest(
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id, first_frame_img=first_frame_id,
last_frame_img=last_frame_id, last_frame_img=last_frame_id,
prompt=prompt, prompt=prompt,
@ -477,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -498,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixVerseExtension(ComfyExtension): class PixVerseExtension(ComfyExtension):

View File

@ -8,9 +8,6 @@ from typing_extensions import override
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
)
from comfy_api_nodes.apis.recraft_api import ( from comfy_api_nodes.apis.recraft_api import (
RecraftColor, RecraftColor,
RecraftColorChain, RecraftColorChain,
@ -28,6 +25,7 @@ from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
bytesio_to_image_tensor, bytesio_to_image_tensor,
download_url_as_bytesio, download_url_as_bytesio,
resize_mask_to_image,
sync_op, sync_op,
tensor_to_bytesio, tensor_to_bytesio,
validate_string, validate_string,

View File

@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999) validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
reference_images = None reference_images = None
if reference_image is not None: if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999) validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
reference_image, reference_image,

View File

@ -14,9 +14,9 @@ from comfy_api_nodes.util import (
poll_op, poll_op,
sync_op, sync_op,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_aspect_ratio_closeness, validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions, validate_image_dimensions,
validate_images_aspect_ratio_closeness,
) )
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
@ -114,7 +114,7 @@ async def execute_task(
cls, cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: r.state.value, status_extractor=lambda r: r.state,
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
) )
@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if get_number_of_images(image) > 1: if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.") raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
prompt=prompt, prompt=prompt,
@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
if a > 7: if a > 7:
raise ValueError("Too many images, maximum allowed is 7.") raise ValueError("Too many images, maximum allowed is 7.")
for image in images: for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128) validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution: str, resolution: str,
movement_amplitude: str, movement_amplitude: str,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
prompt=prompt, prompt=prompt,

View File

@ -14,6 +14,7 @@ from .conversions import (
downscale_image_tensor, downscale_image_tensor,
image_tensor_pair_to_batch, image_tensor_pair_to_batch,
pil_to_bytesio, pil_to_bytesio,
resize_mask_to_image,
tensor_to_base64_string, tensor_to_base64_string,
tensor_to_bytesio, tensor_to_bytesio,
tensor_to_pil, tensor_to_pil,
@ -34,12 +35,12 @@ from .upload_helpers import (
) )
from .validation_utils import ( from .validation_utils import (
get_number_of_images, get_number_of_images,
validate_aspect_ratio_closeness, validate_aspect_ratio_string,
validate_audio_duration, validate_audio_duration,
validate_container_format_is_mp4, validate_container_format_is_mp4,
validate_image_aspect_ratio, validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions, validate_image_dimensions,
validate_images_aspect_ratio_closeness,
validate_string, validate_string,
validate_video_dimensions, validate_video_dimensions,
validate_video_duration, validate_video_duration,
@ -70,6 +71,7 @@ __all__ = [
"downscale_image_tensor", "downscale_image_tensor",
"image_tensor_pair_to_batch", "image_tensor_pair_to_batch",
"pil_to_bytesio", "pil_to_bytesio",
"resize_mask_to_image",
"tensor_to_base64_string", "tensor_to_base64_string",
"tensor_to_bytesio", "tensor_to_bytesio",
"tensor_to_pil", "tensor_to_pil",
@ -77,12 +79,12 @@ __all__ = [
"video_to_base64_string", "video_to_base64_string",
# Validation utilities # Validation utilities
"get_number_of_images", "get_number_of_images",
"validate_aspect_ratio_closeness", "validate_aspect_ratio_string",
"validate_audio_duration", "validate_audio_duration",
"validate_container_format_is_mp4", "validate_container_format_is_mp4",
"validate_image_aspect_ratio", "validate_image_aspect_ratio",
"validate_image_aspect_ratio_range",
"validate_image_dimensions", "validate_image_dimensions",
"validate_images_aspect_ratio_closeness",
"validate_string", "validate_string",
"validate_video_dimensions", "validate_video_dimensions",
"validate_video_duration", "validate_video_duration",

View File

@ -430,3 +430,24 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
wav = torch.cat(frames, dim=1) # [C, T] wav = torch.cat(frames, dim=1) # [C, T]
wav = _f32_pcm(wav) wav = _f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
_, height, width, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask

View File

@ -37,63 +37,62 @@ def validate_image_dimensions(
def validate_image_aspect_ratio( def validate_image_aspect_ratio(
image: torch.Tensor, image: torch.Tensor,
min_aspect_ratio: Optional[float] = None, min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_aspect_ratio: Optional[float] = None, max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
*, *,
strict: bool = True, # True -> (min, max); False -> [min, max] strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float: ) -> float:
a1, b1 = min_ratio """Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
w, h = get_image_dimensions(image) w, h = get_image_dimensions(image)
if w <= 0 or h <= 0: if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}") raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h ar = w / h
ok = (lo < ar < hi) if strict else (lo <= ar <= hi) _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
if not ok:
op = "<" if strict else ""
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
return ar return ar
def validate_aspect_ratio_closeness( def validate_images_aspect_ratio_closeness(
start_img, first_image: torch.Tensor,
end_img, second_image: torch.Tensor,
min_rel: float, min_rel: float, # e.g. 0.8
max_rel: float, max_rel: float, # e.g. 1.25
*, *,
strict: bool = False, # True => exclusive, False => inclusive strict: bool = False, # True -> (min, max); False -> [min, max]
) -> None: ) -> float:
w1, h1 = get_image_dimensions(start_img) """
w2, h2 = get_image_dimensions(end_img) Validates that the two images' aspect ratios are 'close'.
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
Returns the computed closeness factor C.
"""
w1, h1 = get_image_dimensions(first_image)
w2, h2 = get_image_dimensions(second_image)
if min(w1, h1, w2, h2) <= 0: if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions") raise ValueError("Invalid image dimensions")
ar1 = w1 / h1 ar1 = w1 / h1
ar2 = w2 / h2 ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2) closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25 limit = max(max_rel, 1.0 / min_rel)
if (closeness >= limit) if strict else (closeness > limit): if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.") raise ValueError(
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
f"allowed range {min_rel}{max_rel} (limit {limit:.2g})."
)
return closeness
def validate_aspect_ratio_string(
aspect_ratio: str,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
ar = _parse_aspect_ratio_string(aspect_ratio)
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
return ar
def validate_video_dimensions( def validate_video_dimensions(
@ -183,3 +182,49 @@ def validate_container_format_is_mp4(video: VideoInput) -> None:
container_format = video.get_container_format() container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}") raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
def _ratio_from_tuple(r: tuple[float, float]) -> float:
a, b = r
if a <= 0 or b <= 0:
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
return a / b
def _assert_ratio_bounds(
ar: float,
*,
min_ratio: Optional[tuple[float, float]] = None,
max_ratio: Optional[tuple[float, float]] = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
if lo is not None and hi is not None and lo > hi:
lo, hi = hi, lo # normalize order if caller swapped them
if lo is not None:
if (ar <= lo) if strict else (ar < lo):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
if hi is not None:
if (ar >= hi) if strict else (ar > hi):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
def _parse_aspect_ratio_string(ar_str: str) -> float:
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
parts = ar_str.split(":")
if len(parts) != 2:
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
try:
a = int(parts[0].strip())
b = int(parts[1].strip())
except ValueError as exc:
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
if a <= 0 or b <= 0:
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
return a / b