mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
ideogram nodes update
This commit is contained in:
parent
8e48a6765b
commit
64fa54c3ad
@ -2,13 +2,13 @@ import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from itertools import chain
|
||||
from typing import Tuple, Dict, Any
|
||||
from typing import Tuple, Dict, Any, Literal
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch, ImageBatch
|
||||
from comfy.nodes.package_typing import CustomNode, Seed
|
||||
from comfy.utils import pil2tensor, tensor2pil
|
||||
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
||||
@ -19,11 +19,14 @@ ASPECT_RATIO_ENUM = ["ASPECT_1_1"] + list(chain.from_iterable(
|
||||
[f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"]
|
||||
for a, b in ASPECT_RATIOS
|
||||
))
|
||||
MODELS_ENUM = ["V_2", "V_2_TURBO"]
|
||||
V2_MODELS = ["V_2", "V_2_TURBO"]
|
||||
MODELS_ENUM = V2_MODELS + ["V_3"]
|
||||
AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"]
|
||||
STYLES_ENUM = ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"]
|
||||
RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS]
|
||||
|
||||
def to_v3_resolution(resolution:str) -> str:
|
||||
return resolution[len("RESOLUTION_"):].replace("_", "x")
|
||||
|
||||
def api_key_in_env_or_workflow(api_key_from_workflow: str):
|
||||
from comfy.cli_args import args
|
||||
@ -61,21 +64,42 @@ class IdeogramGenerate(CustomNode):
|
||||
api_key = api_key_in_env_or_workflow(api_key)
|
||||
headers = {"Api-Key": api_key, "Content-Type": "application/json"}
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"model": model,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
"style_type": style_type,
|
||||
}
|
||||
if model in V2_MODELS:
|
||||
payload = {
|
||||
"image_request": {
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"model": model,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
"style_type": style_type,
|
||||
}
|
||||
}
|
||||
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
if seed:
|
||||
payload["seed"] = seed
|
||||
if negative_prompt:
|
||||
payload["image_request"]["negative_prompt"] = negative_prompt
|
||||
if seed:
|
||||
payload["image_request"]["seed"] = seed
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/generate", headers=headers, json=payload)
|
||||
elif model == "V_3":
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"resolution": to_v3_resolution(resolution),
|
||||
"model": model,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
"style_type": style_type,
|
||||
}
|
||||
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
if seed:
|
||||
payload["seed"] = seed
|
||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload)
|
||||
else:
|
||||
raise ValueError(f"Invalid model={model}")
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
images = []
|
||||
@ -130,20 +154,41 @@ class IdeogramEdit(CustomNode):
|
||||
image_pil.save(image_bytes, format="PNG")
|
||||
mask_pil.save(mask_bytes, format="PNG")
|
||||
|
||||
files = {
|
||||
"image": ("image.png", image_bytes.getvalue()),
|
||||
"mask": ("mask.png", mask_bytes.getvalue()),
|
||||
}
|
||||
if model in V2_MODELS:
|
||||
files = {
|
||||
"image_file": ("image.png", image_bytes.getvalue()),
|
||||
"mask": ("mask.png", mask_bytes.getvalue()),
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
}
|
||||
if seed:
|
||||
data["seed"] = seed
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
"style_type": style_type,
|
||||
}
|
||||
if seed:
|
||||
data["seed"] = seed
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/edit", headers=headers, files=files, data=data)
|
||||
elif model == "V_3":
|
||||
files = {
|
||||
"image": ("image.png", image_bytes.getvalue()),
|
||||
"mask": ("mask.png", mask_bytes.getvalue()),
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
}
|
||||
if seed:
|
||||
data["seed"] = seed
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data)
|
||||
else:
|
||||
raise ValueError(f"Invalid model={model}")
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
for item in response.json()["data"]:
|
||||
@ -193,27 +238,55 @@ class IdeogramRemix(CustomNode):
|
||||
image_bytes = BytesIO()
|
||||
image_pil.save(image_bytes, format="PNG")
|
||||
|
||||
files = {
|
||||
"image": ("image.png", image_bytes.getvalue()),
|
||||
}
|
||||
if model in V2_MODELS:
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"image_weight": image_weight,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
"style_type": style_type,
|
||||
}
|
||||
files = {
|
||||
"image_file": ("image.png", image_bytes.getvalue()),
|
||||
}
|
||||
|
||||
if negative_prompt:
|
||||
data["negative_prompt"] = negative_prompt
|
||||
if seed:
|
||||
data["seed"] = seed
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"model": model,
|
||||
"image_weight": image_weight,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
"style_type": style_type,
|
||||
}
|
||||
|
||||
# data = {"image_request": data}
|
||||
if negative_prompt:
|
||||
data["negative_prompt"] = negative_prompt
|
||||
if seed:
|
||||
data["seed"] = seed
|
||||
|
||||
# data = {"image_request": data}
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/remix", headers=headers, files=files, data={
|
||||
"image_request": json.dumps(data)
|
||||
})
|
||||
elif model == "V_3":
|
||||
files = {
|
||||
"image": ("image.png", image_bytes.getvalue()),
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"resolution": to_v3_resolution(resolution),
|
||||
"image_weight": image_weight,
|
||||
"magic_prompt": magic_prompt_option,
|
||||
"num_images": num_images,
|
||||
"style_type": style_type,
|
||||
}
|
||||
|
||||
if negative_prompt:
|
||||
data["negative_prompt"] = negative_prompt
|
||||
if seed:
|
||||
data["seed"] = seed
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data)
|
||||
else:
|
||||
raise ValueError(f"Invalid model={model}")
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
for item in response.json()["data"]:
|
||||
@ -226,14 +299,81 @@ class IdeogramRemix(CustomNode):
|
||||
return (torch.cat(result_images, dim=0),)
|
||||
|
||||
|
||||
class IdeogramDescribe(CustomNode):
|
||||
"""
|
||||
A ComfyUI node to get a description of an image using the Ideogram API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||
"""
|
||||
Defines the input types for the node.
|
||||
"""
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"api_key": ("STRING", {"default": ""}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
FUNCTION = "describe"
|
||||
CATEGORY = "ideogram"
|
||||
|
||||
def describe(self, images: ImageBatch, api_key: str = "") -> tuple[list[str]]:
|
||||
"""
|
||||
Sends an image to the Ideogram API and returns a generated description.
|
||||
|
||||
Args:
|
||||
images: A batch of images as a tensor.
|
||||
api_key: The Ideogram API key.
|
||||
|
||||
Returns:
|
||||
A tuple containing the description string for the first image.
|
||||
"""
|
||||
api_key = api_key_in_env_or_workflow(api_key)
|
||||
headers = {"Api-Key": api_key}
|
||||
|
||||
descriptions_batch = []
|
||||
for image in images:
|
||||
pil_image = tensor2pil(image)
|
||||
|
||||
image_bytes = BytesIO()
|
||||
pil_image.save(image_bytes, format="PNG")
|
||||
image_bytes.seek(0)
|
||||
|
||||
files = {
|
||||
"image_file": ("image.png", image_bytes.getvalue(), "image/png"),
|
||||
}
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/describe", headers=headers, files=files)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
descriptions = data.get("descriptions", [])
|
||||
|
||||
if not descriptions:
|
||||
descriptions_batch.append("")
|
||||
else:
|
||||
first_description = descriptions[0].get("text", "")
|
||||
descriptions_batch.append(first_description)
|
||||
|
||||
return (descriptions_batch,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"IdeogramGenerate": IdeogramGenerate,
|
||||
"IdeogramEdit": IdeogramEdit,
|
||||
"IdeogramRemix": IdeogramRemix,
|
||||
"IdeogramDescribe": IdeogramDescribe,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Ideogram Generate": "Ideogram Generate",
|
||||
"Ideogram Edit": "Ideogram Edit",
|
||||
"Ideogram Remix": "Ideogram Remix",
|
||||
"Ideogram Describe": "Ideogram Describe",
|
||||
}
|
||||
|
||||
@ -3,10 +3,12 @@ import os
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy.component_model.tensor_types import RGBImageBatch
|
||||
from comfy_extras.nodes.nodes_ideogram import (
|
||||
IdeogramGenerate,
|
||||
IdeogramEdit,
|
||||
IdeogramRemix
|
||||
IdeogramRemix,
|
||||
IdeogramDescribe,
|
||||
)
|
||||
|
||||
|
||||
@ -23,13 +25,44 @@ def sample_image():
|
||||
return torch.ones((1, 1024, 1024, 3)) * 0.8 # Light gray image
|
||||
|
||||
|
||||
def test_ideogram_generate(api_key):
|
||||
@pytest.fixture
|
||||
def black_square_image() -> RGBImageBatch:
|
||||
# A black square image (1 batch, 1024x1024 pixels, 3 channels)
|
||||
return torch.zeros((1, 1024, 1024, 3), dtype=torch.float32)
|
||||
|
||||
|
||||
def test_ideogram_describe(api_key, black_square_image):
|
||||
"""
|
||||
Tests the IdeogramDescribe node by passing it a black square image and
|
||||
asserting that the returned description contains "black" and "square".
|
||||
"""
|
||||
node = IdeogramDescribe()
|
||||
|
||||
# The node's method returns a tuple containing a list of descriptions
|
||||
descriptions_list, = node.describe(
|
||||
images=black_square_image,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
# We passed one image, so we expect one description in the list
|
||||
assert isinstance(descriptions_list, list)
|
||||
assert len(descriptions_list) == 1
|
||||
|
||||
description = descriptions_list[0]
|
||||
|
||||
assert isinstance(description, str)
|
||||
assert "black" in description.lower()
|
||||
assert "square" in description.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
||||
def test_ideogram_generate(api_key, model):
|
||||
node = IdeogramGenerate()
|
||||
|
||||
image, = node.generate(
|
||||
prompt="a serene mountain landscape at sunset with snow-capped peaks",
|
||||
resolution="RESOLUTION_1024_1024",
|
||||
model="V_2_TURBO",
|
||||
model=model,
|
||||
magic_prompt_option="AUTO",
|
||||
api_key=api_key,
|
||||
num_images=1
|
||||
@ -41,8 +74,8 @@ def test_ideogram_generate(api_key):
|
||||
assert image.dtype == torch.float32
|
||||
assert torch.all((image >= 0) & (image <= 1))
|
||||
|
||||
|
||||
def test_ideogram_edit(api_key, sample_image):
|
||||
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
||||
def test_ideogram_edit(api_key, sample_image, model):
|
||||
node = IdeogramEdit()
|
||||
|
||||
# white is areas to keep, black is areas to repaint
|
||||
@ -56,7 +89,7 @@ def test_ideogram_edit(api_key, sample_image):
|
||||
masks=mask,
|
||||
magic_prompt_option="OFF",
|
||||
prompt="a solid black rectangle",
|
||||
model="V_2_TURBO",
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
num_images=1,
|
||||
)
|
||||
@ -77,15 +110,15 @@ def test_ideogram_edit(api_key, sample_image):
|
||||
assert center_mean < outer_mean, f"Center region ({center_mean:.3f}) should be darker than outer region ({outer_mean:.3f})"
|
||||
assert center_mean < 0.6, f"Center region ({center_mean:.3f}) should be dark"
|
||||
|
||||
|
||||
def test_ideogram_remix(api_key, sample_image):
|
||||
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
||||
def test_ideogram_remix(api_key, sample_image, model):
|
||||
node = IdeogramRemix()
|
||||
|
||||
image, = node.remix(
|
||||
images=sample_image,
|
||||
prompt="transform into a vibrant blue ocean scene with waves",
|
||||
resolution="RESOLUTION_1024_1024",
|
||||
model="V_2_TURBO",
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user