mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
feat(api-nodes): add new Gemini model (#10789)
This commit is contained in:
parent
d526974576
commit
24fdb92edf
@ -1,22 +1,229 @@
|
|||||||
from typing import Optional
|
from datetime import date
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
from pydantic import BaseModel, Field
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
class GeminiSafetyCategory(str, Enum):
|
||||||
|
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
||||||
|
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
||||||
|
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
||||||
|
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSafetyThreshold(str, Enum):
|
||||||
|
OFF = "OFF"
|
||||||
|
BLOCK_NONE = "BLOCK_NONE"
|
||||||
|
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
||||||
|
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
||||||
|
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSafetySetting(BaseModel):
|
||||||
|
category: GeminiSafetyCategory
|
||||||
|
threshold: GeminiSafetyThreshold
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiRole(str, Enum):
|
||||||
|
user = "user"
|
||||||
|
model = "model"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiMimeType(str, Enum):
|
||||||
|
application_pdf = "application/pdf"
|
||||||
|
audio_mpeg = "audio/mpeg"
|
||||||
|
audio_mp3 = "audio/mp3"
|
||||||
|
audio_wav = "audio/wav"
|
||||||
|
image_png = "image/png"
|
||||||
|
image_jpeg = "image/jpeg"
|
||||||
|
image_webp = "image/webp"
|
||||||
|
text_plain = "text/plain"
|
||||||
|
video_mov = "video/mov"
|
||||||
|
video_mpeg = "video/mpeg"
|
||||||
|
video_mp4 = "video/mp4"
|
||||||
|
video_mpg = "video/mpg"
|
||||||
|
video_avi = "video/avi"
|
||||||
|
video_wmv = "video/wmv"
|
||||||
|
video_mpegps = "video/mpegps"
|
||||||
|
video_flv = "video/flv"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiInlineData(BaseModel):
|
||||||
|
data: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
|
||||||
|
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
|
||||||
|
)
|
||||||
|
mimeType: GeminiMimeType | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiPart(BaseModel):
|
||||||
|
inlineData: GeminiInlineData | None = Field(None)
|
||||||
|
text: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiTextPart(BaseModel):
|
||||||
|
text: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiContent(BaseModel):
|
||||||
|
parts: list[GeminiPart] = Field(...)
|
||||||
|
role: GeminiRole = Field(..., examples=["user"])
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSystemInstructionContent(BaseModel):
|
||||||
|
parts: list[GeminiTextPart] = Field(
|
||||||
|
...,
|
||||||
|
description="A list of ordered parts that make up a single message. "
|
||||||
|
"Different parts may have different IANA MIME types.",
|
||||||
|
)
|
||||||
|
role: GeminiRole = Field(
|
||||||
|
...,
|
||||||
|
description="The identity of the entity that creates the message. "
|
||||||
|
"The following values are supported: "
|
||||||
|
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||||
|
"model: This indicates that the message is generated by the model. "
|
||||||
|
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||||
|
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiFunctionDeclaration(BaseModel):
|
||||||
|
description: str | None = Field(None)
|
||||||
|
name: str = Field(...)
|
||||||
|
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiTool(BaseModel):
|
||||||
|
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiOffset(BaseModel):
|
||||||
|
nanos: int | None = Field(None, ge=0, le=999999999)
|
||||||
|
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiVideoMetadata(BaseModel):
|
||||||
|
endOffset: GeminiOffset | None = Field(None)
|
||||||
|
startOffset: GeminiOffset | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiGenerationConfig(BaseModel):
|
||||||
|
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||||
|
seed: int | None = Field(None)
|
||||||
|
stopSequences: list[str] | None = Field(None)
|
||||||
|
temperature: float | None = Field(1, ge=0.0, le=2.0)
|
||||||
|
topK: int | None = Field(40, ge=1)
|
||||||
|
topP: float | None = Field(0.95, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageConfig(BaseModel):
|
class GeminiImageConfig(BaseModel):
|
||||||
aspectRatio: Optional[str] = None
|
aspectRatio: str | None = Field(None)
|
||||||
|
resolution: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||||
responseModalities: Optional[list[str]] = None
|
responseModalities: list[str] | None = Field(None)
|
||||||
imageConfig: Optional[GeminiImageConfig] = None
|
imageConfig: GeminiImageConfig | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerateContentRequest(BaseModel):
|
class GeminiImageGenerateContentRequest(BaseModel):
|
||||||
contents: list[GeminiContent]
|
contents: list[GeminiContent] = Field(...)
|
||||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
generationConfig: GeminiImageGenerationConfig | None = Field(None)
|
||||||
safetySettings: Optional[list[GeminiSafetySetting]] = None
|
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||||
tools: Optional[list[GeminiTool]] = None
|
tools: list[GeminiTool] | None = Field(None)
|
||||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiGenerateContentRequest(BaseModel):
|
||||||
|
contents: list[GeminiContent] = Field(...)
|
||||||
|
generationConfig: GeminiGenerationConfig | None = Field(None)
|
||||||
|
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||||
|
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||||
|
tools: list[GeminiTool] | None = Field(None)
|
||||||
|
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class Modality(str, Enum):
|
||||||
|
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
|
||||||
|
TEXT = "TEXT"
|
||||||
|
IMAGE = "IMAGE"
|
||||||
|
VIDEO = "VIDEO"
|
||||||
|
AUDIO = "AUDIO"
|
||||||
|
DOCUMENT = "DOCUMENT"
|
||||||
|
|
||||||
|
|
||||||
|
class ModalityTokenCount(BaseModel):
|
||||||
|
modality: Modality | None = None
|
||||||
|
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
|
||||||
|
|
||||||
|
|
||||||
|
class Probability(str, Enum):
|
||||||
|
NEGLIGIBLE = "NEGLIGIBLE"
|
||||||
|
LOW = "LOW"
|
||||||
|
MEDIUM = "MEDIUM"
|
||||||
|
HIGH = "HIGH"
|
||||||
|
UNKNOWN = "UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiSafetyRating(BaseModel):
|
||||||
|
category: GeminiSafetyCategory | None = None
|
||||||
|
probability: Probability | None = Field(
|
||||||
|
None,
|
||||||
|
description="The probability that the content violates the specified safety category",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiCitation(BaseModel):
|
||||||
|
authors: list[str] | None = None
|
||||||
|
endIndex: int | None = None
|
||||||
|
license: str | None = None
|
||||||
|
publicationDate: date | None = None
|
||||||
|
startIndex: int | None = None
|
||||||
|
title: str | None = None
|
||||||
|
uri: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiCitationMetadata(BaseModel):
|
||||||
|
citations: list[GeminiCitation] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiCandidate(BaseModel):
|
||||||
|
citationMetadata: GeminiCitationMetadata | None = None
|
||||||
|
content: GeminiContent | None = None
|
||||||
|
finishReason: str | None = None
|
||||||
|
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiPromptFeedback(BaseModel):
|
||||||
|
blockReason: str | None = None
|
||||||
|
blockReasonMessage: str | None = None
|
||||||
|
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiUsageMetadata(BaseModel):
|
||||||
|
cachedContentTokenCount: int | None = Field(
|
||||||
|
None,
|
||||||
|
description="Output only. Number of tokens in the cached part in the input (the cached content).",
|
||||||
|
)
|
||||||
|
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
|
||||||
|
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||||
|
None, description="Breakdown of candidate tokens by modality."
|
||||||
|
)
|
||||||
|
promptTokenCount: int | None = Field(
|
||||||
|
None,
|
||||||
|
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
|
||||||
|
)
|
||||||
|
promptTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||||
|
None, description="Breakdown of prompt tokens by modality."
|
||||||
|
)
|
||||||
|
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
|
||||||
|
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiGenerateContentResponse(BaseModel):
|
||||||
|
candidates: list[GeminiCandidate] | None = Field(None)
|
||||||
|
promptFeedback: GeminiPromptFeedback | None = Field(None)
|
||||||
|
usageMetadata: GeminiUsageMetadata | None = Field(None)
|
||||||
|
|||||||
@ -3,8 +3,6 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API
|
|||||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -12,7 +10,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -20,18 +18,17 @@ from typing_extensions import override
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis.gemini_api import (
|
||||||
GeminiContent,
|
GeminiContent,
|
||||||
GeminiGenerateContentRequest,
|
GeminiGenerateContentRequest,
|
||||||
GeminiGenerateContentResponse,
|
GeminiGenerateContentResponse,
|
||||||
GeminiInlineData,
|
|
||||||
GeminiMimeType,
|
|
||||||
GeminiPart,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.gemini_api import (
|
|
||||||
GeminiImageConfig,
|
GeminiImageConfig,
|
||||||
GeminiImageGenerateContentRequest,
|
GeminiImageGenerateContentRequest,
|
||||||
GeminiImageGenerationConfig,
|
GeminiImageGenerationConfig,
|
||||||
|
GeminiInlineData,
|
||||||
|
GeminiMimeType,
|
||||||
|
GeminiPart,
|
||||||
|
GeminiRole,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -57,6 +54,7 @@ class GeminiModel(str, Enum):
|
|||||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||||
gemini_2_5_pro = "gemini-2.5-pro"
|
gemini_2_5_pro = "gemini-2.5-pro"
|
||||||
gemini_2_5_flash = "gemini-2.5-flash"
|
gemini_2_5_flash = "gemini-2.5-flash"
|
||||||
|
gemini_3_0_pro = "gemini-3-pro-preview"
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageModel(str, Enum):
|
class GeminiImageModel(str, Enum):
|
||||||
@ -103,6 +101,16 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
|||||||
Returns:
|
Returns:
|
||||||
List of response parts matching the requested type.
|
List of response parts matching the requested type.
|
||||||
"""
|
"""
|
||||||
|
if response.candidates is None:
|
||||||
|
if response.promptFeedback.blockReason:
|
||||||
|
feedback = response.promptFeedback
|
||||||
|
raise ValueError(
|
||||||
|
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
|
||||||
|
)
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Gemini returned no response candidates. "
|
||||||
|
"Please report to ComfyUI repository with the example of workflow to reproduce this."
|
||||||
|
)
|
||||||
parts = []
|
parts = []
|
||||||
for part in response.candidates[0].content.parts:
|
for part in response.candidates[0].content.parts:
|
||||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||||
@ -272,10 +280,10 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
images: Optional[torch.Tensor] = None,
|
images: torch.Tensor | None = None,
|
||||||
audio: Optional[Input.Audio] = None,
|
audio: Input.Audio | None = None,
|
||||||
video: Optional[Input.Video] = None,
|
video: Input.Video | None = None,
|
||||||
files: Optional[list[GeminiPart]] = None,
|
files: list[GeminiPart] | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
|
||||||
@ -300,7 +308,7 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
data=GeminiGenerateContentRequest(
|
data=GeminiGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(
|
GeminiContent(
|
||||||
role="user",
|
role=GeminiRole.user,
|
||||||
parts=parts,
|
parts=parts,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -308,7 +316,6 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get result output
|
|
||||||
output_text = get_text_from_response(response)
|
output_text = get_text_from_response(response)
|
||||||
if output_text:
|
if output_text:
|
||||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||||
@ -406,7 +413,7 @@ class GeminiInputFiles(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
|
def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput:
|
||||||
"""Loads and formats input files for Gemini API."""
|
"""Loads and formats input files for Gemini API."""
|
||||||
if GEMINI_INPUT_FILES is None:
|
if GEMINI_INPUT_FILES is None:
|
||||||
GEMINI_INPUT_FILES = []
|
GEMINI_INPUT_FILES = []
|
||||||
@ -421,7 +428,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="GeminiImageNode",
|
node_id="GeminiImageNode",
|
||||||
display_name="Google Gemini Image",
|
display_name="Nano Banana (Google Gemini Image)",
|
||||||
category="api node/image/Gemini",
|
category="api node/image/Gemini",
|
||||||
description="Edit images synchronously via Google API.",
|
description="Edit images synchronously via Google API.",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -488,8 +495,8 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
images: Optional[torch.Tensor] = None,
|
images: torch.Tensor | None = None,
|
||||||
files: Optional[list[GeminiPart]] = None,
|
files: list[GeminiPart] | None = None,
|
||||||
aspect_ratio: str = "auto",
|
aspect_ratio: str = "auto",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
@ -510,7 +517,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||||
data=GeminiImageGenerateContentRequest(
|
data=GeminiImageGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(role="user", parts=parts),
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
],
|
],
|
||||||
generationConfig=GeminiImageGenerationConfig(
|
generationConfig=GeminiImageGenerationConfig(
|
||||||
responseModalities=["TEXT", "IMAGE"],
|
responseModalities=["TEXT", "IMAGE"],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user