Merge pull request #1 from austin1997/add-grok-api-key-support-9519594855872412057

feat: add official xAI API key support to Grok nodes
This commit is contained in:
austin1997 2026-04-05 17:26:37 +08:00 committed by GitHub
commit 7e1e4a182d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -100,6 +100,13 @@ class GrokImageNode(IO.ComfyNode):
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
IO.Combo.Input("resolution", options=["1K", "2K"], optional=True), IO.Combo.Input("resolution", options=["1K", "2K"], optional=True),
IO.String.Input(
"xai_api_key",
default="",
tooltip="Your xAI API Key (optional). If provided, it will bypass Comfy org limits.",
optional=True,
advanced=True,
),
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
@ -130,11 +137,23 @@ class GrokImageNode(IO.ComfyNode):
number_of_images: int, number_of_images: int,
seed: int, seed: int,
resolution: str = "1K", resolution: str = "1K",
xai_api_key: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
xai_api_key = xai_api_key.strip()
if xai_api_key.lower().startswith("bearer "):
xai_api_key = xai_api_key[7:].strip()
path = "/proxy/xai/v1/images/generations"
headers = None
if xai_api_key:
path = "https://api.x.ai/v1/images/generations"
headers = {"Authorization": f"Bearer {xai_api_key}"}
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"), ApiEndpoint(path=path, method="POST", headers=headers),
data=ImageGenerationRequest( data=ImageGenerationRequest(
model=model, model=model,
prompt=prompt, prompt=prompt,
@ -217,6 +236,13 @@ class GrokImageEditNode(IO.ComfyNode):
optional=True, optional=True,
tooltip="Only allowed when multiple images are connected to the image input.", tooltip="Only allowed when multiple images are connected to the image input.",
), ),
IO.String.Input(
"xai_api_key",
default="",
tooltip="Your xAI API Key (optional). If provided, it will bypass Comfy org limits.",
optional=True,
advanced=True,
),
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
@ -248,6 +274,7 @@ class GrokImageEditNode(IO.ComfyNode):
number_of_images: int, number_of_images: int,
seed: int, seed: int,
aspect_ratio: str = "auto", aspect_ratio: str = "auto",
xai_api_key: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
if model == "grok-imagine-image-pro": if model == "grok-imagine-image-pro":
@ -259,9 +286,20 @@ class GrokImageEditNode(IO.ComfyNode):
raise ValueError( raise ValueError(
"Custom aspect ratio is only allowed when multiple images are connected to the image input." "Custom aspect ratio is only allowed when multiple images are connected to the image input."
) )
xai_api_key = xai_api_key.strip()
if xai_api_key.lower().startswith("bearer "):
xai_api_key = xai_api_key[7:].strip()
path = "/proxy/xai/v1/images/edits"
headers = None
if xai_api_key:
path = "https://api.x.ai/v1/images/edits"
headers = {"Authorization": f"Bearer {xai_api_key}"}
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), ApiEndpoint(path=path, method="POST", headers=headers),
data=ImageEditRequest( data=ImageEditRequest(
model=model, model=model,
images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image], images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image],
@ -330,6 +368,13 @@ class GrokVideoNode(IO.ComfyNode):
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
IO.Image.Input("image", optional=True), IO.Image.Input("image", optional=True),
IO.String.Input(
"xai_api_key",
default="",
tooltip="Your xAI API Key (optional). If provided, it will bypass Comfy org limits.",
optional=True,
advanced=True,
),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -362,6 +407,7 @@ class GrokVideoNode(IO.ComfyNode):
duration: int, duration: int,
seed: int, seed: int,
image: Input.Image | None = None, image: Input.Image | None = None,
xai_api_key: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if model == "grok-imagine-video-beta": if model == "grok-imagine-video-beta":
model = "grok-imagine-video" model = "grok-imagine-video"
@ -371,9 +417,20 @@ class GrokVideoNode(IO.ComfyNode):
raise ValueError("Only one input image is supported.") raise ValueError("Only one input image is supported.")
image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}") image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}")
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
xai_api_key = xai_api_key.strip()
if xai_api_key.lower().startswith("bearer "):
xai_api_key = xai_api_key[7:].strip()
path = "/proxy/xai/v1/videos/generations"
headers = None
if xai_api_key:
path = "https://api.x.ai/v1/videos/generations"
headers = {"Authorization": f"Bearer {xai_api_key}"}
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), ApiEndpoint(path=path, method="POST", headers=headers),
data=VideoGenerationRequest( data=VideoGenerationRequest(
model=model, model=model,
image=image_url, image=image_url,
@ -385,9 +442,13 @@ class GrokVideoNode(IO.ComfyNode):
), ),
response_model=VideoGenerationResponse, response_model=VideoGenerationResponse,
) )
poll_path = f"/proxy/xai/v1/videos/{initial_response.request_id}"
if xai_api_key:
poll_path = f"https://api.x.ai/v1/videos/{initial_response.request_id}"
response = await poll_op( response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), ApiEndpoint(path=poll_path, headers=headers),
status_extractor=lambda r: r.status if r.status is not None else "complete", status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse, response_model=VideoStatusResponse,
price_extractor=_extract_grok_price, price_extractor=_extract_grok_price,
@ -423,6 +484,13 @@ class GrokVideoEditNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; " tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
IO.String.Input(
"xai_api_key",
default="",
tooltip="Your xAI API Key (optional). If provided, it will bypass Comfy org limits.",
optional=True,
advanced=True,
),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -445,6 +513,7 @@ class GrokVideoEditNode(IO.ComfyNode):
prompt: str, prompt: str,
video: Input.Video, video: Input.Video,
seed: int, seed: int,
xai_api_key: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
validate_video_duration(video, min_duration=1, max_duration=8.7) validate_video_duration(video, min_duration=1, max_duration=8.7)
@ -452,9 +521,20 @@ class GrokVideoEditNode(IO.ComfyNode):
video_size = get_fs_object_size(video_stream) video_size = get_fs_object_size(video_stream)
if video_size > 50 * 1024 * 1024: if video_size > 50 * 1024 * 1024:
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
xai_api_key = xai_api_key.strip()
if xai_api_key.lower().startswith("bearer "):
xai_api_key = xai_api_key[7:].strip()
path = "/proxy/xai/v1/videos/edits"
headers = None
if xai_api_key:
path = "https://api.x.ai/v1/videos/edits"
headers = {"Authorization": f"Bearer {xai_api_key}"}
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"), ApiEndpoint(path=path, method="POST", headers=headers),
data=VideoEditRequest( data=VideoEditRequest(
model=model, model=model,
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
@ -463,9 +543,13 @@ class GrokVideoEditNode(IO.ComfyNode):
), ),
response_model=VideoGenerationResponse, response_model=VideoGenerationResponse,
) )
poll_path = f"/proxy/xai/v1/videos/{initial_response.request_id}"
if xai_api_key:
poll_path = f"https://api.x.ai/v1/videos/{initial_response.request_id}"
response = await poll_op( response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), ApiEndpoint(path=poll_path, headers=headers),
status_extractor=lambda r: r.status if r.status is not None else "complete", status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse, response_model=VideoStatusResponse,
price_extractor=_extract_grok_price, price_extractor=_extract_grok_price,
@ -539,6 +623,13 @@ class GrokVideoReferenceNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; " tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
IO.String.Input(
"xai_api_key",
default="",
tooltip="Your xAI API Key (optional). If provided, it will bypass Comfy org limits.",
optional=True,
advanced=True,
),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -573,8 +664,18 @@ class GrokVideoReferenceNode(IO.ComfyNode):
prompt: str, prompt: str,
model: dict, model: dict,
seed: int, seed: int,
xai_api_key: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
xai_api_key = xai_api_key.strip()
if xai_api_key.lower().startswith("bearer "):
xai_api_key = xai_api_key[7:].strip()
# We must use proxy to upload images temporarily even if they provide their own key for video generation
# because the API requires URLs and we use our proxy for image hosting during the request.
# Wait, if they are providing their own key to our backend for generation,
# `upload_images_to_comfyapi` relies on `comfyapi`. This is fine.
ref_image_urls = await upload_images_to_comfyapi( ref_image_urls = await upload_images_to_comfyapi(
cls, cls,
list(model["reference_images"].values()), list(model["reference_images"].values()),
@ -582,9 +683,16 @@ class GrokVideoReferenceNode(IO.ComfyNode):
wait_label="Uploading base images", wait_label="Uploading base images",
max_images=7, max_images=7,
) )
path = "/proxy/xai/v1/videos/generations"
headers = None
if xai_api_key:
path = "https://api.x.ai/v1/videos/generations"
headers = {"Authorization": f"Bearer {xai_api_key}"}
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), ApiEndpoint(path=path, method="POST", headers=headers),
data=VideoGenerationRequest( data=VideoGenerationRequest(
model=model["model"], model=model["model"],
reference_images=[InputUrlObject(url=i) for i in ref_image_urls], reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
@ -596,9 +704,13 @@ class GrokVideoReferenceNode(IO.ComfyNode):
), ),
response_model=VideoGenerationResponse, response_model=VideoGenerationResponse,
) )
poll_path = f"/proxy/xai/v1/videos/{initial_response.request_id}"
if xai_api_key:
poll_path = f"https://api.x.ai/v1/videos/{initial_response.request_id}"
response = await poll_op( response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), ApiEndpoint(path=poll_path, headers=headers),
status_extractor=lambda r: r.status if r.status is not None else "complete", status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse, response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price, price_extractor=_extract_grok_video_price,
@ -653,6 +765,13 @@ class GrokVideoExtendNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; " tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
IO.String.Input(
"xai_api_key",
default="",
tooltip="Your xAI API Key (optional). If provided, it will bypass Comfy org limits.",
optional=True,
advanced=True,
),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -685,15 +804,27 @@ class GrokVideoExtendNode(IO.ComfyNode):
video: Input.Video, video: Input.Video,
model: dict, model: dict,
seed: int, seed: int,
xai_api_key: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
validate_video_duration(video, min_duration=2, max_duration=15) validate_video_duration(video, min_duration=2, max_duration=15)
video_size = get_fs_object_size(video.get_stream_source()) video_size = get_fs_object_size(video.get_stream_source())
if video_size > 50 * 1024 * 1024: if video_size > 50 * 1024 * 1024:
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
xai_api_key = xai_api_key.strip()
if xai_api_key.lower().startswith("bearer "):
xai_api_key = xai_api_key[7:].strip()
path = "/proxy/xai/v1/videos/extensions"
headers = None
if xai_api_key:
path = "https://api.x.ai/v1/videos/extensions"
headers = {"Authorization": f"Bearer {xai_api_key}"}
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"), ApiEndpoint(path=path, method="POST", headers=headers),
data=VideoExtensionRequest( data=VideoExtensionRequest(
prompt=prompt, prompt=prompt,
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
@ -701,9 +832,13 @@ class GrokVideoExtendNode(IO.ComfyNode):
), ),
response_model=VideoGenerationResponse, response_model=VideoGenerationResponse,
) )
poll_path = f"/proxy/xai/v1/videos/{initial_response.request_id}"
if xai_api_key:
poll_path = f"https://api.x.ai/v1/videos/{initial_response.request_id}"
response = await poll_op( response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), ApiEndpoint(path=poll_path, headers=headers),
status_extractor=lambda r: r.status if r.status is not None else "complete", status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse, response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price, price_extractor=_extract_grok_video_price,