mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 23:00:20 +08:00
ideogram nodes update
This commit is contained in:
parent
8e48a6765b
commit
64fa54c3ad
@ -2,13 +2,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Tuple, Dict, Any
|
from typing import Tuple, Dict, Any, Literal
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch, ImageBatch
|
||||||
from comfy.nodes.package_typing import CustomNode, Seed
|
from comfy.nodes.package_typing import CustomNode, Seed
|
||||||
from comfy.utils import pil2tensor, tensor2pil
|
from comfy.utils import pil2tensor, tensor2pil
|
||||||
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
||||||
@ -19,11 +19,14 @@ ASPECT_RATIO_ENUM = ["ASPECT_1_1"] + list(chain.from_iterable(
|
|||||||
[f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"]
|
[f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"]
|
||||||
for a, b in ASPECT_RATIOS
|
for a, b in ASPECT_RATIOS
|
||||||
))
|
))
|
||||||
MODELS_ENUM = ["V_2", "V_2_TURBO"]
|
V2_MODELS = ["V_2", "V_2_TURBO"]
|
||||||
|
MODELS_ENUM = V2_MODELS + ["V_3"]
|
||||||
AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"]
|
AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"]
|
||||||
STYLES_ENUM = ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"]
|
STYLES_ENUM = ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"]
|
||||||
RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS]
|
RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS]
|
||||||
|
|
||||||
|
def to_v3_resolution(resolution:str) -> str:
|
||||||
|
return resolution[len("RESOLUTION_"):].replace("_", "x")
|
||||||
|
|
||||||
def api_key_in_env_or_workflow(api_key_from_workflow: str):
|
def api_key_in_env_or_workflow(api_key_from_workflow: str):
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
@ -61,21 +64,42 @@ class IdeogramGenerate(CustomNode):
|
|||||||
api_key = api_key_in_env_or_workflow(api_key)
|
api_key = api_key_in_env_or_workflow(api_key)
|
||||||
headers = {"Api-Key": api_key, "Content-Type": "application/json"}
|
headers = {"Api-Key": api_key, "Content-Type": "application/json"}
|
||||||
|
|
||||||
payload = {
|
if model in V2_MODELS:
|
||||||
"prompt": prompt,
|
payload = {
|
||||||
"resolution": resolution,
|
"image_request": {
|
||||||
"model": model,
|
"prompt": prompt,
|
||||||
"magic_prompt": magic_prompt_option,
|
"resolution": resolution,
|
||||||
"num_images": num_images,
|
"model": model,
|
||||||
"style_type": style_type,
|
"magic_prompt_option": magic_prompt_option,
|
||||||
}
|
"num_images": num_images,
|
||||||
|
"style_type": style_type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if negative_prompt:
|
if negative_prompt:
|
||||||
payload["negative_prompt"] = negative_prompt
|
payload["image_request"]["negative_prompt"] = negative_prompt
|
||||||
if seed:
|
if seed:
|
||||||
payload["seed"] = seed
|
payload["image_request"]["seed"] = seed
|
||||||
|
|
||||||
|
response = requests.post("https://api.ideogram.ai/generate", headers=headers, json=payload)
|
||||||
|
elif model == "V_3":
|
||||||
|
payload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"resolution": to_v3_resolution(resolution),
|
||||||
|
"model": model,
|
||||||
|
"magic_prompt": magic_prompt_option,
|
||||||
|
"num_images": num_images,
|
||||||
|
"style_type": style_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
if negative_prompt:
|
||||||
|
payload["negative_prompt"] = negative_prompt
|
||||||
|
if seed:
|
||||||
|
payload["seed"] = seed
|
||||||
|
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model={model}")
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
@ -130,20 +154,41 @@ class IdeogramEdit(CustomNode):
|
|||||||
image_pil.save(image_bytes, format="PNG")
|
image_pil.save(image_bytes, format="PNG")
|
||||||
mask_pil.save(mask_bytes, format="PNG")
|
mask_pil.save(mask_bytes, format="PNG")
|
||||||
|
|
||||||
files = {
|
if model in V2_MODELS:
|
||||||
"image": ("image.png", image_bytes.getvalue()),
|
files = {
|
||||||
"mask": ("mask.png", mask_bytes.getvalue()),
|
"image_file": ("image.png", image_bytes.getvalue()),
|
||||||
}
|
"mask": ("mask.png", mask_bytes.getvalue()),
|
||||||
|
}
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"magic_prompt": magic_prompt_option,
|
"model": model,
|
||||||
"num_images": num_images,
|
"magic_prompt_option": magic_prompt_option,
|
||||||
}
|
"num_images": num_images,
|
||||||
if seed:
|
"style_type": style_type,
|
||||||
data["seed"] = seed
|
}
|
||||||
|
if seed:
|
||||||
|
data["seed"] = seed
|
||||||
|
|
||||||
|
response = requests.post("https://api.ideogram.ai/edit", headers=headers, files=files, data=data)
|
||||||
|
elif model == "V_3":
|
||||||
|
files = {
|
||||||
|
"image": ("image.png", image_bytes.getvalue()),
|
||||||
|
"mask": ("mask.png", mask_bytes.getvalue()),
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"magic_prompt": magic_prompt_option,
|
||||||
|
"num_images": num_images,
|
||||||
|
}
|
||||||
|
if seed:
|
||||||
|
data["seed"] = seed
|
||||||
|
|
||||||
|
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model={model}")
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
for item in response.json()["data"]:
|
for item in response.json()["data"]:
|
||||||
@ -193,27 +238,55 @@ class IdeogramRemix(CustomNode):
|
|||||||
image_bytes = BytesIO()
|
image_bytes = BytesIO()
|
||||||
image_pil.save(image_bytes, format="PNG")
|
image_pil.save(image_bytes, format="PNG")
|
||||||
|
|
||||||
files = {
|
if model in V2_MODELS:
|
||||||
"image": ("image.png", image_bytes.getvalue()),
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
files = {
|
||||||
"prompt": prompt,
|
"image_file": ("image.png", image_bytes.getvalue()),
|
||||||
"resolution": resolution,
|
}
|
||||||
"image_weight": image_weight,
|
|
||||||
"magic_prompt": magic_prompt_option,
|
|
||||||
"num_images": num_images,
|
|
||||||
"style_type": style_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
if negative_prompt:
|
data = {
|
||||||
data["negative_prompt"] = negative_prompt
|
"prompt": prompt,
|
||||||
if seed:
|
"resolution": resolution,
|
||||||
data["seed"] = seed
|
"model": model,
|
||||||
|
"image_weight": image_weight,
|
||||||
|
"magic_prompt_option": magic_prompt_option,
|
||||||
|
"num_images": num_images,
|
||||||
|
"style_type": style_type,
|
||||||
|
}
|
||||||
|
|
||||||
# data = {"image_request": data}
|
if negative_prompt:
|
||||||
|
data["negative_prompt"] = negative_prompt
|
||||||
|
if seed:
|
||||||
|
data["seed"] = seed
|
||||||
|
|
||||||
|
# data = {"image_request": data}
|
||||||
|
|
||||||
|
response = requests.post("https://api.ideogram.ai/remix", headers=headers, files=files, data={
|
||||||
|
"image_request": json.dumps(data)
|
||||||
|
})
|
||||||
|
elif model == "V_3":
|
||||||
|
files = {
|
||||||
|
"image": ("image.png", image_bytes.getvalue()),
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"resolution": to_v3_resolution(resolution),
|
||||||
|
"image_weight": image_weight,
|
||||||
|
"magic_prompt": magic_prompt_option,
|
||||||
|
"num_images": num_images,
|
||||||
|
"style_type": style_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
if negative_prompt:
|
||||||
|
data["negative_prompt"] = negative_prompt
|
||||||
|
if seed:
|
||||||
|
data["seed"] = seed
|
||||||
|
|
||||||
|
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model={model}")
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
for item in response.json()["data"]:
|
for item in response.json()["data"]:
|
||||||
@ -226,14 +299,81 @@ class IdeogramRemix(CustomNode):
|
|||||||
return (torch.cat(result_images, dim=0),)
|
return (torch.cat(result_images, dim=0),)
|
||||||
|
|
||||||
|
|
||||||
|
class IdeogramDescribe(CustomNode):
|
||||||
|
"""
|
||||||
|
A ComfyUI node to get a description of an image using the Ideogram API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Defines the input types for the node.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"api_key": ("STRING", {"default": ""}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
OUTPUT_IS_LIST = (True,)
|
||||||
|
FUNCTION = "describe"
|
||||||
|
CATEGORY = "ideogram"
|
||||||
|
|
||||||
|
def describe(self, images: ImageBatch, api_key: str = "") -> tuple[list[str]]:
|
||||||
|
"""
|
||||||
|
Sends an image to the Ideogram API and returns a generated description.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: A batch of images as a tensor.
|
||||||
|
api_key: The Ideogram API key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the description string for the first image.
|
||||||
|
"""
|
||||||
|
api_key = api_key_in_env_or_workflow(api_key)
|
||||||
|
headers = {"Api-Key": api_key}
|
||||||
|
|
||||||
|
descriptions_batch = []
|
||||||
|
for image in images:
|
||||||
|
pil_image = tensor2pil(image)
|
||||||
|
|
||||||
|
image_bytes = BytesIO()
|
||||||
|
pil_image.save(image_bytes, format="PNG")
|
||||||
|
image_bytes.seek(0)
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"image_file": ("image.png", image_bytes.getvalue(), "image/png"),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post("https://api.ideogram.ai/describe", headers=headers, files=files)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
descriptions = data.get("descriptions", [])
|
||||||
|
|
||||||
|
if not descriptions:
|
||||||
|
descriptions_batch.append("")
|
||||||
|
else:
|
||||||
|
first_description = descriptions[0].get("text", "")
|
||||||
|
descriptions_batch.append(first_description)
|
||||||
|
|
||||||
|
return (descriptions_batch,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"IdeogramGenerate": IdeogramGenerate,
|
"IdeogramGenerate": IdeogramGenerate,
|
||||||
"IdeogramEdit": IdeogramEdit,
|
"IdeogramEdit": IdeogramEdit,
|
||||||
"IdeogramRemix": IdeogramRemix,
|
"IdeogramRemix": IdeogramRemix,
|
||||||
|
"IdeogramDescribe": IdeogramDescribe,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"Ideogram Generate": "Ideogram Generate",
|
"Ideogram Generate": "Ideogram Generate",
|
||||||
"Ideogram Edit": "Ideogram Edit",
|
"Ideogram Edit": "Ideogram Edit",
|
||||||
"Ideogram Remix": "Ideogram Remix",
|
"Ideogram Remix": "Ideogram Remix",
|
||||||
|
"Ideogram Describe": "Ideogram Describe",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,10 +3,12 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy.component_model.tensor_types import RGBImageBatch
|
||||||
from comfy_extras.nodes.nodes_ideogram import (
|
from comfy_extras.nodes.nodes_ideogram import (
|
||||||
IdeogramGenerate,
|
IdeogramGenerate,
|
||||||
IdeogramEdit,
|
IdeogramEdit,
|
||||||
IdeogramRemix
|
IdeogramRemix,
|
||||||
|
IdeogramDescribe,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -23,13 +25,44 @@ def sample_image():
|
|||||||
return torch.ones((1, 1024, 1024, 3)) * 0.8 # Light gray image
|
return torch.ones((1, 1024, 1024, 3)) * 0.8 # Light gray image
|
||||||
|
|
||||||
|
|
||||||
def test_ideogram_generate(api_key):
|
@pytest.fixture
|
||||||
|
def black_square_image() -> RGBImageBatch:
|
||||||
|
# A black square image (1 batch, 1024x1024 pixels, 3 channels)
|
||||||
|
return torch.zeros((1, 1024, 1024, 3), dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ideogram_describe(api_key, black_square_image):
|
||||||
|
"""
|
||||||
|
Tests the IdeogramDescribe node by passing it a black square image and
|
||||||
|
asserting that the returned description contains "black" and "square".
|
||||||
|
"""
|
||||||
|
node = IdeogramDescribe()
|
||||||
|
|
||||||
|
# The node's method returns a tuple containing a list of descriptions
|
||||||
|
descriptions_list, = node.describe(
|
||||||
|
images=black_square_image,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# We passed one image, so we expect one description in the list
|
||||||
|
assert isinstance(descriptions_list, list)
|
||||||
|
assert len(descriptions_list) == 1
|
||||||
|
|
||||||
|
description = descriptions_list[0]
|
||||||
|
|
||||||
|
assert isinstance(description, str)
|
||||||
|
assert "black" in description.lower()
|
||||||
|
assert "square" in description.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
||||||
|
def test_ideogram_generate(api_key, model):
|
||||||
node = IdeogramGenerate()
|
node = IdeogramGenerate()
|
||||||
|
|
||||||
image, = node.generate(
|
image, = node.generate(
|
||||||
prompt="a serene mountain landscape at sunset with snow-capped peaks",
|
prompt="a serene mountain landscape at sunset with snow-capped peaks",
|
||||||
resolution="RESOLUTION_1024_1024",
|
resolution="RESOLUTION_1024_1024",
|
||||||
model="V_2_TURBO",
|
model=model,
|
||||||
magic_prompt_option="AUTO",
|
magic_prompt_option="AUTO",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
num_images=1
|
num_images=1
|
||||||
@ -41,8 +74,8 @@ def test_ideogram_generate(api_key):
|
|||||||
assert image.dtype == torch.float32
|
assert image.dtype == torch.float32
|
||||||
assert torch.all((image >= 0) & (image <= 1))
|
assert torch.all((image >= 0) & (image <= 1))
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
||||||
def test_ideogram_edit(api_key, sample_image):
|
def test_ideogram_edit(api_key, sample_image, model):
|
||||||
node = IdeogramEdit()
|
node = IdeogramEdit()
|
||||||
|
|
||||||
# white is areas to keep, black is areas to repaint
|
# white is areas to keep, black is areas to repaint
|
||||||
@ -56,7 +89,7 @@ def test_ideogram_edit(api_key, sample_image):
|
|||||||
masks=mask,
|
masks=mask,
|
||||||
magic_prompt_option="OFF",
|
magic_prompt_option="OFF",
|
||||||
prompt="a solid black rectangle",
|
prompt="a solid black rectangle",
|
||||||
model="V_2_TURBO",
|
model=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
num_images=1,
|
num_images=1,
|
||||||
)
|
)
|
||||||
@ -77,15 +110,15 @@ def test_ideogram_edit(api_key, sample_image):
|
|||||||
assert center_mean < outer_mean, f"Center region ({center_mean:.3f}) should be darker than outer region ({outer_mean:.3f})"
|
assert center_mean < outer_mean, f"Center region ({center_mean:.3f}) should be darker than outer region ({outer_mean:.3f})"
|
||||||
assert center_mean < 0.6, f"Center region ({center_mean:.3f}) should be dark"
|
assert center_mean < 0.6, f"Center region ({center_mean:.3f}) should be dark"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
||||||
def test_ideogram_remix(api_key, sample_image):
|
def test_ideogram_remix(api_key, sample_image, model):
|
||||||
node = IdeogramRemix()
|
node = IdeogramRemix()
|
||||||
|
|
||||||
image, = node.remix(
|
image, = node.remix(
|
||||||
images=sample_image,
|
images=sample_image,
|
||||||
prompt="transform into a vibrant blue ocean scene with waves",
|
prompt="transform into a vibrant blue ocean scene with waves",
|
||||||
resolution="RESOLUTION_1024_1024",
|
resolution="RESOLUTION_1024_1024",
|
||||||
model="V_2_TURBO",
|
model=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
num_images=1
|
num_images=1
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user