mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
add args for ideogram nodes, add tests
This commit is contained in:
parent
64fa54c3ad
commit
e6f9a6a552
@ -14,28 +14,39 @@ from comfy.utils import pil2tensor, tensor2pil
|
|||||||
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
||||||
from comfy_extras.nodes.nodes_mask import MaskToImage
|
from comfy_extras.nodes.nodes_mask import MaskToImage
|
||||||
|
|
||||||
|
# --- ENUMs and Constants ---
|
||||||
|
|
||||||
ASPECT_RATIOS = [(10, 6), (16, 10), (9, 16), (3, 2), (4, 3)]
|
ASPECT_RATIOS = [(10, 6), (16, 10), (9, 16), (3, 2), (4, 3)]
|
||||||
ASPECT_RATIO_ENUM = ["ASPECT_1_1"] + list(chain.from_iterable(
|
ASPECT_RATIO_ENUM = ["ASPECT_1_1"] + list(chain.from_iterable(
|
||||||
[f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"]
|
[f"ASPECT_{a}_{b}", f"ASPECT_{b}_{a}"]
|
||||||
for a, b in ASPECT_RATIOS
|
for a, b in ASPECT_RATIOS
|
||||||
))
|
))
|
||||||
|
# New enum for v3 aspect ratios
|
||||||
|
ASPECT_RATIO_V3_ENUM = ["disabled", "1x1", "10x16", "9x16", "3x4", "2x3", "16x10", "3x2", "4x3", "16x9"]
|
||||||
V2_MODELS = ["V_2", "V_2_TURBO"]
|
V2_MODELS = ["V_2", "V_2_TURBO"]
|
||||||
MODELS_ENUM = V2_MODELS + ["V_3"]
|
MODELS_ENUM = V2_MODELS + ["V_3"]
|
||||||
AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"]
|
AUTO_PROMPT_ENUM = ["AUTO", "ON", "OFF"]
|
||||||
STYLES_ENUM = ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"]
|
STYLES_ENUM = ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"]
|
||||||
RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS]
|
RESOLUTION_ENUM = [f"RESOLUTION_{w}_{h}" for w, h in IDEOGRAM_RESOLUTIONS]
|
||||||
|
# New enum for v3 rendering speed
|
||||||
|
RENDERING_SPEED_ENUM = ["DEFAULT", "TURBO", "QUALITY"]
|
||||||
|
|
||||||
def to_v3_resolution(resolution:str) -> str:
|
|
||||||
|
# --- Helper Functions ---
|
||||||
|
|
||||||
|
def to_v3_resolution(resolution: str) -> str:
|
||||||
return resolution[len("RESOLUTION_"):].replace("_", "x")
|
return resolution[len("RESOLUTION_"):].replace("_", "x")
|
||||||
|
|
||||||
|
|
||||||
def api_key_in_env_or_workflow(api_key_from_workflow: str):
|
def api_key_in_env_or_workflow(api_key_from_workflow: str):
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
if api_key_from_workflow is not None and "" != api_key_from_workflow.strip():
|
if api_key_from_workflow is not None and "" != api_key_from_workflow.strip():
|
||||||
return api_key_from_workflow
|
return api_key_from_workflow
|
||||||
|
|
||||||
return os.environ.get("IDEOGRAM_API_KEY", args.ideogram_api_key)
|
return os.environ.get("IDEOGRAM_API_KEY", args.ideogram_api_key)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Custom Nodes ---
|
||||||
|
|
||||||
class IdeogramGenerate(CustomNode):
|
class IdeogramGenerate(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||||
@ -43,7 +54,7 @@ class IdeogramGenerate(CustomNode):
|
|||||||
"required": {
|
"required": {
|
||||||
"prompt": ("STRING", {"multiline": True}),
|
"prompt": ("STRING", {"multiline": True}),
|
||||||
"resolution": (RESOLUTION_ENUM, {"default": RESOLUTION_ENUM[0]}),
|
"resolution": (RESOLUTION_ENUM, {"default": RESOLUTION_ENUM[0]}),
|
||||||
"model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}),
|
"model": (MODELS_ENUM, {"default": MODELS_ENUM[-1]}),
|
||||||
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
@ -52,6 +63,10 @@ class IdeogramGenerate(CustomNode):
|
|||||||
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
||||||
"seed": Seed,
|
"seed": Seed,
|
||||||
"style_type": (STYLES_ENUM, {}),
|
"style_type": (STYLES_ENUM, {}),
|
||||||
|
# New v3 optional args
|
||||||
|
"rendering_speed": (RENDERING_SPEED_ENUM, {"default": "DEFAULT"}),
|
||||||
|
"aspect_ratio": (ASPECT_RATIO_V3_ENUM, {"default": "disabled"}),
|
||||||
|
"style_reference_images": ("IMAGE",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,56 +75,61 @@ class IdeogramGenerate(CustomNode):
|
|||||||
CATEGORY = "ideogram"
|
CATEGORY = "ideogram"
|
||||||
|
|
||||||
def generate(self, prompt: str, resolution: str, model: str, magic_prompt_option: str,
|
def generate(self, prompt: str, resolution: str, model: str, magic_prompt_option: str,
|
||||||
api_key: str = "", negative_prompt: str = "", num_images: int = 1, seed: int = 0, style_type: str = "AUTO") -> Tuple[torch.Tensor]:
|
api_key: str = "", negative_prompt: str = "", num_images: int = 1, seed: int = 0, style_type: str = "AUTO",
|
||||||
|
rendering_speed: str = "DEFAULT", aspect_ratio: str = "disabled", style_reference_images: ImageBatch = None) -> Tuple[torch.Tensor]:
|
||||||
api_key = api_key_in_env_or_workflow(api_key)
|
api_key = api_key_in_env_or_workflow(api_key)
|
||||||
headers = {"Api-Key": api_key, "Content-Type": "application/json"}
|
|
||||||
|
|
||||||
if model in V2_MODELS:
|
if model in V2_MODELS:
|
||||||
|
headers = {"Api-Key": api_key, "Content-Type": "application/json"}
|
||||||
payload = {
|
payload = {
|
||||||
"image_request": {
|
"image_request": {
|
||||||
"prompt": prompt,
|
"prompt": prompt, "resolution": resolution, "model": model,
|
||||||
"resolution": resolution,
|
"magic_prompt_option": magic_prompt_option, "num_images": num_images,
|
||||||
"model": model,
|
|
||||||
"magic_prompt_option": magic_prompt_option,
|
|
||||||
"num_images": num_images,
|
|
||||||
"style_type": style_type,
|
"style_type": style_type,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if negative_prompt: payload["image_request"]["negative_prompt"] = negative_prompt
|
||||||
if negative_prompt:
|
if seed: payload["image_request"]["seed"] = seed
|
||||||
payload["image_request"]["negative_prompt"] = negative_prompt
|
|
||||||
if seed:
|
|
||||||
payload["image_request"]["seed"] = seed
|
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/generate", headers=headers, json=payload)
|
response = requests.post("https://api.ideogram.ai/generate", headers=headers, json=payload)
|
||||||
|
|
||||||
elif model == "V_3":
|
elif model == "V_3":
|
||||||
payload = {
|
payload = {
|
||||||
"prompt": prompt,
|
"prompt": prompt, "model": model, "magic_prompt": magic_prompt_option,
|
||||||
"resolution": to_v3_resolution(resolution),
|
"num_images": num_images, "style_type": style_type, "rendering_speed": rendering_speed,
|
||||||
"model": model,
|
|
||||||
"magic_prompt": magic_prompt_option,
|
|
||||||
"num_images": num_images,
|
|
||||||
"style_type": style_type,
|
|
||||||
}
|
}
|
||||||
|
if negative_prompt: payload["negative_prompt"] = negative_prompt
|
||||||
|
if seed: payload["seed"] = seed
|
||||||
|
|
||||||
if negative_prompt:
|
# Handle resolution vs aspect_ratio (aspect_ratio takes precedence)
|
||||||
payload["negative_prompt"] = negative_prompt
|
if aspect_ratio != "disabled":
|
||||||
if seed:
|
payload["aspect_ratio"] = aspect_ratio
|
||||||
payload["seed"] = seed
|
else:
|
||||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload)
|
payload["resolution"] = to_v3_resolution(resolution)
|
||||||
|
|
||||||
|
headers = {"Api-Key": api_key}
|
||||||
|
|
||||||
|
# Use multipart/form-data if style references are provided
|
||||||
|
if style_reference_images is not None:
|
||||||
|
files = []
|
||||||
|
for i, style_image in enumerate(style_reference_images):
|
||||||
|
pil_image = tensor2pil(style_image)
|
||||||
|
image_bytes = BytesIO()
|
||||||
|
pil_image.save(image_bytes, format="PNG")
|
||||||
|
files.append(("style_reference_images", (f"style_{i}.png", image_bytes.getvalue(), "image/png")))
|
||||||
|
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, data=payload, files=files)
|
||||||
|
else:
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/generate", headers=headers, json=payload)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid model={model}")
|
raise ValueError(f"Invalid model={model}")
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
for item in response.json()["data"]:
|
for item in response.json()["data"]:
|
||||||
img_response = requests.get(item["url"])
|
img_response = requests.get(item["url"])
|
||||||
img_response.raise_for_status()
|
img_response.raise_for_status()
|
||||||
|
|
||||||
pil_image = Image.open(BytesIO(img_response.content))
|
pil_image = Image.open(BytesIO(img_response.content))
|
||||||
images.append(pil2tensor(pil_image))
|
images.append(pil2tensor(pil_image))
|
||||||
|
|
||||||
return (torch.cat(images, dim=0),)
|
return (torch.cat(images, dim=0),)
|
||||||
|
|
||||||
|
|
||||||
@ -118,17 +138,17 @@ class IdeogramEdit(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"images": ("IMAGE",),
|
"images": ("IMAGE",), "masks": ("MASK",), "prompt": ("STRING", {"multiline": True}),
|
||||||
"masks": ("MASK",),
|
"model": (MODELS_ENUM, {"default": MODELS_ENUM[-1]}),
|
||||||
"prompt": ("STRING", {"multiline": True}),
|
|
||||||
"model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}),
|
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"api_key": ("STRING", {"default": ""}),
|
"api_key": ("STRING", {"default": ""}),
|
||||||
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
||||||
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}), "seed": ("INT", {"default": 0}),
|
||||||
"seed": ("INT", {"default": 0}),
|
|
||||||
"style_type": (STYLES_ENUM, {}),
|
"style_type": (STYLES_ENUM, {}),
|
||||||
|
# New v3 optional args
|
||||||
|
"rendering_speed": (RENDERING_SPEED_ENUM, {"default": "DEFAULT"}),
|
||||||
|
"style_reference_images": ("IMAGE",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,64 +157,48 @@ class IdeogramEdit(CustomNode):
|
|||||||
CATEGORY = "ideogram"
|
CATEGORY = "ideogram"
|
||||||
|
|
||||||
def edit(self, images: RGBImageBatch, masks: MaskBatch, prompt: str, model: str,
|
def edit(self, images: RGBImageBatch, masks: MaskBatch, prompt: str, model: str,
|
||||||
api_key: str = "", magic_prompt_option: str = "AUTO",
|
api_key: str = "", magic_prompt_option: str = "AUTO", num_images: int = 1, seed: int = 0,
|
||||||
num_images: int = 1, seed: int = 0, style_type: str = "AUTO") -> Tuple[torch.Tensor]:
|
style_type: str = "AUTO", rendering_speed: str = "DEFAULT", style_reference_images: ImageBatch = None) -> Tuple[torch.Tensor]:
|
||||||
api_key = api_key_in_env_or_workflow(api_key)
|
api_key = api_key_in_env_or_workflow(api_key)
|
||||||
headers = {"Api-Key": api_key}
|
headers = {"Api-Key": api_key}
|
||||||
image_responses = []
|
image_responses = []
|
||||||
for mask, image in zip(torch.unbind(masks), torch.unbind(images)):
|
for mask_tensor, image_tensor in zip(torch.unbind(masks), torch.unbind(images)):
|
||||||
mask, = MaskToImage().mask_to_image(mask=mask)
|
mask_tensor, = MaskToImage().mask_to_image(mask=mask_tensor)
|
||||||
mask: RGBImageBatch
|
|
||||||
|
|
||||||
image_pil = tensor2pil(image)
|
image_pil, mask_pil = tensor2pil(image_tensor), tensor2pil(mask_tensor)
|
||||||
mask_pil = tensor2pil(mask)
|
image_bytes, mask_bytes = BytesIO(), BytesIO()
|
||||||
|
|
||||||
image_bytes = BytesIO()
|
|
||||||
mask_bytes = BytesIO()
|
|
||||||
image_pil.save(image_bytes, format="PNG")
|
image_pil.save(image_bytes, format="PNG")
|
||||||
mask_pil.save(mask_bytes, format="PNG")
|
mask_pil.save(mask_bytes, format="PNG")
|
||||||
|
|
||||||
if model in V2_MODELS:
|
if model in V2_MODELS:
|
||||||
files = {
|
files = {"image_file": ("image.png", image_bytes.getvalue()), "mask": ("mask.png", mask_bytes.getvalue())}
|
||||||
"image_file": ("image.png", image_bytes.getvalue()),
|
data = {"prompt": prompt, "model": model, "magic_prompt_option": magic_prompt_option, "num_images": num_images, "style_type": style_type}
|
||||||
"mask": ("mask.png", mask_bytes.getvalue()),
|
if seed: data["seed"] = seed
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"model": model,
|
|
||||||
"magic_prompt_option": magic_prompt_option,
|
|
||||||
"num_images": num_images,
|
|
||||||
"style_type": style_type,
|
|
||||||
}
|
|
||||||
if seed:
|
|
||||||
data["seed"] = seed
|
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/edit", headers=headers, files=files, data=data)
|
response = requests.post("https://api.ideogram.ai/edit", headers=headers, files=files, data=data)
|
||||||
|
|
||||||
elif model == "V_3":
|
elif model == "V_3":
|
||||||
files = {
|
data = {"prompt": prompt, "magic_prompt": magic_prompt_option, "num_images": num_images, "rendering_speed": rendering_speed}
|
||||||
"image": ("image.png", image_bytes.getvalue()),
|
if seed: data["seed"] = seed
|
||||||
"mask": ("mask.png", mask_bytes.getvalue()),
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
files_list = [
|
||||||
"prompt": prompt,
|
("image", ("image.png", image_bytes.getvalue(), "image/png")),
|
||||||
"magic_prompt": magic_prompt_option,
|
("mask", ("mask.png", mask_bytes.getvalue(), "image/png")),
|
||||||
"num_images": num_images,
|
]
|
||||||
}
|
if style_reference_images is not None:
|
||||||
if seed:
|
for i, style_image in enumerate(style_reference_images):
|
||||||
data["seed"] = seed
|
pil_ref = tensor2pil(style_image)
|
||||||
|
ref_bytes = BytesIO()
|
||||||
|
pil_ref.save(ref_bytes, format="PNG")
|
||||||
|
files_list.append(("style_reference_images", (f"style_{i}.png", ref_bytes.getvalue(), "image/png")))
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files, data=data)
|
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/edit", headers=headers, files=files_list, data=data)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid model={model}")
|
raise ValueError(f"Invalid model={model}")
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
for item in response.json()["data"]:
|
for item in response.json()["data"]:
|
||||||
img_response = requests.get(item["url"])
|
img_response = requests.get(item["url"])
|
||||||
img_response.raise_for_status()
|
img_response.raise_for_status()
|
||||||
|
|
||||||
pil_image = Image.open(BytesIO(img_response.content))
|
pil_image = Image.open(BytesIO(img_response.content))
|
||||||
image_responses.append(pil2tensor(pil_image))
|
image_responses.append(pil2tensor(pil_image))
|
||||||
|
|
||||||
@ -206,10 +210,9 @@ class IdeogramRemix(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"images": ("IMAGE",),
|
"images": ("IMAGE",), "prompt": ("STRING", {"multiline": True}),
|
||||||
"prompt": ("STRING", {"multiline": True}),
|
|
||||||
"resolution": (RESOLUTION_ENUM, {"default": RESOLUTION_ENUM[0]}),
|
"resolution": (RESOLUTION_ENUM, {"default": RESOLUTION_ENUM[0]}),
|
||||||
"model": (MODELS_ENUM, {"default": MODELS_ENUM[0]}),
|
"model": (MODELS_ENUM, {"default": MODELS_ENUM[-1]}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"api_key": ("STRING", {"default": ""}),
|
"api_key": ("STRING", {"default": ""}),
|
||||||
@ -217,8 +220,11 @@ class IdeogramRemix(CustomNode):
|
|||||||
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
"magic_prompt_option": (AUTO_PROMPT_ENUM, {"default": AUTO_PROMPT_ENUM[0]}),
|
||||||
"negative_prompt": ("STRING", {"multiline": True}),
|
"negative_prompt": ("STRING", {"multiline": True}),
|
||||||
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
||||||
"seed": ("INT", {"default": 0}),
|
"seed": ("INT", {"default": 0}), "style_type": (STYLES_ENUM, {}),
|
||||||
"style_type": (STYLES_ENUM, {}),
|
# New v3 optional args
|
||||||
|
"rendering_speed": (RENDERING_SPEED_ENUM, {"default": "DEFAULT"}),
|
||||||
|
"aspect_ratio": (ASPECT_RATIO_V3_ENUM, {"default": "disabled"}),
|
||||||
|
"style_reference_images": ("IMAGE",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,10 +234,10 @@ class IdeogramRemix(CustomNode):
|
|||||||
|
|
||||||
def remix(self, images: torch.Tensor, prompt: str, resolution: str, model: str,
|
def remix(self, images: torch.Tensor, prompt: str, resolution: str, model: str,
|
||||||
api_key: str = "", image_weight: int = 50, magic_prompt_option: str = "AUTO",
|
api_key: str = "", image_weight: int = 50, magic_prompt_option: str = "AUTO",
|
||||||
negative_prompt: str = "", num_images: int = 1, seed: int = 0, style_type: str = "AUTO") -> Tuple[torch.Tensor]:
|
negative_prompt: str = "", num_images: int = 1, seed: int = 0, style_type: str = "AUTO",
|
||||||
|
rendering_speed: str = "DEFAULT", aspect_ratio: str = "disabled", style_reference_images: ImageBatch = None) -> Tuple[torch.Tensor]:
|
||||||
api_key = api_key_in_env_or_workflow(api_key)
|
api_key = api_key_in_env_or_workflow(api_key)
|
||||||
headers = {"Api-Key": api_key}
|
headers = {"Api-Key": api_key}
|
||||||
|
|
||||||
result_images = []
|
result_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
image_pil = tensor2pil(image)
|
image_pil = tensor2pil(image)
|
||||||
@ -239,60 +245,41 @@ class IdeogramRemix(CustomNode):
|
|||||||
image_pil.save(image_bytes, format="PNG")
|
image_pil.save(image_bytes, format="PNG")
|
||||||
|
|
||||||
if model in V2_MODELS:
|
if model in V2_MODELS:
|
||||||
|
files = {"image_file": ("image.png", image_bytes.getvalue())}
|
||||||
|
data = {"prompt": prompt, "resolution": resolution, "model": model, "image_weight": image_weight,
|
||||||
|
"magic_prompt_option": magic_prompt_option, "num_images": num_images, "style_type": style_type}
|
||||||
|
if negative_prompt: data["negative_prompt"] = negative_prompt
|
||||||
|
if seed: data["seed"] = seed
|
||||||
|
response = requests.post("https://api.ideogram.ai/remix", headers=headers, files=files, data={"image_request": json.dumps(data)})
|
||||||
|
|
||||||
files = {
|
|
||||||
"image_file": ("image.png", image_bytes.getvalue()),
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"resolution": resolution,
|
|
||||||
"model": model,
|
|
||||||
"image_weight": image_weight,
|
|
||||||
"magic_prompt_option": magic_prompt_option,
|
|
||||||
"num_images": num_images,
|
|
||||||
"style_type": style_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
if negative_prompt:
|
|
||||||
data["negative_prompt"] = negative_prompt
|
|
||||||
if seed:
|
|
||||||
data["seed"] = seed
|
|
||||||
|
|
||||||
# data = {"image_request": data}
|
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/remix", headers=headers, files=files, data={
|
|
||||||
"image_request": json.dumps(data)
|
|
||||||
})
|
|
||||||
elif model == "V_3":
|
elif model == "V_3":
|
||||||
files = {
|
|
||||||
"image": ("image.png", image_bytes.getvalue()),
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"prompt": prompt,
|
"prompt": prompt, "image_weight": image_weight, "magic_prompt": magic_prompt_option,
|
||||||
"resolution": to_v3_resolution(resolution),
|
"num_images": num_images, "style_type": style_type, "rendering_speed": rendering_speed,
|
||||||
"image_weight": image_weight,
|
|
||||||
"magic_prompt": magic_prompt_option,
|
|
||||||
"num_images": num_images,
|
|
||||||
"style_type": style_type,
|
|
||||||
}
|
}
|
||||||
|
if negative_prompt: data["negative_prompt"] = negative_prompt
|
||||||
|
if seed: data["seed"] = seed
|
||||||
|
if aspect_ratio != "disabled":
|
||||||
|
data["aspect_ratio"] = aspect_ratio
|
||||||
|
else:
|
||||||
|
data["resolution"] = to_v3_resolution(resolution)
|
||||||
|
|
||||||
if negative_prompt:
|
files_list = [("image", ("image.png", image_bytes.getvalue(), "image/png"))]
|
||||||
data["negative_prompt"] = negative_prompt
|
if style_reference_images is not None:
|
||||||
if seed:
|
for i, style_image in enumerate(style_reference_images):
|
||||||
data["seed"] = seed
|
pil_ref = tensor2pil(style_image)
|
||||||
|
ref_bytes = BytesIO()
|
||||||
|
pil_ref.save(ref_bytes, format="PNG")
|
||||||
|
files_list.append(("style_reference_images", (f"style_{i}.png", ref_bytes.getvalue(), "image/png")))
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files, data=data)
|
response = requests.post("https://api.ideogram.ai/v1/ideogram-v3/remix", headers=headers, files=files_list, data=data)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid model={model}")
|
raise ValueError(f"Invalid model={model}")
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
for item in response.json()["data"]:
|
for item in response.json()["data"]:
|
||||||
img_response = requests.get(item["url"])
|
img_response = requests.get(item["url"])
|
||||||
img_response.raise_for_status()
|
img_response.raise_for_status()
|
||||||
|
|
||||||
pil_image = Image.open(BytesIO(img_response.content))
|
pil_image = Image.open(BytesIO(img_response.content))
|
||||||
result_images.append(pil2tensor(pil_image))
|
result_images.append(pil2tensor(pil_image))
|
||||||
|
|
||||||
@ -300,22 +287,11 @@ class IdeogramRemix(CustomNode):
|
|||||||
|
|
||||||
|
|
||||||
class IdeogramDescribe(CustomNode):
|
class IdeogramDescribe(CustomNode):
|
||||||
"""
|
|
||||||
A ComfyUI node to get a description of an image using the Ideogram API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
def INPUT_TYPES(cls) -> Dict[str, Any]:
|
||||||
"""
|
|
||||||
Defines the input types for the node.
|
|
||||||
"""
|
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {"images": ("IMAGE",)},
|
||||||
"images": ("IMAGE",),
|
"optional": {"api_key": ("STRING", {"default": ""})}
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"api_key": ("STRING", {"default": ""}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
RETURN_TYPES = ("STRING",)
|
||||||
@ -324,43 +300,19 @@ class IdeogramDescribe(CustomNode):
|
|||||||
CATEGORY = "ideogram"
|
CATEGORY = "ideogram"
|
||||||
|
|
||||||
def describe(self, images: ImageBatch, api_key: str = "") -> tuple[list[str]]:
|
def describe(self, images: ImageBatch, api_key: str = "") -> tuple[list[str]]:
|
||||||
"""
|
|
||||||
Sends an image to the Ideogram API and returns a generated description.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images: A batch of images as a tensor.
|
|
||||||
api_key: The Ideogram API key.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing the description string for the first image.
|
|
||||||
"""
|
|
||||||
api_key = api_key_in_env_or_workflow(api_key)
|
api_key = api_key_in_env_or_workflow(api_key)
|
||||||
headers = {"Api-Key": api_key}
|
headers = {"Api-Key": api_key}
|
||||||
|
|
||||||
descriptions_batch = []
|
descriptions_batch = []
|
||||||
for image in images:
|
for image in images:
|
||||||
pil_image = tensor2pil(image)
|
pil_image = tensor2pil(image)
|
||||||
|
|
||||||
image_bytes = BytesIO()
|
image_bytes = BytesIO()
|
||||||
pil_image.save(image_bytes, format="PNG")
|
pil_image.save(image_bytes, format="PNG")
|
||||||
image_bytes.seek(0)
|
files = {"image_file": ("image.png", image_bytes.getvalue(), "image/png")}
|
||||||
|
|
||||||
files = {
|
|
||||||
"image_file": ("image.png", image_bytes.getvalue(), "image/png"),
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post("https://api.ideogram.ai/describe", headers=headers, files=files)
|
response = requests.post("https://api.ideogram.ai/describe", headers=headers, files=files)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
descriptions = data.get("descriptions", [])
|
descriptions = data.get("descriptions", [])
|
||||||
|
descriptions_batch.append(descriptions[0].get("text", "") if descriptions else "")
|
||||||
if not descriptions:
|
|
||||||
descriptions_batch.append("")
|
|
||||||
else:
|
|
||||||
first_description = descriptions[0].get("text", "")
|
|
||||||
descriptions_batch.append(first_description)
|
|
||||||
|
|
||||||
return (descriptions_batch,)
|
return (descriptions_batch,)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -21,115 +20,135 @@ def api_key():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_image():
|
def sample_image() -> RGBImageBatch:
|
||||||
return torch.ones((1, 1024, 1024, 3)) * 0.8 # Light gray image
|
"""A light gray 1024x1024 image."""
|
||||||
|
return torch.ones((1, 1024, 1024, 3), dtype=torch.float32) * 0.8
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def black_square_image() -> RGBImageBatch:
|
def black_square_image() -> RGBImageBatch:
|
||||||
# A black square image (1 batch, 1024x1024 pixels, 3 channels)
|
"""A black square image (1 batch, 1024x1024 pixels, 3 channels)"""
|
||||||
return torch.zeros((1, 1024, 1024, 3), dtype=torch.float32)
|
return torch.zeros((1, 1024, 1024, 3), dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def red_style_image() -> RGBImageBatch:
|
||||||
|
"""A solid red 512x512 image to be used as a style reference."""
|
||||||
|
red_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32)
|
||||||
|
red_image[..., 0] = 1.0 # Set red channel to max
|
||||||
|
return red_image
|
||||||
|
|
||||||
|
|
||||||
def test_ideogram_describe(api_key, black_square_image):
|
def test_ideogram_describe(api_key, black_square_image):
|
||||||
"""
|
|
||||||
Tests the IdeogramDescribe node by passing it a black square image and
|
|
||||||
asserting that the returned description contains "black" and "square".
|
|
||||||
"""
|
|
||||||
node = IdeogramDescribe()
|
node = IdeogramDescribe()
|
||||||
|
descriptions_list, = node.describe(images=black_square_image, api_key=api_key)
|
||||||
# The node's method returns a tuple containing a list of descriptions
|
# todo: why does this do some wacky thing about buildings?
|
||||||
descriptions_list, = node.describe(
|
assert len(descriptions_list) > 0
|
||||||
images=black_square_image,
|
|
||||||
api_key=api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
# We passed one image, so we expect one description in the list
|
|
||||||
assert isinstance(descriptions_list, list)
|
|
||||||
assert len(descriptions_list) == 1
|
|
||||||
|
|
||||||
description = descriptions_list[0]
|
|
||||||
|
|
||||||
assert isinstance(description, str)
|
|
||||||
assert "black" in description.lower()
|
|
||||||
assert "square" in description.lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
@pytest.mark.parametrize(
|
||||||
def test_ideogram_generate(api_key, model):
|
"model, aspect_ratio, use_style_ref",
|
||||||
|
[
|
||||||
|
("V_2_TURBO", "disabled", False), # Test V2 model
|
||||||
|
("V_3", "disabled", False), # Test V3 model, no special args
|
||||||
|
("V_3", "16x9", True), # Test V3 model with style and aspect ratio
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_ideogram_generate(api_key, model, aspect_ratio, use_style_ref, red_style_image):
|
||||||
node = IdeogramGenerate()
|
node = IdeogramGenerate()
|
||||||
|
style_ref = red_style_image if use_style_ref else None
|
||||||
|
|
||||||
image, = node.generate(
|
image, = node.generate(
|
||||||
prompt="a serene mountain landscape at sunset with snow-capped peaks",
|
prompt="a vibrant fantasy landscape",
|
||||||
resolution="RESOLUTION_1024_1024",
|
resolution="RESOLUTION_1024_1024",
|
||||||
model=model,
|
model=model,
|
||||||
magic_prompt_option="AUTO",
|
magic_prompt_option="AUTO",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
num_images=1
|
num_images=1,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
style_reference_images=style_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify output format
|
|
||||||
assert isinstance(image, torch.Tensor)
|
assert isinstance(image, torch.Tensor)
|
||||||
assert image.shape[1:] == (1024, 1024, 3) # HxWxC format
|
|
||||||
assert image.dtype == torch.float32
|
|
||||||
assert torch.all((image >= 0) & (image <= 1))
|
assert torch.all((image >= 0) & (image <= 1))
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
if model == "V_3":
|
||||||
def test_ideogram_edit(api_key, sample_image, model):
|
if aspect_ratio == "16x9":
|
||||||
node = IdeogramEdit()
|
# For a 16x9 aspect ratio, width should be greater than height. Shape is (B, H, W, C)
|
||||||
|
assert image.shape[2] > image.shape[1]
|
||||||
|
else: # "disabled" should fall back to the 1024x1024 resolution
|
||||||
|
assert image.shape[1:] == (1024, 1024, 3)
|
||||||
|
|
||||||
# white is areas to keep, black is areas to repaint
|
if use_style_ref:
|
||||||
mask = torch.full((1, 1024, 1024), fill_value=1.0)
|
# Check for red color influence from the style image
|
||||||
center_start = 386
|
red_channel_mean = image[..., 0].mean().item()
|
||||||
center_end = 640
|
assert red_channel_mean > 0.35, "Red channel should be prominent due to style reference"
|
||||||
mask[:, center_start:center_end, center_start:center_end] = 0.0
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, use_style_ref",
|
||||||
|
[
|
||||||
|
("V_2_TURBO", False), # Test V2 model
|
||||||
|
("V_3", False), # Test V3 model, no style ref
|
||||||
|
("V_3", True), # Test V3 model with style ref
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_ideogram_edit(api_key, sample_image, model, use_style_ref, red_style_image):
|
||||||
|
node = IdeogramEdit()
|
||||||
|
style_ref = red_style_image if use_style_ref else None
|
||||||
|
|
||||||
|
mask = torch.zeros((1, 1024, 1024), dtype=torch.float32)
|
||||||
|
# Create a black square in the middle to be repainted
|
||||||
|
mask[:, 256:768, 256:768] = 1.0
|
||||||
|
# Invert mask: black regions are edited
|
||||||
|
mask = 1.0 - mask
|
||||||
|
|
||||||
image, = node.edit(
|
image, = node.edit(
|
||||||
images=sample_image,
|
images=sample_image, masks=mask,
|
||||||
masks=mask,
|
prompt="a vibrant, colorful object",
|
||||||
magic_prompt_option="OFF",
|
model=model, api_key=api_key, num_images=1,
|
||||||
prompt="a solid black rectangle",
|
style_reference_images=style_ref,
|
||||||
model=model,
|
|
||||||
api_key=api_key,
|
|
||||||
num_images=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify output format
|
|
||||||
assert isinstance(image, torch.Tensor)
|
assert isinstance(image, torch.Tensor)
|
||||||
assert image.shape[1:] == (1024, 1024, 3)
|
assert image.shape[1:] == (1024, 1024, 3)
|
||||||
assert image.dtype == torch.float32
|
|
||||||
assert torch.all((image >= 0) & (image <= 1))
|
|
||||||
|
|
||||||
# Verify the center is darker than the original
|
if model == "V_3" and use_style_ref:
|
||||||
center_region = image[:, center_start:center_end, center_start:center_end, :]
|
# Check for red color influence in the edited region
|
||||||
outer_region = image[:, :center_start, :, :] # Use top portion for comparison
|
edited_region = image[:, 256:768, 256:768, :]
|
||||||
|
red_channel_mean = edited_region[..., 0].mean().item()
|
||||||
|
assert red_channel_mean > 0.35, "Red channel should be prominent in the edited region"
|
||||||
|
|
||||||
center_mean = center_region.mean().item()
|
|
||||||
outer_mean = outer_region.mean().item()
|
|
||||||
|
|
||||||
assert center_mean < outer_mean, f"Center region ({center_mean:.3f}) should be darker than outer region ({outer_mean:.3f})"
|
@pytest.mark.parametrize(
|
||||||
assert center_mean < 0.6, f"Center region ({center_mean:.3f}) should be dark"
|
"model, aspect_ratio, use_style_ref",
|
||||||
|
[
|
||||||
@pytest.mark.parametrize("model", ["V_2_TURBO", "V_3"])
|
("V_2_TURBO", "disabled", False),
|
||||||
def test_ideogram_remix(api_key, sample_image, model):
|
("V_3", "disabled", False),
|
||||||
|
("V_3", "16x9", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_ideogram_remix(api_key, sample_image, model, aspect_ratio, use_style_ref, red_style_image):
|
||||||
node = IdeogramRemix()
|
node = IdeogramRemix()
|
||||||
|
style_ref = red_style_image if use_style_ref else None
|
||||||
|
|
||||||
image, = node.remix(
|
image, = node.remix(
|
||||||
images=sample_image,
|
images=sample_image,
|
||||||
prompt="transform into a vibrant blue ocean scene with waves",
|
prompt="transform into a vibrant, colorful abstract scene",
|
||||||
resolution="RESOLUTION_1024_1024",
|
resolution="RESOLUTION_1024_1024",
|
||||||
model=model,
|
model=model, api_key=api_key, num_images=1,
|
||||||
api_key=api_key,
|
aspect_ratio=aspect_ratio,
|
||||||
num_images=1
|
style_reference_images=style_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify output format
|
|
||||||
assert isinstance(image, torch.Tensor)
|
assert isinstance(image, torch.Tensor)
|
||||||
assert image.shape[1:] == (1024, 1024, 3)
|
|
||||||
assert image.dtype == torch.float32
|
|
||||||
assert torch.all((image >= 0) & (image <= 1))
|
|
||||||
|
|
||||||
# Since we asked for a blue ocean scene, verify there's significant blue component
|
if model == "V_3":
|
||||||
blue_channel = image[..., 2] # RGB where blue is index 2
|
if aspect_ratio == "16x9":
|
||||||
blue_mean = blue_channel.mean().item()
|
assert image.shape[2] > image.shape[1]
|
||||||
assert blue_mean > 0.4, f"Blue channel mean ({blue_mean:.3f}) should be significant for an ocean scene"
|
else:
|
||||||
|
assert image.shape[1:] == (1024, 1024, 3)
|
||||||
|
|
||||||
|
if use_style_ref:
|
||||||
|
red_channel_mean = image[..., 0].mean().item()
|
||||||
|
assert red_channel_mean > 0.35, "Red channel should be prominent due to style reference"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user