mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Native Ideogram support
This commit is contained in:
parent
a9347c6713
commit
b1bcf082af
10
README.md
10
README.md
@ -407,6 +407,16 @@ In this example, a raster image is converted to SVG, potentially modified, and t
|
||||
|
||||
You can try the [SVG Conversion Workflow](tests/inference/workflows/svg-0.json) to explore these features.
|
||||
|
||||
# Ideogram
|
||||
|
||||
First class support for Ideogram, currently the best still images model.
|
||||
|
||||
Visit [API key management](https://ideogram.ai/manage-api) and set the environment variable `IDEOGRAM_API_KEY` to it.
|
||||
|
||||
The `IdeogramEdit` node expects the white areas of the mask to be kept, and the black areas of the mask to be inpainted.
|
||||
|
||||
Use the **Fit Image to Diffusion Size** with the **Ideogram** resolution set to correctly fit images for inpainting.
|
||||
|
||||
# Video Workflows
|
||||
|
||||
ComfyUI LTS supports video workflows with AnimateDiff Evolved.
|
||||
|
||||
@ -222,6 +222,15 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
default=None
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ideogram-api-key",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Configures the Ideogram API Key for the Ideogram nodes. Visit https://ideogram.ai/manage-api to create this key.",
|
||||
env_var="IDEOGRAM_API_KEY",
|
||||
default=None
|
||||
)
|
||||
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||
|
||||
# now give plugins a chance to add configuration
|
||||
|
||||
@ -118,6 +118,7 @@ class Configuration(dict):
|
||||
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
||||
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
|
||||
openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes
|
||||
ideogram_api_key (str): Configures the Ideogram API Key for the Ideogram nodes. Visit https://ideogram.ai/manage-api to create this key.
|
||||
user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path.
|
||||
log_stdout (bool): Send normal process output to stdout instead of stderr (default)
|
||||
"""
|
||||
@ -215,6 +216,7 @@ class Configuration(dict):
|
||||
|
||||
self.executor_factory: str = "ThreadPoolExecutor"
|
||||
self.openai_api_key: Optional[str] = None
|
||||
self.ideogram_api_key: Optional[str] = None
|
||||
self.user_directory: Optional[str] = None
|
||||
|
||||
def __getattr__(self, item):
|
||||
|
||||
@ -7,8 +7,10 @@ import torch
|
||||
import torch.nn
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
||||
from comfy.latent_formats import LatentFormat
|
||||
|
||||
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
||||
LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat)
|
||||
|
||||
@runtime_checkable
|
||||
class DeviceSettable(Protocol):
|
||||
|
||||
@ -21,6 +21,7 @@ import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import typing
|
||||
import uuid
|
||||
from math import isclose
|
||||
from typing import Callable, Optional
|
||||
@ -38,7 +39,7 @@ from .float import stochastic_rounding
|
||||
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
|
||||
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
||||
from .model_base import BaseModel
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT
|
||||
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -437,7 +438,7 @@ class ModelPatcher(ModelManageable):
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||
def get_model_object(self, name: str) -> torch.nn.Module | typing.Any:
|
||||
"""Retrieves a nested attribute from an object using dot notation considering
|
||||
object patches.
|
||||
|
||||
@ -467,6 +468,10 @@ class ModelPatcher(ModelManageable):
|
||||
def diffusion_model(self, value: torch.nn.Module):
|
||||
self.add_object_patch("diffusion_model", value)
|
||||
|
||||
@property
|
||||
def latent_format(self) -> LatentFormatT:
|
||||
return self.get_model_object("latent_format")
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
|
||||
44
comfy_extras/constants/resolutions.py
Normal file
44
comfy_extras/constants/resolutions.py
Normal file
@ -0,0 +1,44 @@
|
||||
IDEOGRAM_RESOLUTIONS = [
|
||||
(512, 1536), (576, 1408), (576, 1472), (576, 1536),
|
||||
(640, 1024), (640, 1344), (640, 1408), (640, 1472), (640, 1536),
|
||||
(704, 1152), (704, 1216), (704, 1280), (704, 1344), (704, 1408), (704, 1472),
|
||||
(720, 1280), (736, 1312),
|
||||
(768, 1024), (768, 1088), (768, 1152), (768, 1216), (768, 1232), (768, 1280), (768, 1344),
|
||||
(832, 960), (832, 1024), (832, 1088), (832, 1152), (832, 1216), (832, 1248),
|
||||
(864, 1152),
|
||||
(896, 960), (896, 1024), (896, 1088), (896, 1120), (896, 1152),
|
||||
(960, 832), (960, 896), (960, 1024), (960, 1088),
|
||||
(1024, 640), (1024, 768), (1024, 832), (1024, 896), (1024, 960), (1024, 1024),
|
||||
(1088, 768), (1088, 832), (1088, 896), (1088, 960),
|
||||
(1120, 896),
|
||||
(1152, 704), (1152, 768), (1152, 832), (1152, 864), (1152, 896),
|
||||
(1216, 704), (1216, 768), (1216, 832),
|
||||
(1232, 768),
|
||||
(1248, 832),
|
||||
(1280, 704), (1280, 720), (1280, 768), (1280, 800),
|
||||
(1312, 736),
|
||||
(1344, 640), (1344, 704), (1344, 768),
|
||||
(1408, 576), (1408, 640), (1408, 704),
|
||||
(1472, 576), (1472, 640), (1472, 704),
|
||||
(1536, 512), (1536, 576), (1536, 640)
|
||||
]
|
||||
|
||||
SDXL_SD3_FLUX_RESOLUTIONS = [
|
||||
(640, 1536),
|
||||
(768, 1344),
|
||||
(832, 1216),
|
||||
(896, 1152),
|
||||
(1024, 1024),
|
||||
(1152, 896),
|
||||
(1216, 832),
|
||||
(1344, 768),
|
||||
(1536, 640),
|
||||
]
|
||||
|
||||
LTVX_RESOLUTIONS = [
|
||||
(768, 512)
|
||||
]
|
||||
|
||||
SD_RESOLUTIONS = [
|
||||
(512, 512),
|
||||
]
|
||||
239
comfy_extras/nodes/nodes_ideogram.py
Normal file
239
comfy_extras/nodes/nodes_ideogram.py
Normal file
@ -0,0 +1,239 @@
|
||||
import json
|
||||
from io import BytesIO
|
||||
from itertools import chain
|
||||
from typing import Tuple, Dict, Any
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
from comfy.utils import pil2tensor, tensor2pil
|
||||
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
||||
from comfy_extras.nodes.nodes_mask import MaskToImage
|
||||
from comfy.cli_args import args
|
||||
|
||||
ASPECT_RATIOS = [(10, 6), (16, 10), (9, 16), (3, 2), (4, 3)]
|
||||
ASPECT_RATIO_ENUM = ["ASPECT_1_1"] + list(chain.from_iterable(
|
||||
[f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"]
|
||||
for a, b in ASPECT_RATIOS
|
||||
))
|
||||
MODELS_ENUM = ["V_2", "V_2_TURBO"]
|
||||
AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"]
|
||||
RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS]
|
||||
|
||||
|
||||
|
||||
def api_key_in_env_or_workflow(api_key_from_workflow: str):
|
||||
if api_key_from_workflow is not None and "" != api_key_from_workflow:
|
||||
return api_key_from_workflow
|
||||
|
||||
return args.ideogram_api_key
|
||||
|
||||
|
||||
class IdeogramGenerate(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": ("STRING", {"multiline": True}),
|
||||
"resolution": (RESOLUTION_ENUM, {"default": RESOLUTION_ENUM[0]}),
|
||||
"model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}),
|
||||
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
||||
},
|
||||
"optional": {
|
||||
"api_key": ("STRING", {"default": ""}),
|
||||
"negative_prompt": ("STRING", {"multiline": True}),
|
||||
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
||||
"seed": ("INT", {"default": 0}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = "ideogram"
|
||||
|
||||
def generate(self, prompt: str, resolution: str, model: str, magic_prompt_option: str,
|
||||
api_key: str = "", negative_prompt: str = "", num_images: int = 1, seed: int = 0) -> Tuple[torch.Tensor]:
|
||||
api_key = api_key_in_env_or_workflow(api_key)
|
||||
headers = {"Api-Key": api_key, "Content-Type": "application/json"}
|
||||
|
||||
payload = {
|
||||
"image_request": {
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"model": model,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"num_images": num_images
|
||||
}
|
||||
}
|
||||
|
||||
if negative_prompt:
|
||||
payload["image_request"]["negative_prompt"] = negative_prompt
|
||||
if seed:
|
||||
payload["image_request"]["seed"] = seed
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/generate", headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
images = []
|
||||
for item in response.json()["data"]:
|
||||
img_response = requests.get(item["url"])
|
||||
img_response.raise_for_status()
|
||||
|
||||
pil_image = Image.open(BytesIO(img_response.content))
|
||||
images.append(pil2tensor(pil_image))
|
||||
|
||||
return (torch.cat(images, dim=0),)
|
||||
|
||||
|
||||
class IdeogramEdit(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"masks": ("MASK",),
|
||||
"prompt": ("STRING", {"multiline": True}),
|
||||
"model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}),
|
||||
},
|
||||
"optional": {
|
||||
"api_key": ("STRING", {"default": ""}),
|
||||
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
||||
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
||||
"seed": ("INT", {"default": 0}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "edit"
|
||||
CATEGORY = "ideogram"
|
||||
|
||||
def edit(self, images: RGBImageBatch, masks: MaskBatch, prompt: str, model: str,
|
||||
api_key: str = "", magic_prompt_option: str = "AUTO",
|
||||
num_images: int = 1, seed: int = 0) -> Tuple[torch.Tensor]:
|
||||
api_key = api_key_in_env_or_workflow(api_key)
|
||||
headers = {"Api-Key": api_key}
|
||||
image_responses = []
|
||||
for mask, image in zip(torch.unbind(masks), torch.unbind(images)):
|
||||
mask, = MaskToImage().mask_to_image(mask=mask)
|
||||
mask: RGBImageBatch
|
||||
|
||||
image_pil = tensor2pil(image)
|
||||
mask_pil = tensor2pil(mask)
|
||||
|
||||
image_bytes = BytesIO()
|
||||
mask_bytes = BytesIO()
|
||||
image_pil.save(image_bytes, format="PNG")
|
||||
mask_pil.save(mask_bytes, format="PNG")
|
||||
|
||||
files = {
|
||||
"image_file": ("image.png", image_bytes.getvalue()),
|
||||
"mask": ("mask.png", mask_bytes.getvalue()),
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"num_images": num_images
|
||||
}
|
||||
if seed:
|
||||
data["seed"] = seed
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/edit", headers=headers, files=files, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
for item in response.json()["data"]:
|
||||
img_response = requests.get(item["url"])
|
||||
img_response.raise_for_status()
|
||||
|
||||
pil_image = Image.open(BytesIO(img_response.content))
|
||||
image_responses.append(pil2tensor(pil_image))
|
||||
|
||||
return (torch.cat(image_responses, dim=0),)
|
||||
|
||||
|
||||
class IdeogramRemix(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"prompt": ("STRING", {"multiline": True}),
|
||||
"resolution": (RESOLUTION_ENUM, {"default": RESOLUTION_ENUM[0]}),
|
||||
"model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}),
|
||||
},
|
||||
"optional": {
|
||||
"api_key": ("STRING", {"default": ""}),
|
||||
"image_weight": ("INT", {"default": 50, "min": 1, "max": 100}),
|
||||
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
||||
"negative_prompt": ("STRING", {"multiline": True}),
|
||||
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
||||
"seed": ("INT", {"default": 0}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "remix"
|
||||
CATEGORY = "ideogram"
|
||||
|
||||
def remix(self, images: torch.Tensor, prompt: str, resolution: str, model: str,
|
||||
api_key: str = "", image_weight: int = 50, magic_prompt_option: str = "AUTO",
|
||||
negative_prompt: str = "", num_images: int = 1, seed: int = 0) -> Tuple[torch.Tensor]:
|
||||
api_key = api_key_in_env_or_workflow(api_key)
|
||||
headers = {"Api-Key": api_key}
|
||||
|
||||
result_images = []
|
||||
for image in images:
|
||||
image_pil = tensor2pil(image)
|
||||
image_bytes = BytesIO()
|
||||
image_pil.save(image_bytes, format="PNG")
|
||||
|
||||
files = {
|
||||
"image_file": ("image.png", image_bytes.getvalue()),
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"model": model,
|
||||
"image_weight": image_weight,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"num_images": num_images
|
||||
}
|
||||
|
||||
if negative_prompt:
|
||||
data["negative_prompt"] = negative_prompt
|
||||
if seed:
|
||||
data["seed"] = seed
|
||||
|
||||
# data = {"image_request": data}
|
||||
|
||||
response = requests.post("https://api.ideogram.ai/remix", headers=headers, files=files, data={
|
||||
"image_request": json.dumps(data)
|
||||
})
|
||||
response.raise_for_status()
|
||||
|
||||
for item in response.json()["data"]:
|
||||
img_response = requests.get(item["url"])
|
||||
img_response.raise_for_status()
|
||||
|
||||
pil_image = Image.open(BytesIO(img_response.content))
|
||||
result_images.append(pil2tensor(pil_image))
|
||||
|
||||
return (torch.cat(result_images, dim=0),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"IdeogramGenerate": IdeogramGenerate,
|
||||
"IdeogramEdit": IdeogramEdit,
|
||||
"IdeogramRemix": IdeogramRemix,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Ideogram Generate": "Ideogram Generate",
|
||||
"Ideogram Edit": "Ideogram Edit",
|
||||
"Ideogram Remix": "Ideogram Remix",
|
||||
}
|
||||
@ -14,6 +14,8 @@ from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch
|
||||
from comfy.nodes.base_nodes import ImageScale
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
from comfy_extras.constants.resolutions import SDXL_SD3_FLUX_RESOLUTIONS, LTVX_RESOLUTIONS, SD_RESOLUTIONS, \
|
||||
IDEOGRAM_RESOLUTIONS
|
||||
|
||||
|
||||
def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch:
|
||||
@ -271,7 +273,7 @@ class ImageResize:
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"resize_mode": (["cover", "contain", "auto"], {"default": "cover"}),
|
||||
"resolutions": (["SDXL/SD3/Flux", "SD1.5", "LTXV"], {"default": "SDXL/SD3/Flux"}),
|
||||
"resolutions": (["SDXL/SD3/Flux", "SD1.5", "LTXV", "Ideogram"], {"default": "SDXL/SD3/Flux"}),
|
||||
"interpolation": (ImageScale.upscale_methods, {"default": "bilinear"}),
|
||||
}
|
||||
}
|
||||
@ -282,26 +284,16 @@ class ImageResize:
|
||||
|
||||
def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str) -> Tuple[RGBImageBatch]:
|
||||
if resolutions == "SDXL/SD3/Flux":
|
||||
supported_resolutions = [
|
||||
(640, 1536),
|
||||
(768, 1344),
|
||||
(832, 1216),
|
||||
(896, 1152),
|
||||
(1024, 1024),
|
||||
(1152, 896),
|
||||
(1216, 832),
|
||||
(1344, 768),
|
||||
(1536, 640),
|
||||
]
|
||||
supported_resolutions = SDXL_SD3_FLUX_RESOLUTIONS
|
||||
elif resolutions == "ltxv":
|
||||
supported_resolutions = [
|
||||
(768, 512)
|
||||
]
|
||||
supported_resolutions = LTVX_RESOLUTIONS
|
||||
elif resolutions == "ideogram":
|
||||
supported_resolutions = IDEOGRAM_RESOLUTIONS
|
||||
else:
|
||||
supported_resolutions = [
|
||||
(512, 512),
|
||||
]
|
||||
supported_resolutions = SD_RESOLUTIONS
|
||||
return self.resize_image_with_supported_resolutions(image, resize_mode, supported_resolutions, interpolation)
|
||||
|
||||
def resize_image_with_supported_resolutions(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str):
|
||||
resized_images = []
|
||||
for img in image:
|
||||
h, w = img.shape[:2]
|
||||
|
||||
@ -2,6 +2,7 @@ import numpy as np
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
from comfy import utils
|
||||
from comfy.component_model.tensor_types import MaskBatch, RGBImageBatch
|
||||
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
|
||||
@ -106,7 +107,7 @@ class MaskToImage:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mask_to_image"
|
||||
|
||||
def mask_to_image(self, mask):
|
||||
def mask_to_image(self, mask: MaskBatch) -> tuple[RGBImageBatch]:
|
||||
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return (result,)
|
||||
|
||||
|
||||
102
tests/unit/test_ideogram_nodes.py
Normal file
102
tests/unit/test_ideogram_nodes.py
Normal file
@ -0,0 +1,102 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy_extras.nodes.nodes_ideogram import (
|
||||
IdeogramGenerate,
|
||||
IdeogramEdit,
|
||||
IdeogramRemix
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key():
|
||||
key = os.environ.get('IDEOGRAM_API_KEY')
|
||||
if not key:
|
||||
pytest.skip("IDEOGRAM_API_KEY environment variable not set")
|
||||
return key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image():
|
||||
return torch.ones((1, 1024, 1024, 3)) * 0.8 # Light gray image
|
||||
|
||||
|
||||
def test_ideogram_generate(api_key):
|
||||
node = IdeogramGenerate()
|
||||
|
||||
image, = node.generate(
|
||||
prompt="a serene mountain landscape at sunset with snow-capped peaks",
|
||||
resolution="RESOLUTION_1024_1024",
|
||||
model="V_2_TURBO",
|
||||
magic_prompt_option="AUTO",
|
||||
api_key=api_key,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# Verify output format
|
||||
assert isinstance(image, torch.Tensor)
|
||||
assert image.shape[1:] == (1024, 1024, 3) # HxWxC format
|
||||
assert image.dtype == torch.float32
|
||||
assert torch.all((image >= 0) & (image <= 1))
|
||||
|
||||
|
||||
def test_ideogram_edit(api_key, sample_image):
|
||||
node = IdeogramEdit()
|
||||
|
||||
# white is areas to keep, black is areas to repaint
|
||||
mask = torch.full((1, 1024, 1024), fill_value=1.0)
|
||||
center_start = 386
|
||||
center_end = 640
|
||||
mask[:, center_start:center_end, center_start:center_end] = 0.0
|
||||
|
||||
image, = node.edit(
|
||||
images=sample_image,
|
||||
masks=mask,
|
||||
magic_prompt_option="OFF",
|
||||
prompt="a solid black rectangle",
|
||||
model="V_2_TURBO",
|
||||
api_key=api_key,
|
||||
num_images=1,
|
||||
)
|
||||
|
||||
# Verify output format
|
||||
assert isinstance(image, torch.Tensor)
|
||||
assert image.shape[1:] == (1024, 1024, 3)
|
||||
assert image.dtype == torch.float32
|
||||
assert torch.all((image >= 0) & (image <= 1))
|
||||
|
||||
# Verify the center is darker than the original
|
||||
center_region = image[:, center_start:center_end, center_start:center_end, :]
|
||||
outer_region = image[:, :center_start, :, :] # Use top portion for comparison
|
||||
|
||||
center_mean = center_region.mean().item()
|
||||
outer_mean = outer_region.mean().item()
|
||||
|
||||
assert center_mean < outer_mean, f"Center region ({center_mean:.3f}) should be darker than outer region ({outer_mean:.3f})"
|
||||
assert center_mean < 0.6, f"Center region ({center_mean:.3f}) should be dark"
|
||||
|
||||
|
||||
def test_ideogram_remix(api_key, sample_image):
|
||||
node = IdeogramRemix()
|
||||
|
||||
image, = node.remix(
|
||||
images=sample_image,
|
||||
prompt="transform into a vibrant blue ocean scene with waves",
|
||||
resolution="RESOLUTION_1024_1024",
|
||||
model="V_2_TURBO",
|
||||
api_key=api_key,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# Verify output format
|
||||
assert isinstance(image, torch.Tensor)
|
||||
assert image.shape[1:] == (1024, 1024, 3)
|
||||
assert image.dtype == torch.float32
|
||||
assert torch.all((image >= 0) & (image <= 1))
|
||||
|
||||
# Since we asked for a blue ocean scene, verify there's significant blue component
|
||||
blue_channel = image[..., 2] # RGB where blue is index 2
|
||||
blue_mean = blue_channel.mean().item()
|
||||
assert blue_mean > 0.4, f"Blue channel mean ({blue_mean:.3f}) should be significant for an ocean scene"
|
||||
Loading…
Reference in New Issue
Block a user