From b1bcf082afc63da109c0a549a8f209141f82eb34 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 22 Jan 2025 10:32:04 -0800 Subject: [PATCH] Native Ideogram support --- README.md | 10 ++ comfy/cli_args.py | 9 + comfy/cli_args_types.py | 2 + comfy/model_management_types.py | 4 +- comfy/model_patcher.py | 9 +- comfy_extras/constants/resolutions.py | 44 +++++ comfy_extras/nodes/nodes_ideogram.py | 239 ++++++++++++++++++++++++++ comfy_extras/nodes/nodes_images.py | 28 ++- comfy_extras/nodes/nodes_mask.py | 3 +- tests/unit/test_ideogram_nodes.py | 102 +++++++++++ 10 files changed, 428 insertions(+), 22 deletions(-) create mode 100644 comfy_extras/constants/resolutions.py create mode 100644 comfy_extras/nodes/nodes_ideogram.py create mode 100644 tests/unit/test_ideogram_nodes.py diff --git a/README.md b/README.md index bc69cce1c..ff013ddea 100644 --- a/README.md +++ b/README.md @@ -407,6 +407,16 @@ In this example, a raster image is converted to SVG, potentially modified, and t You can try the [SVG Conversion Workflow](tests/inference/workflows/svg-0.json) to explore these features. +# Ideogram + +First class support for Ideogram, currently the best still images model. + +Visit [API key management](https://ideogram.ai/manage-api) and set the environment variable `IDEOGRAM_API_KEY` to it. + +The `IdeogramEdit` node expects the white areas of the mask to be kept, and the black areas of the mask to be inpainted. + +Use the **Fit Image to Diffusion Size** with the **Ideogram** resolution set to correctly fit images for inpainting. + # Video Workflows ComfyUI LTS supports video workflows with AnimateDiff Evolved. diff --git a/comfy/cli_args.py b/comfy/cli_args.py index be3902af2..99f2d2673 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -222,6 +222,15 @@ def _create_parser() -> EnhancedConfigArgParser: default=None ) + parser.add_argument( + "--ideogram-api-key", + required=False, + type=str, + help="Configures the Ideogram API Key for the Ideogram nodes. Visit https://ideogram.ai/manage-api to create this key.", + env_var="IDEOGRAM_API_KEY", + default=None + ) + parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.") # now give plugins a chance to add configuration diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index baa86ef09..3bce0dc00 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -118,6 +118,7 @@ class Configuration(dict): executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512. openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes + ideogram_api_key (str): Configures the Ideogram API Key for the Ideogram nodes. Visit https://ideogram.ai/manage-api to create this key. user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path. log_stdout (bool): Send normal process output to stdout instead of stderr (default) """ @@ -215,6 +216,7 @@ class Configuration(dict): self.executor_factory: str = "ThreadPoolExecutor" self.openai_api_key: Optional[str] = None + self.ideogram_api_key: Optional[str] = None self.user_directory: Optional[str] = None def __getattr__(self, item): diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 5291bbfa2..41add063e 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -7,8 +7,10 @@ import torch import torch.nn from typing_extensions import TypedDict, NotRequired -ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable') +from comfy.latent_formats import LatentFormat +ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable') +LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat) @runtime_checkable class DeviceSettable(Protocol): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index bedb6b2af..c7d31ddd1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -21,6 +21,7 @@ import collections import copy import inspect import logging +import typing import uuid from math import isclose from typing import Callable, Optional @@ -38,7 +39,7 @@ from .float import stochastic_rounding from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue from .model_base import BaseModel -from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions +from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection logger = logging.getLogger(__name__) @@ -437,7 +438,7 @@ class ModelPatcher(ModelManageable): def add_object_patch(self, name, obj): self.object_patches[name] = obj - def get_model_object(self, name: str) -> torch.nn.Module: + def get_model_object(self, name: str) -> torch.nn.Module | typing.Any: """Retrieves a nested attribute from an object using dot notation considering object patches. @@ -467,6 +468,10 @@ class ModelPatcher(ModelManageable): def diffusion_model(self, value: torch.nn.Module): self.add_object_patch("diffusion_model", value) + @property + def latent_format(self) -> LatentFormatT: + return self.get_model_object("latent_format") + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: diff --git a/comfy_extras/constants/resolutions.py b/comfy_extras/constants/resolutions.py new file mode 100644 index 000000000..d71da576d --- /dev/null +++ b/comfy_extras/constants/resolutions.py @@ -0,0 +1,44 @@ +IDEOGRAM_RESOLUTIONS = [ + (512, 1536), (576, 1408), (576, 1472), (576, 1536), + (640, 1024), (640, 1344), (640, 1408), (640, 1472), (640, 1536), + (704, 1152), (704, 1216), (704, 1280), (704, 1344), (704, 1408), (704, 1472), + (720, 1280), (736, 1312), + (768, 1024), (768, 1088), (768, 1152), (768, 1216), (768, 1232), (768, 1280), (768, 1344), + (832, 960), (832, 1024), (832, 1088), (832, 1152), (832, 1216), (832, 1248), + (864, 1152), + (896, 960), (896, 1024), (896, 1088), (896, 1120), (896, 1152), + (960, 832), (960, 896), (960, 1024), (960, 1088), + (1024, 640), (1024, 768), (1024, 832), (1024, 896), (1024, 960), (1024, 1024), + (1088, 768), (1088, 832), (1088, 896), (1088, 960), + (1120, 896), + (1152, 704), (1152, 768), (1152, 832), (1152, 864), (1152, 896), + (1216, 704), (1216, 768), (1216, 832), + (1232, 768), + (1248, 832), + (1280, 704), (1280, 720), (1280, 768), (1280, 800), + (1312, 736), + (1344, 640), (1344, 704), (1344, 768), + (1408, 576), (1408, 640), (1408, 704), + (1472, 576), (1472, 640), (1472, 704), + (1536, 512), (1536, 576), (1536, 640) +] + +SDXL_SD3_FLUX_RESOLUTIONS = [ + (640, 1536), + (768, 1344), + (832, 1216), + (896, 1152), + (1024, 1024), + (1152, 896), + (1216, 832), + (1344, 768), + (1536, 640), +] + +LTVX_RESOLUTIONS = [ + (768, 512) +] + +SD_RESOLUTIONS = [ + (512, 512), +] diff --git a/comfy_extras/nodes/nodes_ideogram.py b/comfy_extras/nodes/nodes_ideogram.py new file mode 100644 index 000000000..f1aad747c --- /dev/null +++ b/comfy_extras/nodes/nodes_ideogram.py @@ -0,0 +1,239 @@ +import json +from io import BytesIO +from itertools import chain +from typing import Tuple, Dict, Any + +import requests +import torch +from PIL import Image + +from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch +from comfy.nodes.package_typing import CustomNode +from comfy.utils import pil2tensor, tensor2pil +from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS +from comfy_extras.nodes.nodes_mask import MaskToImage +from comfy.cli_args import args + +ASPECT_RATIOS = [(10, 6), (16, 10), (9, 16), (3, 2), (4, 3)] +ASPECT_RATIO_ENUM = ["ASPECT_1_1"] + list(chain.from_iterable( + [f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"] + for a, b in ASPECT_RATIOS +)) +MODELS_ENUM = ["V_2", "V_2_TURBO"] +AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"] +RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS] + + + +def api_key_in_env_or_workflow(api_key_from_workflow: str): + if api_key_from_workflow is not None and "" != api_key_from_workflow: + return api_key_from_workflow + + return args.ideogram_api_key + + +class IdeogramGenerate(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> Dict[str, Any]: + return { + "required": { + "prompt": ("STRING", {"multiline": True}), + "resolution": (RESOLUTION_ENUM, {"default": RESOLUTION_ENUM[0]}), + "model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}), + "magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}), + }, + "optional": { + "api_key": ("STRING", {"default": ""}), + "negative_prompt": ("STRING", {"multiline": True}), + "num_images": ("INT", {"default": 1, "min": 1, "max": 8}), + "seed": ("INT", {"default": 0}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "generate" + CATEGORY = "ideogram" + + def generate(self, prompt: str, resolution: str, model: str, magic_prompt_option: str, + api_key: str = "", negative_prompt: str = "", num_images: int = 1, seed: int = 0) -> Tuple[torch.Tensor]: + api_key = api_key_in_env_or_workflow(api_key) + headers = {"Api-Key": api_key, "Content-Type": "application/json"} + + payload = { + "image_request": { + "prompt": prompt, + "resolution": resolution, + "model": model, + "magic_prompt_option": magic_prompt_option, + "num_images": num_images + } + } + + if negative_prompt: + payload["image_request"]["negative_prompt"] = negative_prompt + if seed: + payload["image_request"]["seed"] = seed + + response = requests.post("https://api.ideogram.ai/generate", headers=headers, json=payload) + response.raise_for_status() + + images = [] + for item in response.json()["data"]: + img_response = requests.get(item["url"]) + img_response.raise_for_status() + + pil_image = Image.open(BytesIO(img_response.content)) + images.append(pil2tensor(pil_image)) + + return (torch.cat(images, dim=0),) + + +class IdeogramEdit(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> Dict[str, Any]: + return { + "required": { + "images": ("IMAGE",), + "masks": ("MASK",), + "prompt": ("STRING", {"multiline": True}), + "model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}), + }, + "optional": { + "api_key": ("STRING", {"default": ""}), + "magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}), + "num_images": ("INT", {"default": 1, "min": 1, "max": 8}), + "seed": ("INT", {"default": 0}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "edit" + CATEGORY = "ideogram" + + def edit(self, images: RGBImageBatch, masks: MaskBatch, prompt: str, model: str, + api_key: str = "", magic_prompt_option: str = "AUTO", + num_images: int = 1, seed: int = 0) -> Tuple[torch.Tensor]: + api_key = api_key_in_env_or_workflow(api_key) + headers = {"Api-Key": api_key} + image_responses = [] + for mask, image in zip(torch.unbind(masks), torch.unbind(images)): + mask, = MaskToImage().mask_to_image(mask=mask) + mask: RGBImageBatch + + image_pil = tensor2pil(image) + mask_pil = tensor2pil(mask) + + image_bytes = BytesIO() + mask_bytes = BytesIO() + image_pil.save(image_bytes, format="PNG") + mask_pil.save(mask_bytes, format="PNG") + + files = { + "image_file": ("image.png", image_bytes.getvalue()), + "mask": ("mask.png", mask_bytes.getvalue()), + } + + data = { + "prompt": prompt, + "model": model, + "magic_prompt_option": magic_prompt_option, + "num_images": num_images + } + if seed: + data["seed"] = seed + + response = requests.post("https://api.ideogram.ai/edit", headers=headers, files=files, data=data) + response.raise_for_status() + + for item in response.json()["data"]: + img_response = requests.get(item["url"]) + img_response.raise_for_status() + + pil_image = Image.open(BytesIO(img_response.content)) + image_responses.append(pil2tensor(pil_image)) + + return (torch.cat(image_responses, dim=0),) + + +class IdeogramRemix(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> Dict[str, Any]: + return { + "required": { + "images": ("IMAGE",), + "prompt": ("STRING", {"multiline": True}), + "resolution": (RESOLUTION_ENUM, {"default": RESOLUTION_ENUM[0]}), + "model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}), + }, + "optional": { + "api_key": ("STRING", {"default": ""}), + "image_weight": ("INT", {"default": 50, "min": 1, "max": 100}), + "magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}), + "negative_prompt": ("STRING", {"multiline": True}), + "num_images": ("INT", {"default": 1, "min": 1, "max": 8}), + "seed": ("INT", {"default": 0}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "remix" + CATEGORY = "ideogram" + + def remix(self, images: torch.Tensor, prompt: str, resolution: str, model: str, + api_key: str = "", image_weight: int = 50, magic_prompt_option: str = "AUTO", + negative_prompt: str = "", num_images: int = 1, seed: int = 0) -> Tuple[torch.Tensor]: + api_key = api_key_in_env_or_workflow(api_key) + headers = {"Api-Key": api_key} + + result_images = [] + for image in images: + image_pil = tensor2pil(image) + image_bytes = BytesIO() + image_pil.save(image_bytes, format="PNG") + + files = { + "image_file": ("image.png", image_bytes.getvalue()), + } + + data = { + "prompt": prompt, + "resolution": resolution, + "model": model, + "image_weight": image_weight, + "magic_prompt_option": magic_prompt_option, + "num_images": num_images + } + + if negative_prompt: + data["negative_prompt"] = negative_prompt + if seed: + data["seed"] = seed + + # data = {"image_request": data} + + response = requests.post("https://api.ideogram.ai/remix", headers=headers, files=files, data={ + "image_request": json.dumps(data) + }) + response.raise_for_status() + + for item in response.json()["data"]: + img_response = requests.get(item["url"]) + img_response.raise_for_status() + + pil_image = Image.open(BytesIO(img_response.content)) + result_images.append(pil2tensor(pil_image)) + + return (torch.cat(result_images, dim=0),) + + +NODE_CLASS_MAPPINGS = { + "IdeogramGenerate": IdeogramGenerate, + "IdeogramEdit": IdeogramEdit, + "IdeogramRemix": IdeogramRemix, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Ideogram Generate": "Ideogram Generate", + "Ideogram Edit": "Ideogram Edit", + "Ideogram Remix": "Ideogram Remix", +} diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index 151be6a84..3ea4d037c 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -14,6 +14,8 @@ from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch from comfy.nodes.base_nodes import ImageScale from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.package_typing import CustomNode +from comfy_extras.constants.resolutions import SDXL_SD3_FLUX_RESOLUTIONS, LTVX_RESOLUTIONS, SD_RESOLUTIONS, \ + IDEOGRAM_RESOLUTIONS def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch: @@ -271,7 +273,7 @@ class ImageResize: "required": { "image": ("IMAGE",), "resize_mode": (["cover", "contain", "auto"], {"default": "cover"}), - "resolutions": (["SDXL/SD3/Flux", "SD1.5", "LTXV"], {"default": "SDXL/SD3/Flux"}), + "resolutions": (["SDXL/SD3/Flux", "SD1.5", "LTXV", "Ideogram"], {"default": "SDXL/SD3/Flux"}), "interpolation": (ImageScale.upscale_methods, {"default": "bilinear"}), } } @@ -282,26 +284,16 @@ class ImageResize: def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str) -> Tuple[RGBImageBatch]: if resolutions == "SDXL/SD3/Flux": - supported_resolutions = [ - (640, 1536), - (768, 1344), - (832, 1216), - (896, 1152), - (1024, 1024), - (1152, 896), - (1216, 832), - (1344, 768), - (1536, 640), - ] + supported_resolutions = SDXL_SD3_FLUX_RESOLUTIONS elif resolutions == "ltxv": - supported_resolutions = [ - (768, 512) - ] + supported_resolutions = LTVX_RESOLUTIONS + elif resolutions == "ideogram": + supported_resolutions = IDEOGRAM_RESOLUTIONS else: - supported_resolutions = [ - (512, 512), - ] + supported_resolutions = SD_RESOLUTIONS + return self.resize_image_with_supported_resolutions(image, resize_mode, supported_resolutions, interpolation) + def resize_image_with_supported_resolutions(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str): resized_images = [] for img in image: h, w = img.shape[:2] diff --git a/comfy_extras/nodes/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py index 12185989b..5dd6fdad0 100644 --- a/comfy_extras/nodes/nodes_mask.py +++ b/comfy_extras/nodes/nodes_mask.py @@ -2,6 +2,7 @@ import numpy as np import scipy.ndimage import torch from comfy import utils +from comfy.component_model.tensor_types import MaskBatch, RGBImageBatch from comfy.nodes.common import MAX_RESOLUTION @@ -106,7 +107,7 @@ class MaskToImage: RETURN_TYPES = ("IMAGE",) FUNCTION = "mask_to_image" - def mask_to_image(self, mask): + def mask_to_image(self, mask: MaskBatch) -> tuple[RGBImageBatch]: result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) return (result,) diff --git a/tests/unit/test_ideogram_nodes.py b/tests/unit/test_ideogram_nodes.py new file mode 100644 index 000000000..b2dc19c29 --- /dev/null +++ b/tests/unit/test_ideogram_nodes.py @@ -0,0 +1,102 @@ +import os + +import pytest +import torch + +from comfy_extras.nodes.nodes_ideogram import ( + IdeogramGenerate, + IdeogramEdit, + IdeogramRemix +) + + +@pytest.fixture +def api_key(): + key = os.environ.get('IDEOGRAM_API_KEY') + if not key: + pytest.skip("IDEOGRAM_API_KEY environment variable not set") + return key + + +@pytest.fixture +def sample_image(): + return torch.ones((1, 1024, 1024, 3)) * 0.8 # Light gray image + + +def test_ideogram_generate(api_key): + node = IdeogramGenerate() + + image, = node.generate( + prompt="a serene mountain landscape at sunset with snow-capped peaks", + resolution="RESOLUTION_1024_1024", + model="V_2_TURBO", + magic_prompt_option="AUTO", + api_key=api_key, + num_images=1 + ) + + # Verify output format + assert isinstance(image, torch.Tensor) + assert image.shape[1:] == (1024, 1024, 3) # HxWxC format + assert image.dtype == torch.float32 + assert torch.all((image >= 0) & (image <= 1)) + + +def test_ideogram_edit(api_key, sample_image): + node = IdeogramEdit() + + # white is areas to keep, black is areas to repaint + mask = torch.full((1, 1024, 1024), fill_value=1.0) + center_start = 386 + center_end = 640 + mask[:, center_start:center_end, center_start:center_end] = 0.0 + + image, = node.edit( + images=sample_image, + masks=mask, + magic_prompt_option="OFF", + prompt="a solid black rectangle", + model="V_2_TURBO", + api_key=api_key, + num_images=1, + ) + + # Verify output format + assert isinstance(image, torch.Tensor) + assert image.shape[1:] == (1024, 1024, 3) + assert image.dtype == torch.float32 + assert torch.all((image >= 0) & (image <= 1)) + + # Verify the center is darker than the original + center_region = image[:, center_start:center_end, center_start:center_end, :] + outer_region = image[:, :center_start, :, :] # Use top portion for comparison + + center_mean = center_region.mean().item() + outer_mean = outer_region.mean().item() + + assert center_mean < outer_mean, f"Center region ({center_mean:.3f}) should be darker than outer region ({outer_mean:.3f})" + assert center_mean < 0.6, f"Center region ({center_mean:.3f}) should be dark" + + +def test_ideogram_remix(api_key, sample_image): + node = IdeogramRemix() + + image, = node.remix( + images=sample_image, + prompt="transform into a vibrant blue ocean scene with waves", + resolution="RESOLUTION_1024_1024", + model="V_2_TURBO", + api_key=api_key, + num_images=1 + ) + + # Verify output format + assert isinstance(image, torch.Tensor) + assert image.shape[1:] == (1024, 1024, 3) + assert image.dtype == torch.float32 + assert torch.all((image >= 0) & (image <= 1)) + + # Since we asked for a blue ocean scene, verify there's significant blue component + blue_channel = image[..., 2] # RGB where blue is index 2 + blue_mean = blue_channel.mean().item() + assert blue_mean > 0.4, f"Blue channel mean ({blue_mean:.3f}) should be significant for an ocean scene"