ideogram nodes update

This commit is contained in:
Benjamin Berman 2025-06-05 18:02:37 -07:00
parent 8e48a6765b
commit 64fa54c3ad
2 changed files with 227 additions and 54 deletions

View File

@ -2,13 +2,13 @@ import json
import os
from io import BytesIO
from itertools import chain
from typing import Tuple, Dict, Any
from typing import Tuple, Dict, Any, Literal
import requests
import torch
from PIL import Image
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch, ImageBatch
from comfy.nodes.package_typing import CustomNode, Seed
from comfy.utils import pil2tensor, tensor2pil
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
@ -19,11 +19,14 @@ ASPECT_RATIO_ENUM = ["ASPECT_1_1"] + list(chain.from_iterable(
[f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"]
for a, b in ASPECT_RATIOS
))
MODELS_ENUM = ["V_2", "V_2_TURBO"]
V2_MODELS = ["V_2", "V_2_TURBO"]
MODELS_ENUM = V2_MODELS + ["V_3"]
AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"]
STYLES_ENUM = ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"]
RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS]
def to_v3_resolution(resolution:str) -> str:
return resolution[len("RESOLUTION_"):].replace("_", "x")
def api_key_in_env_or_workflow(api_key_from_workflow: str):
from comfy.cli_args import args
@ -61,21 +64,42 @@ class IdeogramGenerate(CustomNode):
api_key = api_key_in_env_or_workflow(api_key)
headers = {"Api-Key": api_key, "Content-Type": "application/json"}
payload = {
"prompt": prompt,
"resolution": resolution,
"model": model,
"magic_prompt": magic_prompt_option,
"num_images": num_images,
"style_type": style_type,
}
if model in V2_MODELS:
payload = {
"image_request": {
"prompt": prompt,
"resolution": resolution,
"model": model,
"magic_prompt_option": magic_prompt_option,
"num_images": num_images,
"style_type": style_type,
}
}
if negative_prompt:
payload["negative_prompt"] = negative_prompt
if seed:
payload["seed"] = seed
if negative_prompt:
payload["image_request"]["negative_prompt"] = negative_prompt
if seed:
payload["image_request"]["seed"] = seed
response = requests.post("https://api.ideogram.ai/generate", headers=headers, json=payload)
elif model == "V_3":
payload = {
"prompt": prompt,
"resolution": to_v3_resolution(resolution),
"model": model,
"magic_prompt": magic_prompt_option,
"num_images": num_images,
"style_type": style_type,
}
if negative_prompt:
payload["negative_prompt"] = negative_prompt
if seed:
payload["seed"] = seed
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload)
else:
raise ValueError(f"Invalid model={model}")
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload)
response.raise_for_status()
images = []
@ -130,20 +154,41 @@ class IdeogramEdit(CustomNode):
image_pil.save(image_bytes, format="PNG")
mask_pil.save(mask_bytes, format="PNG")
files = {
"image": ("image.png", image_bytes.getvalue()),
"mask": ("mask.png", mask_bytes.getvalue()),
}
if model in V2_MODELS:
files = {
"image_file": ("image.png", image_bytes.getvalue()),
"mask": ("mask.png", mask_bytes.getvalue()),
}
data = {
"prompt": prompt,
"magic_prompt": magic_prompt_option,
"num_images": num_images,
}
if seed:
data["seed"] = seed
data = {
"prompt": prompt,
"model": model,
"magic_prompt_option": magic_prompt_option,
"num_images": num_images,
"style_type": style_type,
}
if seed:
data["seed"] = seed
response = requests.post("https://api.ideogram.ai/edit", headers=headers, files=files, data=data)
elif model == "V_3":
files = {
"image": ("image.png", image_bytes.getvalue()),
"mask": ("mask.png", mask_bytes.getvalue()),
}
data = {
"prompt": prompt,
"magic_prompt": magic_prompt_option,
"num_images": num_images,
}
if seed:
data["seed"] = seed
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data)
else:
raise ValueError(f"Invalid model={model}")
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data)
response.raise_for_status()
for item in response.json()["data"]:
@ -193,27 +238,55 @@ class IdeogramRemix(CustomNode):
image_bytes = BytesIO()
image_pil.save(image_bytes, format="PNG")
files = {
"image": ("image.png", image_bytes.getvalue()),
}
if model in V2_MODELS:
data = {
"prompt": prompt,
"resolution": resolution,
"image_weight": image_weight,
"magic_prompt": magic_prompt_option,
"num_images": num_images,
"style_type": style_type,
}
files = {
"image_file": ("image.png", image_bytes.getvalue()),
}
if negative_prompt:
data["negative_prompt"] = negative_prompt
if seed:
data["seed"] = seed
data = {
"prompt": prompt,
"resolution": resolution,
"model": model,
"image_weight": image_weight,
"magic_prompt_option": magic_prompt_option,
"num_images": num_images,
"style_type": style_type,
}
# data = {"image_request": data}
if negative_prompt:
data["negative_prompt"] = negative_prompt
if seed:
data["seed"] = seed
# data = {"image_request": data}
response = requests.post("https://api.ideogram.ai/remix", headers=headers, files=files, data={
"image_request": json.dumps(data)
})
elif model == "V_3":
files = {
"image": ("image.png", image_bytes.getvalue()),
}
data = {
"prompt": prompt,
"resolution": to_v3_resolution(resolution),
"image_weight": image_weight,
"magic_prompt": magic_prompt_option,
"num_images": num_images,
"style_type": style_type,
}
if negative_prompt:
data["negative_prompt"] = negative_prompt
if seed:
data["seed"] = seed
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data)
else:
raise ValueError(f"Invalid model={model}")
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data)
response.raise_for_status()
for item in response.json()["data"]:
@ -226,14 +299,81 @@ class IdeogramRemix(CustomNode):
return (torch.cat(result_images, dim=0),)
class IdeogramDescribe(CustomNode):
"""
A ComfyUI node to get a description of an image using the Ideogram API.
"""
@classmethod
def INPUT_TYPES(cls) -> Dict[str, Any]:
"""
Defines the input types for the node.
"""
return {
"required": {
"images": ("IMAGE",),
},
"optional": {
"api_key": ("STRING", {"default": ""}),
}
}
RETURN_TYPES = ("STRING",)
OUTPUT_IS_LIST = (True,)
FUNCTION = "describe"
CATEGORY = "ideogram"
def describe(self, images: ImageBatch, api_key: str = "") -> tuple[list[str]]:
"""
Sends an image to the Ideogram API and returns a generated description.
Args:
images: A batch of images as a tensor.
api_key: The Ideogram API key.
Returns:
A tuple containing the description string for the first image.
"""
api_key = api_key_in_env_or_workflow(api_key)
headers = {"Api-Key": api_key}
descriptions_batch = []
for image in images:
pil_image = tensor2pil(image)
image_bytes = BytesIO()
pil_image.save(image_bytes, format="PNG")
image_bytes.seek(0)
files = {
"image_file": ("image.png", image_bytes.getvalue(), "image/png"),
}
response = requests.post("https://api.ideogram.ai/describe", headers=headers, files=files)
response.raise_for_status()
data = response.json()
descriptions = data.get("descriptions", [])
if not descriptions:
descriptions_batch.append("")
else:
first_description = descriptions[0].get("text", "")
descriptions_batch.append(first_description)
return (descriptions_batch,)
NODE_CLASS_MAPPINGS = {
"IdeogramGenerate": IdeogramGenerate,
"IdeogramEdit": IdeogramEdit,
"IdeogramRemix": IdeogramRemix,
"IdeogramDescribe": IdeogramDescribe,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Ideogram Generate": "Ideogram Generate",
"Ideogram Edit": "Ideogram Edit",
"Ideogram Remix": "Ideogram Remix",
"Ideogram Describe": "Ideogram Describe",
}

View File

@ -3,10 +3,12 @@ import os
import pytest
import torch
from comfy.component_model.tensor_types import RGBImageBatch
from comfy_extras.nodes.nodes_ideogram import (
IdeogramGenerate,
IdeogramEdit,
IdeogramRemix
IdeogramRemix,
IdeogramDescribe,
)
@ -23,13 +25,44 @@ def sample_image():
return torch.ones((1, 1024, 1024, 3)) * 0.8 # Light gray image
def test_ideogram_generate(api_key):
@pytest.fixture
def black_square_image() -> RGBImageBatch:
# A black square image (1 batch, 1024x1024 pixels, 3 channels)
return torch.zeros((1, 1024, 1024, 3), dtype=torch.float32)
def test_ideogram_describe(api_key, black_square_image):
"""
Tests the IdeogramDescribe node by passing it a black square image and
asserting that the returned description contains "black" and "square".
"""
node = IdeogramDescribe()
# The node's method returns a tuple containing a list of descriptions
descriptions_list, = node.describe(
images=black_square_image,
api_key=api_key
)
# We passed one image, so we expect one description in the list
assert isinstance(descriptions_list, list)
assert len(descriptions_list) == 1
description = descriptions_list[0]
assert isinstance(description, str)
assert "black" in description.lower()
assert "square" in description.lower()
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
def test_ideogram_generate(api_key, model):
node = IdeogramGenerate()
image, = node.generate(
prompt="a serene mountain landscape at sunset with snow-capped peaks",
resolution="RESOLUTION_1024_1024",
model="V_2_TURBO",
model=model,
magic_prompt_option="AUTO",
api_key=api_key,
num_images=1
@ -41,8 +74,8 @@ def test_ideogram_generate(api_key):
assert image.dtype == torch.float32
assert torch.all((image >= 0) & (image <= 1))
def test_ideogram_edit(api_key, sample_image):
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
def test_ideogram_edit(api_key, sample_image, model):
node = IdeogramEdit()
# white is areas to keep, black is areas to repaint
@ -56,7 +89,7 @@ def test_ideogram_edit(api_key, sample_image):
masks=mask,
magic_prompt_option="OFF",
prompt="a solid black rectangle",
model="V_2_TURBO",
model=model,
api_key=api_key,
num_images=1,
)
@ -77,15 +110,15 @@ def test_ideogram_edit(api_key, sample_image):
assert center_mean < outer_mean, f"Center region ({center_mean:.3f}) should be darker than outer region ({outer_mean:.3f})"
assert center_mean < 0.6, f"Center region ({center_mean:.3f}) should be dark"
def test_ideogram_remix(api_key, sample_image):
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
def test_ideogram_remix(api_key, sample_image, model):
node = IdeogramRemix()
image, = node.remix(
images=sample_image,
prompt="transform into a vibrant blue ocean scene with waves",
resolution="RESOLUTION_1024_1024",
model="V_2_TURBO",
model=model,
api_key=api_key,
num_images=1
)