diff --git a/comfy_extras/nodes/nodes_ideogram.py b/comfy_extras/nodes/nodes_ideogram.py index d61460e56..382abaa28 100644 --- a/comfy_extras/nodes/nodes_ideogram.py +++ b/comfy_extras/nodes/nodes_ideogram.py @@ -2,13 +2,13 @@ import json import os from io import BytesIO from itertools import chain -from typing import Tuple, Dict, Any +from typing import Tuple, Dict, Any, Literal import requests import torch from PIL import Image -from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch +from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch, ImageBatch from comfy.nodes.package_typing import CustomNode, Seed from comfy.utils import pil2tensor, tensor2pil from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS @@ -19,11 +19,14 @@ ASPECT_RATIO_ENUM = ["ASPECT_1_1"] + list(chain.from_iterable( [f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"] for a, b in ASPECT_RATIOS )) -MODELS_ENUM = ["V_2", "V_2_TURBO"] +V2_MODELS = ["V_2", "V_2_TURBO"] +MODELS_ENUM = V2_MODELS + ["V_3"] AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"] STYLES_ENUM = ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"] RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS] +def to_v3_resolution(resolution:str) -> str: + return resolution[len("RESOLUTION_"):].replace("_", "x") def api_key_in_env_or_workflow(api_key_from_workflow: str): from comfy.cli_args import args @@ -61,21 +64,42 @@ class IdeogramGenerate(CustomNode): api_key = api_key_in_env_or_workflow(api_key) headers = {"Api-Key": api_key, "Content-Type": "application/json"} - payload = { - "prompt": prompt, - "resolution": resolution, - "model": model, - "magic_prompt": magic_prompt_option, - "num_images": num_images, - "style_type": style_type, - } + if model in V2_MODELS: + payload = { + "image_request": { + "prompt": prompt, + "resolution": resolution, + "model": model, + "magic_prompt_option": magic_prompt_option, + "num_images": num_images, + "style_type": style_type, + } + } - if negative_prompt: - payload["negative_prompt"] = negative_prompt - if seed: - payload["seed"] = seed + if negative_prompt: + payload["image_request"]["negative_prompt"] = negative_prompt + if seed: + payload["image_request"]["seed"] = seed + + response = requests.post("https://api.ideogram.ai/generate", headers=headers, json=payload) + elif model == "V_3": + payload = { + "prompt": prompt, + "resolution": to_v3_resolution(resolution), + "model": model, + "magic_prompt": magic_prompt_option, + "num_images": num_images, + "style_type": style_type, + } + + if negative_prompt: + payload["negative_prompt"] = negative_prompt + if seed: + payload["seed"] = seed + response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload) + else: + raise ValueError(f"Invalid model={model}") - response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload) response.raise_for_status() images = [] @@ -130,20 +154,41 @@ class IdeogramEdit(CustomNode): image_pil.save(image_bytes, format="PNG") mask_pil.save(mask_bytes, format="PNG") - files = { - "image": ("image.png", image_bytes.getvalue()), - "mask": ("mask.png", mask_bytes.getvalue()), - } + if model in V2_MODELS: + files = { + "image_file": ("image.png", image_bytes.getvalue()), + "mask": ("mask.png", mask_bytes.getvalue()), + } - data = { - "prompt": prompt, - "magic_prompt": magic_prompt_option, - "num_images": num_images, - } - if seed: - data["seed"] = seed + data = { + "prompt": prompt, + "model": model, + "magic_prompt_option": magic_prompt_option, + "num_images": num_images, + "style_type": style_type, + } + if seed: + data["seed"] = seed + + response = requests.post("https://api.ideogram.ai/edit", headers=headers, files=files, data=data) + elif model == "V_3": + files = { + "image": ("image.png", image_bytes.getvalue()), + "mask": ("mask.png", mask_bytes.getvalue()), + } + + data = { + "prompt": prompt, + "magic_prompt": magic_prompt_option, + "num_images": num_images, + } + if seed: + data["seed"] = seed + + response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data) + else: + raise ValueError(f"Invalid model={model}") - response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data) response.raise_for_status() for item in response.json()["data"]: @@ -193,27 +238,55 @@ class IdeogramRemix(CustomNode): image_bytes = BytesIO() image_pil.save(image_bytes, format="PNG") - files = { - "image": ("image.png", image_bytes.getvalue()), - } + if model in V2_MODELS: - data = { - "prompt": prompt, - "resolution": resolution, - "image_weight": image_weight, - "magic_prompt": magic_prompt_option, - "num_images": num_images, - "style_type": style_type, - } + files = { + "image_file": ("image.png", image_bytes.getvalue()), + } - if negative_prompt: - data["negative_prompt"] = negative_prompt - if seed: - data["seed"] = seed + data = { + "prompt": prompt, + "resolution": resolution, + "model": model, + "image_weight": image_weight, + "magic_prompt_option": magic_prompt_option, + "num_images": num_images, + "style_type": style_type, + } - # data = {"image_request": data} + if negative_prompt: + data["negative_prompt"] = negative_prompt + if seed: + data["seed"] = seed + + # data = {"image_request": data} + + response = requests.post("https://api.ideogram.ai/remix", headers=headers, files=files, data={ + "image_request": json.dumps(data) + }) + elif model == "V_3": + files = { + "image": ("image.png", image_bytes.getvalue()), + } + + data = { + "prompt": prompt, + "resolution": to_v3_resolution(resolution), + "image_weight": image_weight, + "magic_prompt": magic_prompt_option, + "num_images": num_images, + "style_type": style_type, + } + + if negative_prompt: + data["negative_prompt"] = negative_prompt + if seed: + data["seed"] = seed + + response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data) + else: + raise ValueError(f"Invalid model={model}") - response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data) response.raise_for_status() for item in response.json()["data"]: @@ -226,14 +299,81 @@ class IdeogramRemix(CustomNode): return (torch.cat(result_images, dim=0),) +class IdeogramDescribe(CustomNode): + """ + A ComfyUI node to get a description of an image using the Ideogram API. + """ + + @classmethod + def INPUT_TYPES(cls) -> Dict[str, Any]: + """ + Defines the input types for the node. + """ + return { + "required": { + "images": ("IMAGE",), + }, + "optional": { + "api_key": ("STRING", {"default": ""}), + } + } + + RETURN_TYPES = ("STRING",) + OUTPUT_IS_LIST = (True,) + FUNCTION = "describe" + CATEGORY = "ideogram" + + def describe(self, images: ImageBatch, api_key: str = "") -> tuple[list[str]]: + """ + Sends an image to the Ideogram API and returns a generated description. + + Args: + images: A batch of images as a tensor. + api_key: The Ideogram API key. + + Returns: + A tuple containing the description string for the first image. + """ + api_key = api_key_in_env_or_workflow(api_key) + headers = {"Api-Key": api_key} + + descriptions_batch = [] + for image in images: + pil_image = tensor2pil(image) + + image_bytes = BytesIO() + pil_image.save(image_bytes, format="PNG") + image_bytes.seek(0) + + files = { + "image_file": ("image.png", image_bytes.getvalue(), "image/png"), + } + + response = requests.post("https://api.ideogram.ai/describe", headers=headers, files=files) + response.raise_for_status() + + data = response.json() + descriptions = data.get("descriptions", []) + + if not descriptions: + descriptions_batch.append("") + else: + first_description = descriptions[0].get("text", "") + descriptions_batch.append(first_description) + + return (descriptions_batch,) + + NODE_CLASS_MAPPINGS = { "IdeogramGenerate": IdeogramGenerate, "IdeogramEdit": IdeogramEdit, "IdeogramRemix": IdeogramRemix, + "IdeogramDescribe": IdeogramDescribe, } NODE_DISPLAY_NAME_MAPPINGS = { "Ideogram Generate": "Ideogram Generate", "Ideogram Edit": "Ideogram Edit", "Ideogram Remix": "Ideogram Remix", + "Ideogram Describe": "Ideogram Describe", } diff --git a/tests/unit/test_ideogram_nodes.py b/tests/unit/test_ideogram_nodes.py index b2dc19c29..f6708e48e 100644 --- a/tests/unit/test_ideogram_nodes.py +++ b/tests/unit/test_ideogram_nodes.py @@ -3,10 +3,12 @@ import os import pytest import torch +from comfy.component_model.tensor_types import RGBImageBatch from comfy_extras.nodes.nodes_ideogram import ( IdeogramGenerate, IdeogramEdit, - IdeogramRemix + IdeogramRemix, + IdeogramDescribe, ) @@ -23,13 +25,44 @@ def sample_image(): return torch.ones((1, 1024, 1024, 3)) * 0.8 # Light gray image -def test_ideogram_generate(api_key): +@pytest.fixture +def black_square_image() -> RGBImageBatch: + # A black square image (1 batch, 1024x1024 pixels, 3 channels) + return torch.zeros((1, 1024, 1024, 3), dtype=torch.float32) + + +def test_ideogram_describe(api_key, black_square_image): + """ + Tests the IdeogramDescribe node by passing it a black square image and + asserting that the returned description contains "black" and "square". + """ + node = IdeogramDescribe() + + # The node's method returns a tuple containing a list of descriptions + descriptions_list, = node.describe( + images=black_square_image, + api_key=api_key + ) + + # We passed one image, so we expect one description in the list + assert isinstance(descriptions_list, list) + assert len(descriptions_list) == 1 + + description = descriptions_list[0] + + assert isinstance(description, str) + assert "black" in description.lower() + assert "square" in description.lower() + + +@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"]) +def test_ideogram_generate(api_key, model): node = IdeogramGenerate() image, = node.generate( prompt="a serene mountain landscape at sunset with snow-capped peaks", resolution="RESOLUTION_1024_1024", - model="V_2_TURBO", + model=model, magic_prompt_option="AUTO", api_key=api_key, num_images=1 @@ -41,8 +74,8 @@ def test_ideogram_generate(api_key): assert image.dtype == torch.float32 assert torch.all((image >= 0) & (image <= 1)) - -def test_ideogram_edit(api_key, sample_image): +@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"]) +def test_ideogram_edit(api_key, sample_image, model): node = IdeogramEdit() # white is areas to keep, black is areas to repaint @@ -56,7 +89,7 @@ def test_ideogram_edit(api_key, sample_image): masks=mask, magic_prompt_option="OFF", prompt="a solid black rectangle", - model="V_2_TURBO", + model=model, api_key=api_key, num_images=1, ) @@ -77,15 +110,15 @@ def test_ideogram_edit(api_key, sample_image): assert center_mean < outer_mean, f"Center region ({center_mean:.3f}) should be darker than outer region ({outer_mean:.3f})" assert center_mean < 0.6, f"Center region ({center_mean:.3f}) should be dark" - -def test_ideogram_remix(api_key, sample_image): +@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"]) +def test_ideogram_remix(api_key, sample_image, model): node = IdeogramRemix() image, = node.remix( images=sample_image, prompt="transform into a vibrant blue ocean scene with waves", resolution="RESOLUTION_1024_1024", - model="V_2_TURBO", + model=model, api_key=api_key, num_images=1 )