feat(api-nodes): add new Gemini model (#10789)

This commit is contained in:
Alexander Piskun 2025-11-19 00:26:44 +02:00 committed by GitHub
parent d526974576
commit 24fdb92edf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 246 additions and 32 deletions

View File

@ -1,22 +1,229 @@
from typing import Optional from datetime import date
from enum import Enum
from typing import Any
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata from pydantic import BaseModel, Field
from pydantic import BaseModel
class GeminiSafetyCategory(str, Enum):
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
class GeminiSafetyThreshold(str, Enum):
OFF = "OFF"
BLOCK_NONE = "BLOCK_NONE"
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
class GeminiSafetySetting(BaseModel):
category: GeminiSafetyCategory
threshold: GeminiSafetyThreshold
class GeminiRole(str, Enum):
user = "user"
model = "model"
class GeminiMimeType(str, Enum):
application_pdf = "application/pdf"
audio_mpeg = "audio/mpeg"
audio_mp3 = "audio/mp3"
audio_wav = "audio/wav"
image_png = "image/png"
image_jpeg = "image/jpeg"
image_webp = "image/webp"
text_plain = "text/plain"
video_mov = "video/mov"
video_mpeg = "video/mpeg"
video_mp4 = "video/mp4"
video_mpg = "video/mpg"
video_avi = "video/avi"
video_wmv = "video/wmv"
video_mpegps = "video/mpegps"
video_flv = "video/flv"
class GeminiInlineData(BaseModel):
data: str | None = Field(
None,
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
)
mimeType: GeminiMimeType | None = Field(None)
class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
text: str | None = Field(None)
class GeminiTextPart(BaseModel):
text: str | None = Field(None)
class GeminiContent(BaseModel):
parts: list[GeminiPart] = Field(...)
role: GeminiRole = Field(..., examples=["user"])
class GeminiSystemInstructionContent(BaseModel):
parts: list[GeminiTextPart] = Field(
...,
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
class GeminiFunctionDeclaration(BaseModel):
description: str | None = Field(None)
name: str = Field(...)
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
class GeminiTool(BaseModel):
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
class GeminiOffset(BaseModel):
nanos: int | None = Field(None, ge=0, le=999999999)
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
class GeminiVideoMetadata(BaseModel):
endOffset: GeminiOffset | None = Field(None)
startOffset: GeminiOffset | None = Field(None)
class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(1, ge=0.0, le=2.0)
topK: int | None = Field(40, ge=1)
topP: float | None = Field(0.95, ge=0.0, le=1.0)
class GeminiImageConfig(BaseModel): class GeminiImageConfig(BaseModel):
aspectRatio: Optional[str] = None aspectRatio: str | None = Field(None)
resolution: str | None = Field(None)
class GeminiImageGenerationConfig(GeminiGenerationConfig): class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[list[str]] = None responseModalities: list[str] | None = Field(None)
imageConfig: Optional[GeminiImageConfig] = None imageConfig: GeminiImageConfig | None = Field(None)
class GeminiImageGenerateContentRequest(BaseModel): class GeminiImageGenerateContentRequest(BaseModel):
contents: list[GeminiContent] contents: list[GeminiContent] = Field(...)
generationConfig: Optional[GeminiImageGenerationConfig] = None generationConfig: GeminiImageGenerationConfig | None = Field(None)
safetySettings: Optional[list[GeminiSafetySetting]] = None safetySettings: list[GeminiSafetySetting] | None = Field(None)
systemInstruction: Optional[GeminiSystemInstructionContent] = None systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: Optional[list[GeminiTool]] = None tools: list[GeminiTool] | None = Field(None)
videoMetadata: Optional[GeminiVideoMetadata] = None videoMetadata: GeminiVideoMetadata | None = Field(None)
class GeminiGenerateContentRequest(BaseModel):
contents: list[GeminiContent] = Field(...)
generationConfig: GeminiGenerationConfig | None = Field(None)
safetySettings: list[GeminiSafetySetting] | None = Field(None)
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
class Modality(str, Enum):
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
TEXT = "TEXT"
IMAGE = "IMAGE"
VIDEO = "VIDEO"
AUDIO = "AUDIO"
DOCUMENT = "DOCUMENT"
class ModalityTokenCount(BaseModel):
modality: Modality | None = None
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
class Probability(str, Enum):
NEGLIGIBLE = "NEGLIGIBLE"
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
UNKNOWN = "UNKNOWN"
class GeminiSafetyRating(BaseModel):
category: GeminiSafetyCategory | None = None
probability: Probability | None = Field(
None,
description="The probability that the content violates the specified safety category",
)
class GeminiCitation(BaseModel):
authors: list[str] | None = None
endIndex: int | None = None
license: str | None = None
publicationDate: date | None = None
startIndex: int | None = None
title: str | None = None
uri: str | None = None
class GeminiCitationMetadata(BaseModel):
citations: list[GeminiCitation] | None = None
class GeminiCandidate(BaseModel):
citationMetadata: GeminiCitationMetadata | None = None
content: GeminiContent | None = None
finishReason: str | None = None
safetyRatings: list[GeminiSafetyRating] | None = None
class GeminiPromptFeedback(BaseModel):
blockReason: str | None = None
blockReasonMessage: str | None = None
safetyRatings: list[GeminiSafetyRating] | None = None
class GeminiUsageMetadata(BaseModel):
cachedContentTokenCount: int | None = Field(
None,
description="Output only. Number of tokens in the cached part in the input (the cached content).",
)
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
None, description="Breakdown of candidate tokens by modality."
)
promptTokenCount: int | None = Field(
None,
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
)
promptTokensDetails: list[ModalityTokenCount] | None = Field(
None, description="Breakdown of prompt tokens by modality."
)
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
class GeminiGenerateContentResponse(BaseModel):
candidates: list[GeminiCandidate] | None = Field(None)
promptFeedback: GeminiPromptFeedback | None = Field(None)
usageMetadata: GeminiUsageMetadata | None = Field(None)

View File

@ -3,8 +3,6 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
""" """
from __future__ import annotations
import base64 import base64
import json import json
import os import os
@ -12,7 +10,7 @@ import time
import uuid import uuid
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from typing import Literal, Optional from typing import Literal
import torch import torch
from typing_extensions import override from typing_extensions import override
@ -20,18 +18,17 @@ from typing_extensions import override
import folder_paths import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis.gemini_api import (
GeminiContent, GeminiContent,
GeminiGenerateContentRequest, GeminiGenerateContentRequest,
GeminiGenerateContentResponse, GeminiGenerateContentResponse,
GeminiInlineData,
GeminiMimeType,
GeminiPart,
)
from comfy_api_nodes.apis.gemini_api import (
GeminiImageConfig, GeminiImageConfig,
GeminiImageGenerateContentRequest, GeminiImageGenerateContentRequest,
GeminiImageGenerationConfig, GeminiImageGenerationConfig,
GeminiInlineData,
GeminiMimeType,
GeminiPart,
GeminiRole,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -57,6 +54,7 @@ class GeminiModel(str, Enum):
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
gemini_2_5_pro = "gemini-2.5-pro" gemini_2_5_pro = "gemini-2.5-pro"
gemini_2_5_flash = "gemini-2.5-flash" gemini_2_5_flash = "gemini-2.5-flash"
gemini_3_0_pro = "gemini-3-pro-preview"
class GeminiImageModel(str, Enum): class GeminiImageModel(str, Enum):
@ -103,6 +101,16 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
Returns: Returns:
List of response parts matching the requested type. List of response parts matching the requested type.
""" """
if response.candidates is None:
if response.promptFeedback.blockReason:
feedback = response.promptFeedback
raise ValueError(
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
)
raise NotImplementedError(
"Gemini returned no response candidates. "
"Please report to ComfyUI repository with the example of workflow to reproduce this."
)
parts = [] parts = []
for part in response.candidates[0].content.parts: for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text: if part_type == "text" and hasattr(part, "text") and part.text:
@ -272,10 +280,10 @@ class GeminiNode(IO.ComfyNode):
prompt: str, prompt: str,
model: str, model: str,
seed: int, seed: int,
images: Optional[torch.Tensor] = None, images: torch.Tensor | None = None,
audio: Optional[Input.Audio] = None, audio: Input.Audio | None = None,
video: Optional[Input.Video] = None, video: Input.Video | None = None,
files: Optional[list[GeminiPart]] = None, files: list[GeminiPart] | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -300,7 +308,7 @@ class GeminiNode(IO.ComfyNode):
data=GeminiGenerateContentRequest( data=GeminiGenerateContentRequest(
contents=[ contents=[
GeminiContent( GeminiContent(
role="user", role=GeminiRole.user,
parts=parts, parts=parts,
) )
] ]
@ -308,7 +316,6 @@ class GeminiNode(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
) )
# Get result output
output_text = get_text_from_response(response) output_text = get_text_from_response(response)
if output_text: if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
@ -406,7 +413,7 @@ class GeminiInputFiles(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput: def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput:
"""Loads and formats input files for Gemini API.""" """Loads and formats input files for Gemini API."""
if GEMINI_INPUT_FILES is None: if GEMINI_INPUT_FILES is None:
GEMINI_INPUT_FILES = [] GEMINI_INPUT_FILES = []
@ -421,7 +428,7 @@ class GeminiImage(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="GeminiImageNode", node_id="GeminiImageNode",
display_name="Google Gemini Image", display_name="Nano Banana (Google Gemini Image)",
category="api node/image/Gemini", category="api node/image/Gemini",
description="Edit images synchronously via Google API.", description="Edit images synchronously via Google API.",
inputs=[ inputs=[
@ -488,8 +495,8 @@ class GeminiImage(IO.ComfyNode):
prompt: str, prompt: str,
model: str, model: str,
seed: int, seed: int,
images: Optional[torch.Tensor] = None, images: torch.Tensor | None = None,
files: Optional[list[GeminiPart]] = None, files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto", aspect_ratio: str = "auto",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
@ -510,7 +517,7 @@ class GeminiImage(IO.ComfyNode):
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
data=GeminiImageGenerateContentRequest( data=GeminiImageGenerateContentRequest(
contents=[ contents=[
GeminiContent(role="user", parts=parts), GeminiContent(role=GeminiRole.user, parts=parts),
], ],
generationConfig=GeminiImageGenerationConfig( generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT", "IMAGE"], responseModalities=["TEXT", "IMAGE"],