mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 00:30:21 +08:00
feat(api-nodes): add Grok Imagine nodes (#12136)
This commit is contained in:
parent
4f5bd39b1c
commit
6c5f906bf2
67
comfy_api_nodes/apis/grok.py
Normal file
67
comfy_api_nodes/apis/grok.py
Normal file
@ -0,0 +1,67 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
aspect_ratio: str = Field(...)
|
||||
n: int = Field(...)
|
||||
seed: int = Field(...)
|
||||
response_for: str = Field("url")
|
||||
|
||||
|
||||
class InputUrlObject(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class ImageEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
image: InputUrlObject = Field(...)
|
||||
prompt: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
n: int = Field(...)
|
||||
seed: int = Field(...)
|
||||
response_for: str = Field("url")
|
||||
|
||||
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
image: InputUrlObject | None = Field(...)
|
||||
duration: int = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class VideoEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
video: InputUrlObject = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class ImageResponseObject(BaseModel):
|
||||
url: str | None = Field(None)
|
||||
b64_json: str | None = Field(None)
|
||||
revised_prompt: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
data: list[ImageResponseObject] = Field(...)
|
||||
|
||||
|
||||
class VideoGenerationResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
class VideoResponseObject(BaseModel):
|
||||
url: str = Field(...)
|
||||
upsampled_prompt: str | None = Field(None)
|
||||
duration: int = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponse(BaseModel):
|
||||
status: str | None = Field(None)
|
||||
video: VideoResponseObject | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
417
comfy_api_nodes/nodes_grok.py
Normal file
417
comfy_api_nodes/nodes_grok.py
Normal file
@ -0,0 +1,417 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.grok import (
|
||||
ImageEditRequest,
|
||||
ImageGenerationRequest,
|
||||
ImageGenerationResponse,
|
||||
InputUrlObject,
|
||||
VideoEditRequest,
|
||||
VideoGenerationRequest,
|
||||
VideoGenerationResponse,
|
||||
VideoStatusResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_fs_object_size,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
|
||||
class GrokImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageNode",
|
||||
display_name="Grok Image",
|
||||
category="api node/image/Grok",
|
||||
description="Generate images using Grok based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="The text prompt used to generate the image",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"9:19.5",
|
||||
"19.5:9",
|
||||
"9:20",
|
||||
"20:9",
|
||||
"1:2",
|
||||
"2:1",
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"number_of_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Number of images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
|
||||
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
number_of_images: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"),
|
||||
data=ImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
n=number_of_images,
|
||||
seed=seed,
|
||||
),
|
||||
response_model=ImageGenerationResponse,
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||
return IO.NodeOutput(
|
||||
torch.cat(
|
||||
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class GrokImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNode",
|
||||
display_name="Grok Image Edit",
|
||||
category="api node/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="The text prompt used to generate the image",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1K"]),
|
||||
IO.Int.Input(
|
||||
"number_of_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Number of edited images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
|
||||
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
number_of_images: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
|
||||
data=ImageEditRequest(
|
||||
model=model,
|
||||
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
|
||||
prompt=prompt,
|
||||
resolution=resolution.lower(),
|
||||
n=number_of_images,
|
||||
seed=seed,
|
||||
),
|
||||
response_model=ImageGenerationResponse,
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||
return IO.NodeOutput(
|
||||
torch.cat(
|
||||
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class GrokVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoNode",
|
||||
display_name="Grok Video",
|
||||
category="api node/video/Grok",
|
||||
description="Generate video from a prompt or an image",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["auto", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=6,
|
||||
min=1,
|
||||
max=15,
|
||||
step=1,
|
||||
tooltip="The duration of the output video in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Image.Input("image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
|
||||
expr="""
|
||||
(
|
||||
$base := 0.181 * widgets.duration;
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}")
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model,
|
||||
image=image_url,
|
||||
prompt=prompt,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoEditNode",
|
||||
display_name="Grok Video Edit",
|
||||
category="api node/video/Grok",
|
||||
description="Edit an existing video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Maximum supported duration is 8.7 seconds and 50MB file size."),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
validate_video_duration(video, min_duration=1, max_duration=8.7)
|
||||
video_stream = video.get_stream_source()
|
||||
video_size = get_fs_object_size(video_stream)
|
||||
if video_size > 50 * 1024 * 1024:
|
||||
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"),
|
||||
data=VideoEditRequest(
|
||||
model=model,
|
||||
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokVideoNode,
|
||||
GrokVideoEditNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GrokExtension:
|
||||
return GrokExtension()
|
||||
Loading…
Reference in New Issue
Block a user