mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
use new API client in Pixverse and Ideogram nodes (#10543)
This commit is contained in:
parent
998bf60beb
commit
163b629c70
@ -1,15 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from typing import Optional, Union
|
from typing import Union
|
||||||
from comfy.utils import common_upscale
|
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
from comfy.cli_args import args
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
import math
|
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@ -60,85 +57,6 @@ async def validate_and_cast_response(
|
|||||||
return torch.stack(image_tensors, dim=0)
|
return torch.stack(image_tensors, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def validate_aspect_ratio(
|
|
||||||
aspect_ratio: str,
|
|
||||||
minimum_ratio: float,
|
|
||||||
maximum_ratio: float,
|
|
||||||
minimum_ratio_str: str,
|
|
||||||
maximum_ratio_str: str,
|
|
||||||
) -> float:
|
|
||||||
"""Validates and casts an aspect ratio string to a float.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
aspect_ratio: The aspect ratio string to validate.
|
|
||||||
minimum_ratio: The minimum aspect ratio.
|
|
||||||
maximum_ratio: The maximum aspect ratio.
|
|
||||||
minimum_ratio_str: The minimum aspect ratio string.
|
|
||||||
maximum_ratio_str: The maximum aspect ratio string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The validated and cast aspect ratio.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If the aspect ratio is not valid.
|
|
||||||
"""
|
|
||||||
# get ratio values
|
|
||||||
numbers = aspect_ratio.split(":")
|
|
||||||
if len(numbers) != 2:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
numerator = int(numbers[0])
|
|
||||||
denominator = int(numbers[1])
|
|
||||||
except ValueError as exc:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
|
||||||
) from exc
|
|
||||||
calculated_ratio = numerator / denominator
|
|
||||||
# if not close to minimum and maximum, check bounds
|
|
||||||
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
|
||||||
calculated_ratio, maximum_ratio
|
|
||||||
):
|
|
||||||
if calculated_ratio < minimum_ratio:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
|
||||||
)
|
|
||||||
if calculated_ratio > maximum_ratio:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
|
||||||
)
|
|
||||||
return aspect_ratio
|
|
||||||
|
|
||||||
|
|
||||||
async def download_url_to_bytesio(
|
|
||||||
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
|
||||||
) -> BytesIO:
|
|
||||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The URL to download.
|
|
||||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BytesIO object containing the downloaded content.
|
|
||||||
"""
|
|
||||||
headers = {}
|
|
||||||
if url.startswith("/proxy/"):
|
|
||||||
url = str(args.comfy_api_base).rstrip("/") + url
|
|
||||||
auth_token = auth_kwargs.get("auth_token")
|
|
||||||
comfy_api_key = auth_kwargs.get("comfy_api_key")
|
|
||||||
if auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {auth_token}"
|
|
||||||
elif comfy_api_key:
|
|
||||||
headers["X-API-KEY"] = comfy_api_key
|
|
||||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
||||||
async with session.get(url, headers=headers) as resp:
|
|
||||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
|
||||||
return BytesIO(await resp.read())
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||||
"""Converts a text file to a base64 string."""
|
"""Converts a text file to a base64 string."""
|
||||||
with open(filepath, "rb") as f:
|
with open(filepath, "rb") as f:
|
||||||
@ -153,28 +71,3 @@ def text_filepath_to_data_uri(filepath: str) -> str:
|
|||||||
if mime_type is None:
|
if mime_type is None:
|
||||||
mime_type = "application/octet-stream"
|
mime_type = "application/octet-stream"
|
||||||
return f"data:{mime_type};base64,{base64_string}"
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|
||||||
|
|
||||||
def resize_mask_to_image(
|
|
||||||
mask: torch.Tensor,
|
|
||||||
image: torch.Tensor,
|
|
||||||
upscale_method="nearest-exact",
|
|
||||||
crop="disabled",
|
|
||||||
allow_gradient=True,
|
|
||||||
add_channel_dim=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
|
||||||
"""
|
|
||||||
_, H, W, _ = image.shape
|
|
||||||
mask = mask.unsqueeze(-1)
|
|
||||||
mask = mask.movedim(-1, 1)
|
|
||||||
mask = common_upscale(
|
|
||||||
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
|
||||||
)
|
|
||||||
mask = mask.movedim(1, -1)
|
|
||||||
if not add_channel_dim:
|
|
||||||
mask = mask.squeeze(-1)
|
|
||||||
if not allow_gradient:
|
|
||||||
mask = (mask > 0.5).float()
|
|
||||||
return mask
|
|
||||||
|
|||||||
@ -5,10 +5,6 @@ import torch
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
resize_mask_to_image,
|
|
||||||
validate_aspect_ratio,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.bfl_api import (
|
from comfy_api_nodes.apis.bfl_api import (
|
||||||
BFLFluxExpandImageRequest,
|
BFLFluxExpandImageRequest,
|
||||||
BFLFluxFillImageRequest,
|
BFLFluxFillImageRequest,
|
||||||
@ -23,8 +19,10 @@ from comfy_api_nodes.util import (
|
|||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
download_url_to_image_tensor,
|
download_url_to_image_tensor,
|
||||||
poll_op,
|
poll_op,
|
||||||
|
resize_mask_to_image,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
validate_aspect_ratio_string,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,11 +41,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MINIMUM_RATIO = 1 / 4
|
|
||||||
MAXIMUM_RATIO = 4 / 1
|
|
||||||
MINIMUM_RATIO_STR = "1:4"
|
|
||||||
MAXIMUM_RATIO_STR = "4:1"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
@ -112,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_inputs(cls, aspect_ratio: str):
|
def validate_inputs(cls, aspect_ratio: str):
|
||||||
try:
|
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
|
||||||
validate_aspect_ratio(
|
|
||||||
aspect_ratio,
|
|
||||||
minimum_ratio=cls.MINIMUM_RATIO,
|
|
||||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
|
||||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
|
||||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return str(e)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -145,13 +129,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_upsampling=prompt_upsampling,
|
prompt_upsampling=prompt_upsampling,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=validate_aspect_ratio(
|
aspect_ratio=aspect_ratio,
|
||||||
aspect_ratio,
|
|
||||||
minimum_ratio=cls.MINIMUM_RATIO,
|
|
||||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
|
||||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
|
||||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
|
||||||
),
|
|
||||||
raw=raw,
|
raw=raw,
|
||||||
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
|
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
|
||||||
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
|
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
|
||||||
@ -180,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MINIMUM_RATIO = 1 / 4
|
|
||||||
MAXIMUM_RATIO = 4 / 1
|
|
||||||
MINIMUM_RATIO_STR = "1:4"
|
|
||||||
MAXIMUM_RATIO_STR = "4:1"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
@ -261,13 +234,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
seed=0,
|
seed=0,
|
||||||
prompt_upsampling=False,
|
prompt_upsampling=False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
aspect_ratio = validate_aspect_ratio(
|
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
|
||||||
aspect_ratio,
|
|
||||||
minimum_ratio=cls.MINIMUM_RATIO,
|
|
||||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
|
||||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
|
||||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
|
||||||
)
|
|
||||||
if input_image is None:
|
if input_image is None:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
validate_image_aspect_ratio_range,
|
validate_image_aspect_ratio,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Exactly one input image is required.")
|
raise ValueError("Exactly one input image is required.")
|
||||||
validate_image_aspect_ratio_range(image, (1, 3), (3, 1))
|
validate_image_aspect_ratio(image, (1, 3), (3, 1))
|
||||||
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
|
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
|
||||||
payload = Image2ImageTaskCreationRequest(
|
payload = Image2ImageTaskCreationRequest(
|
||||||
model=model,
|
model=model,
|
||||||
@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
reference_images_urls = []
|
reference_images_urls = []
|
||||||
if n_input_images:
|
if n_input_images:
|
||||||
for i in image:
|
for i in image:
|
||||||
validate_image_aspect_ratio_range(i, (1, 3), (3, 1))
|
validate_image_aspect_ratio(i, (1, 3), (3, 1))
|
||||||
reference_images_urls = await upload_images_to_comfyapi(
|
reference_images_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
image,
|
image,
|
||||||
@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||||
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||||
|
|
||||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||||
prompt = (
|
prompt = (
|
||||||
@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||||
for i in (first_frame, last_frame):
|
for i in (first_frame, last_frame):
|
||||||
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||||
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||||
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
|
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
|
||||||
for image in images:
|
for image in images:
|
||||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||||
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||||
|
|
||||||
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
|
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
|
||||||
prompt = (
|
prompt = (
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
|
|||||||
IdeogramV3Request,
|
IdeogramV3Request,
|
||||||
IdeogramV3EditRequest,
|
IdeogramV3EditRequest,
|
||||||
)
|
)
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
bytesio_to_image_tensor,
|
||||||
SynchronousOperation,
|
download_url_as_bytesio,
|
||||||
)
|
|
||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
download_url_to_bytesio,
|
|
||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
|
sync_op,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import bytesio_to_image_tensor
|
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
V1_V1_RES_MAP = {
|
V1_V1_RES_MAP = {
|
||||||
"Auto":"AUTO",
|
"Auto":"AUTO",
|
||||||
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
|
|||||||
|
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
# Using functions from apinode_utils.py to handle downloading and processing
|
# Using functions from apinode_utils.py to handle downloading and processing
|
||||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
|
||||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||||
image_tensors.append(img_tensor)
|
image_tensors.append(img_tensor)
|
||||||
|
|
||||||
@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
|
|||||||
return stacked_tensors
|
return stacked_tensors
|
||||||
|
|
||||||
|
|
||||||
def display_image_urls_on_node(image_urls, node_id):
|
|
||||||
if node_id and image_urls:
|
|
||||||
if len(image_urls) == 1:
|
|
||||||
PromptServer.instance.send_progress_text(
|
|
||||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
|
||||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
|
||||||
)
|
|
||||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
|
||||||
|
|
||||||
|
|
||||||
class IdeogramV1(IO.ComfyNode):
|
class IdeogramV1(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
|
|||||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||||
model = "V_1_TURBO" if turbo else "V_1"
|
model = "V_1_TURBO" if turbo else "V_1"
|
||||||
|
|
||||||
auth = {
|
response = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||||
}
|
response_model=IdeogramGenerateResponse,
|
||||||
operation = SynchronousOperation(
|
data=IdeogramGenerateRequest(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/ideogram/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=IdeogramGenerateRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
|
||||||
),
|
|
||||||
request=IdeogramGenerateRequest(
|
|
||||||
image_request=ImageRequest(
|
image_request=ImageRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
num_images=num_images,
|
num_images=num_images,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||||
magic_prompt_option=(
|
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
|
||||||
),
|
|
||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||||
|
|
||||||
auth = {
|
response = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||||
}
|
response_model=IdeogramGenerateResponse,
|
||||||
operation = SynchronousOperation(
|
data=IdeogramGenerateRequest(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/ideogram/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=IdeogramGenerateRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
|
||||||
),
|
|
||||||
request=IdeogramGenerateRequest(
|
|
||||||
image_request=ImageRequest(
|
image_request=ImageRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=final_aspect_ratio,
|
aspect_ratio=final_aspect_ratio,
|
||||||
resolution=final_resolution,
|
resolution=final_resolution,
|
||||||
magic_prompt_option=(
|
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
|
||||||
),
|
|
||||||
style_type=style_type if style_type != "NONE" else None,
|
style_type=style_type if style_type != "NONE" else None,
|
||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
color_palette=color_palette if color_palette else None,
|
color_palette=color_palette if color_palette else None,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
character_image=None,
|
character_image=None,
|
||||||
character_mask=None,
|
character_mask=None,
|
||||||
):
|
):
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
if rendering_speed == "BALANCED": # for backward compatibility
|
if rendering_speed == "BALANCED": # for backward compatibility
|
||||||
rendering_speed = "DEFAULT"
|
rendering_speed = "DEFAULT"
|
||||||
|
|
||||||
@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
|
|
||||||
# Check if both image and mask are provided for editing mode
|
# Check if both image and mask are provided for editing mode
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
# Edit mode
|
|
||||||
path = "/proxy/ideogram/ideogram-v3/edit"
|
|
||||||
|
|
||||||
# Process image and mask
|
# Process image and mask
|
||||||
input_tensor = image.squeeze().cpu()
|
input_tensor = image.squeeze().cpu()
|
||||||
# Resize mask to match image dimension
|
# Resize mask to match image dimension
|
||||||
@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
if character_mask_binary:
|
if character_mask_binary:
|
||||||
files["character_mask_binary"] = character_mask_binary
|
files["character_mask_binary"] = character_mask_binary
|
||||||
|
|
||||||
# Execute the operation for edit mode
|
response = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
|
||||||
path=path,
|
response_model=IdeogramGenerateResponse,
|
||||||
method=HttpMethod.POST,
|
data=edit_request,
|
||||||
request_model=IdeogramV3EditRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
|
||||||
),
|
|
||||||
request=edit_request,
|
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif image is not None or mask is not None:
|
elif image is not None or mask is not None:
|
||||||
# If only one of image or mask is provided, raise an error
|
# If only one of image or mask is provided, raise an error
|
||||||
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||||
else:
|
else:
|
||||||
# Generation mode
|
|
||||||
path = "/proxy/ideogram/ideogram-v3/generate"
|
|
||||||
|
|
||||||
# Create generation request
|
# Create generation request
|
||||||
gen_request = IdeogramV3Request(
|
gen_request = IdeogramV3Request(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
if files:
|
if files:
|
||||||
gen_request.style_type = "AUTO"
|
gen_request.style_type = "AUTO"
|
||||||
|
|
||||||
# Execute the operation for generation mode
|
response = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
|
||||||
path=path,
|
response_model=IdeogramGenerateResponse,
|
||||||
method=HttpMethod.POST,
|
data=gen_request,
|
||||||
request_model=IdeogramV3Request,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
|
||||||
),
|
|
||||||
request=gen_request,
|
|
||||||
files=files if files else None,
|
files=files if files else None,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the operation and process response
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
|
|||||||
IdeogramV3,
|
IdeogramV3,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> IdeogramExtension:
|
async def comfy_entrypoint() -> IdeogramExtension:
|
||||||
return IdeogramExtension()
|
return IdeogramExtension()
|
||||||
|
|||||||
@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
|
|||||||
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
|
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
|
||||||
"""
|
"""
|
||||||
validate_image_dimensions(image, min_width=300, min_height=300)
|
validate_image_dimensions(image, min_width=300, min_height=300)
|
||||||
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
|
validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
|
||||||
|
|
||||||
|
|
||||||
def get_video_from_response(response) -> KlingVideoResult:
|
def get_video_from_response(response) -> KlingVideoResult:
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from inspect import cleandoc
|
import torch
|
||||||
from typing import Optional
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from io import BytesIO
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apis.pixverse_api import (
|
from comfy_api_nodes.apis.pixverse_api import (
|
||||||
PixverseTextVideoRequest,
|
PixverseTextVideoRequest,
|
||||||
PixverseImageVideoRequest,
|
PixverseImageVideoRequest,
|
||||||
@ -17,53 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
|
|||||||
PixverseIO,
|
PixverseIO,
|
||||||
pixverse_templates,
|
pixverse_templates,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_video_output,
|
||||||
SynchronousOperation,
|
poll_op,
|
||||||
PollingOperation,
|
sync_op,
|
||||||
EmptyRequest,
|
tensor_to_bytesio,
|
||||||
|
validate_string,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import validate_string, tensor_to_bytesio
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
|
|
||||||
AVERAGE_DURATION_T2V = 32
|
AVERAGE_DURATION_T2V = 32
|
||||||
AVERAGE_DURATION_I2V = 30
|
AVERAGE_DURATION_I2V = 30
|
||||||
AVERAGE_DURATION_T2T = 52
|
AVERAGE_DURATION_T2T = 52
|
||||||
|
|
||||||
|
|
||||||
def get_video_url_from_response(
|
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
|
||||||
response: PixverseGenerationStatusResponse,
|
response_upload = await sync_op(
|
||||||
) -> Optional[str]:
|
cls,
|
||||||
if response.Resp is None or response.Resp.url is None:
|
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
|
||||||
return None
|
response_model=PixverseImageUploadResponse,
|
||||||
return str(response.Resp.url)
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
|
||||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/pixverse/image/upload",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseImageUploadResponse,
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
files={"image": tensor_to_bytesio(image)},
|
files={"image": tensor_to_bytesio(image)},
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
|
||||||
|
|
||||||
if response_upload.Resp is None:
|
if response_upload.Resp is None:
|
||||||
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||||
|
|
||||||
return response_upload.Resp.img_id
|
return response_upload.Resp.img_id
|
||||||
|
|
||||||
|
|
||||||
@ -93,17 +69,13 @@ class PixverseTemplateNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class PixverseTextToVideoNode(IO.ComfyNode):
|
class PixverseTextToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseTextToVideoNode",
|
node_id="PixverseTextToVideoNode",
|
||||||
display_name="PixVerse Text to Video",
|
display_name="PixVerse Text to Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -170,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
|||||||
negative_prompt: str = None,
|
negative_prompt: str = None,
|
||||||
pixverse_template: int = None,
|
pixverse_template: int = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
if quality == PixverseQuality.res_1080p:
|
if quality == PixverseQuality.res_1080p:
|
||||||
@ -179,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
|
||||||
}
|
response_model=PixverseVideoResponse,
|
||||||
operation = SynchronousOperation(
|
data=PixverseTextVideoRequest(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/pixverse/video/text/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PixverseTextVideoRequest,
|
|
||||||
response_model=PixverseVideoResponse,
|
|
||||||
),
|
|
||||||
request=PixverseTextVideoRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
@ -200,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
|||||||
template_id=pixverse_template,
|
template_id=pixverse_template,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -221,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_T2V,
|
estimated_duration=AVERAGE_DURATION_T2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixverseImageToVideoNode(IO.ComfyNode):
|
class PixverseImageToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseImageToVideoNode",
|
node_id="PixverseImageToVideoNode",
|
||||||
display_name="PixVerse Image to Video",
|
display_name="PixVerse Image to Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
@ -309,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
|||||||
pixverse_template: int = None,
|
pixverse_template: int = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
auth = {
|
img_id = await upload_image_to_pixverse(cls, image)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
|
|
||||||
|
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
@ -323,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/pixverse/video/img/generate",
|
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=PixverseVideoResponse,
|
||||||
request_model=PixverseImageVideoRequest,
|
data=PixverseImageVideoRequest(
|
||||||
response_model=PixverseVideoResponse,
|
|
||||||
),
|
|
||||||
request=PixverseImageVideoRequest(
|
|
||||||
img_id=img_id,
|
img_id=img_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
@ -340,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
|||||||
template_id=pixverse_template,
|
template_id=pixverse_template,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -361,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_I2V,
|
estimated_duration=AVERAGE_DURATION_I2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixverseTransitionVideoNode(IO.ComfyNode):
|
class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseTransitionVideoNode",
|
node_id="PixverseTransitionVideoNode",
|
||||||
display_name="PixVerse Transition Video",
|
display_name="PixVerse Transition Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("first_frame"),
|
IO.Image.Input("first_frame"),
|
||||||
IO.Image.Input("last_frame"),
|
IO.Image.Input("last_frame"),
|
||||||
@ -445,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
|||||||
negative_prompt: str = None,
|
negative_prompt: str = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
auth = {
|
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
|
|
||||||
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
|
|
||||||
|
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
@ -460,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/pixverse/video/transition/generate",
|
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=PixverseVideoResponse,
|
||||||
request_model=PixverseTransitionVideoRequest,
|
data=PixverseTransitionVideoRequest(
|
||||||
response_model=PixverseVideoResponse,
|
|
||||||
),
|
|
||||||
request=PixverseTransitionVideoRequest(
|
|
||||||
first_frame_img=first_frame_id,
|
first_frame_img=first_frame_id,
|
||||||
last_frame_img=last_frame_id,
|
last_frame_img=last_frame_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -477,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
|||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -498,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_T2V,
|
estimated_duration=AVERAGE_DURATION_T2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixVerseExtension(ComfyExtension):
|
class PixVerseExtension(ComfyExtension):
|
||||||
|
|||||||
@ -8,9 +8,6 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
resize_mask_to_image,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.recraft_api import (
|
from comfy_api_nodes.apis.recraft_api import (
|
||||||
RecraftColor,
|
RecraftColor,
|
||||||
RecraftColorChain,
|
RecraftColorChain,
|
||||||
@ -28,6 +25,7 @@ from comfy_api_nodes.util import (
|
|||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
download_url_as_bytesio,
|
download_url_as_bytesio,
|
||||||
|
resize_mask_to_image,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
validate_string,
|
validate_string,
|
||||||
|
|||||||
@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
|||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
|||||||
reference_images = None
|
reference_images = None
|
||||||
if reference_image is not None:
|
if reference_image is not None:
|
||||||
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
reference_image,
|
reference_image,
|
||||||
|
|||||||
@ -14,9 +14,9 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
validate_aspect_ratio_closeness,
|
validate_image_aspect_ratio,
|
||||||
validate_image_aspect_ratio_range,
|
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
|
validate_images_aspect_ratio_closeness,
|
||||||
)
|
)
|
||||||
|
|
||||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||||
@ -114,7 +114,7 @@ async def execute_task(
|
|||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
||||||
response_model=TaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: r.state.value,
|
status_extractor=lambda r: r.state,
|
||||||
estimated_duration=estimated_duration,
|
estimated_duration=estimated_duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if get_number_of_images(image) > 1:
|
if get_number_of_images(image) > 1:
|
||||||
raise ValueError("Only one input image is allowed.")
|
raise ValueError("Only one input image is allowed.")
|
||||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
|||||||
if a > 7:
|
if a > 7:
|
||||||
raise ValueError("Too many images, maximum allowed is 7.")
|
raise ValueError("Too many images, maximum allowed is 7.")
|
||||||
for image in images:
|
for image in images:
|
||||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
|||||||
resolution: str,
|
resolution: str,
|
||||||
movement_amplitude: str,
|
movement_amplitude: str,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from .conversions import (
|
|||||||
downscale_image_tensor,
|
downscale_image_tensor,
|
||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
pil_to_bytesio,
|
pil_to_bytesio,
|
||||||
|
resize_mask_to_image,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
tensor_to_pil,
|
tensor_to_pil,
|
||||||
@ -34,12 +35,12 @@ from .upload_helpers import (
|
|||||||
)
|
)
|
||||||
from .validation_utils import (
|
from .validation_utils import (
|
||||||
get_number_of_images,
|
get_number_of_images,
|
||||||
validate_aspect_ratio_closeness,
|
validate_aspect_ratio_string,
|
||||||
validate_audio_duration,
|
validate_audio_duration,
|
||||||
validate_container_format_is_mp4,
|
validate_container_format_is_mp4,
|
||||||
validate_image_aspect_ratio,
|
validate_image_aspect_ratio,
|
||||||
validate_image_aspect_ratio_range,
|
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
|
validate_images_aspect_ratio_closeness,
|
||||||
validate_string,
|
validate_string,
|
||||||
validate_video_dimensions,
|
validate_video_dimensions,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
@ -70,6 +71,7 @@ __all__ = [
|
|||||||
"downscale_image_tensor",
|
"downscale_image_tensor",
|
||||||
"image_tensor_pair_to_batch",
|
"image_tensor_pair_to_batch",
|
||||||
"pil_to_bytesio",
|
"pil_to_bytesio",
|
||||||
|
"resize_mask_to_image",
|
||||||
"tensor_to_base64_string",
|
"tensor_to_base64_string",
|
||||||
"tensor_to_bytesio",
|
"tensor_to_bytesio",
|
||||||
"tensor_to_pil",
|
"tensor_to_pil",
|
||||||
@ -77,12 +79,12 @@ __all__ = [
|
|||||||
"video_to_base64_string",
|
"video_to_base64_string",
|
||||||
# Validation utilities
|
# Validation utilities
|
||||||
"get_number_of_images",
|
"get_number_of_images",
|
||||||
"validate_aspect_ratio_closeness",
|
"validate_aspect_ratio_string",
|
||||||
"validate_audio_duration",
|
"validate_audio_duration",
|
||||||
"validate_container_format_is_mp4",
|
"validate_container_format_is_mp4",
|
||||||
"validate_image_aspect_ratio",
|
"validate_image_aspect_ratio",
|
||||||
"validate_image_aspect_ratio_range",
|
|
||||||
"validate_image_dimensions",
|
"validate_image_dimensions",
|
||||||
|
"validate_images_aspect_ratio_closeness",
|
||||||
"validate_string",
|
"validate_string",
|
||||||
"validate_video_dimensions",
|
"validate_video_dimensions",
|
||||||
"validate_video_duration",
|
"validate_video_duration",
|
||||||
|
|||||||
@ -430,3 +430,24 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
|
|||||||
wav = torch.cat(frames, dim=1) # [C, T]
|
wav = torch.cat(frames, dim=1) # [C, T]
|
||||||
wav = _f32_pcm(wav)
|
wav = _f32_pcm(wav)
|
||||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||||
|
|
||||||
|
|
||||||
|
def resize_mask_to_image(
|
||||||
|
mask: torch.Tensor,
|
||||||
|
image: torch.Tensor,
|
||||||
|
upscale_method="nearest-exact",
|
||||||
|
crop="disabled",
|
||||||
|
allow_gradient=True,
|
||||||
|
add_channel_dim=False,
|
||||||
|
):
|
||||||
|
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
|
||||||
|
_, height, width, _ = image.shape
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
mask = mask.movedim(-1, 1)
|
||||||
|
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
|
||||||
|
mask = mask.movedim(1, -1)
|
||||||
|
if not add_channel_dim:
|
||||||
|
mask = mask.squeeze(-1)
|
||||||
|
if not allow_gradient:
|
||||||
|
mask = (mask > 0.5).float()
|
||||||
|
return mask
|
||||||
|
|||||||
@ -37,63 +37,62 @@ def validate_image_dimensions(
|
|||||||
|
|
||||||
def validate_image_aspect_ratio(
|
def validate_image_aspect_ratio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_aspect_ratio: Optional[float] = None,
|
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||||
max_aspect_ratio: Optional[float] = None,
|
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||||
):
|
|
||||||
width, height = get_image_dimensions(image)
|
|
||||||
aspect_ratio = width / height
|
|
||||||
|
|
||||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
|
||||||
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
|
|
||||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
|
||||||
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
|
|
||||||
|
|
||||||
|
|
||||||
def validate_image_aspect_ratio_range(
|
|
||||||
image: torch.Tensor,
|
|
||||||
min_ratio: tuple[float, float], # e.g. (1, 4)
|
|
||||||
max_ratio: tuple[float, float], # e.g. (4, 1)
|
|
||||||
*,
|
*,
|
||||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
a1, b1 = min_ratio
|
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
|
||||||
a2, b2 = max_ratio
|
|
||||||
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
|
|
||||||
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
|
|
||||||
lo, hi = (a1 / b1), (a2 / b2)
|
|
||||||
if lo > hi:
|
|
||||||
lo, hi = hi, lo
|
|
||||||
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
|
|
||||||
w, h = get_image_dimensions(image)
|
w, h = get_image_dimensions(image)
|
||||||
if w <= 0 or h <= 0:
|
if w <= 0 or h <= 0:
|
||||||
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||||||
ar = w / h
|
ar = w / h
|
||||||
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
|
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||||
if not ok:
|
|
||||||
op = "<" if strict else "≤"
|
|
||||||
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
|
|
||||||
return ar
|
return ar
|
||||||
|
|
||||||
|
|
||||||
def validate_aspect_ratio_closeness(
|
def validate_images_aspect_ratio_closeness(
|
||||||
start_img,
|
first_image: torch.Tensor,
|
||||||
end_img,
|
second_image: torch.Tensor,
|
||||||
min_rel: float,
|
min_rel: float, # e.g. 0.8
|
||||||
max_rel: float,
|
max_rel: float, # e.g. 1.25
|
||||||
*,
|
*,
|
||||||
strict: bool = False, # True => exclusive, False => inclusive
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
) -> None:
|
) -> float:
|
||||||
w1, h1 = get_image_dimensions(start_img)
|
"""
|
||||||
w2, h2 = get_image_dimensions(end_img)
|
Validates that the two images' aspect ratios are 'close'.
|
||||||
|
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
|
||||||
|
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
|
||||||
|
|
||||||
|
Returns the computed closeness factor C.
|
||||||
|
"""
|
||||||
|
w1, h1 = get_image_dimensions(first_image)
|
||||||
|
w2, h2 = get_image_dimensions(second_image)
|
||||||
if min(w1, h1, w2, h2) <= 0:
|
if min(w1, h1, w2, h2) <= 0:
|
||||||
raise ValueError("Invalid image dimensions")
|
raise ValueError("Invalid image dimensions")
|
||||||
ar1 = w1 / h1
|
ar1 = w1 / h1
|
||||||
ar2 = w2 / h2
|
ar2 = w2 / h2
|
||||||
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
|
|
||||||
closeness = max(ar1, ar2) / min(ar1, ar2)
|
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||||||
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
|
limit = max(max_rel, 1.0 / min_rel)
|
||||||
if (closeness >= limit) if strict else (closeness > limit):
|
if (closeness >= limit) if strict else (closeness > limit):
|
||||||
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
|
raise ValueError(
|
||||||
|
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
|
||||||
|
f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})."
|
||||||
|
)
|
||||||
|
return closeness
|
||||||
|
|
||||||
|
|
||||||
|
def validate_aspect_ratio_string(
|
||||||
|
aspect_ratio: str,
|
||||||
|
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||||
|
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||||
|
*,
|
||||||
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
|
) -> float:
|
||||||
|
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
|
||||||
|
ar = _parse_aspect_ratio_string(aspect_ratio)
|
||||||
|
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||||
|
return ar
|
||||||
|
|
||||||
|
|
||||||
def validate_video_dimensions(
|
def validate_video_dimensions(
|
||||||
@ -183,3 +182,49 @@ def validate_container_format_is_mp4(video: VideoInput) -> None:
|
|||||||
container_format = video.get_container_format()
|
container_format = video.get_container_format()
|
||||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||||
|
|
||||||
|
|
||||||
|
def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
||||||
|
a, b = r
|
||||||
|
if a <= 0 or b <= 0:
|
||||||
|
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
|
||||||
|
return a / b
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_ratio_bounds(
|
||||||
|
ar: float,
|
||||||
|
*,
|
||||||
|
min_ratio: Optional[tuple[float, float]] = None,
|
||||||
|
max_ratio: Optional[tuple[float, float]] = None,
|
||||||
|
strict: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||||
|
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
|
||||||
|
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
|
||||||
|
|
||||||
|
if lo is not None and hi is not None and lo > hi:
|
||||||
|
lo, hi = hi, lo # normalize order if caller swapped them
|
||||||
|
|
||||||
|
if lo is not None:
|
||||||
|
if (ar <= lo) if strict else (ar < lo):
|
||||||
|
op = "<" if strict else "≤"
|
||||||
|
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
|
||||||
|
if hi is not None:
|
||||||
|
if (ar >= hi) if strict else (ar > hi):
|
||||||
|
op = "<" if strict else "≤"
|
||||||
|
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_aspect_ratio_string(ar_str: str) -> float:
|
||||||
|
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
|
||||||
|
parts = ar_str.split(":")
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
|
||||||
|
try:
|
||||||
|
a = int(parts[0].strip())
|
||||||
|
b = int(parts[1].strip())
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
|
||||||
|
if a <= 0 or b <= 0:
|
||||||
|
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
|
||||||
|
return a / b
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user