easyai-ai-gateway/apps/api/internal/clients/gemini.go

400 lines
11 KiB
Go

package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
type GeminiClient struct {
HTTPClient *http.Client
}
func (c GeminiClient) Run(ctx context.Context, request Request) (Response, error) {
apiKey := credential(request.Candidate.Credentials, "apiKey", "api_key", "key", "token")
if apiKey == "" {
return Response{}, &ClientError{Code: "missing_credentials", Message: "gemini api key is required", Retryable: false}
}
body := geminiBody(request)
raw, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, geminiURL(request.Candidate.BaseURL, upstreamModelName(request.Candidate), apiKey), bytes.NewReader(raw))
if err != nil {
return Response{}, err
}
req.Header.Set("Content-Type", "application/json")
responseStartedAt := time.Now()
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
if err != nil {
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
responseFinishedAt := time.Now()
if err != nil {
return Response{}, annotateResponseError(err, requestID, responseStartedAt, responseFinishedAt)
}
output := geminiResult(request, result)
if requestID == "" {
requestID = requestIDFromResult(output)
}
return Response{
Result: output,
RequestID: requestID,
Usage: geminiUsage(result),
Progress: providerProgress(request),
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
}, nil
}
func geminiURL(baseURL string, model string, apiKey string) string {
base := strings.TrimRight(strings.TrimSpace(baseURL), "/")
if base == "" {
base = "https://generativelanguage.googleapis.com"
}
if strings.HasSuffix(base, "/v1beta") {
base = strings.TrimSuffix(base, "/v1beta")
}
escapedModel := url.PathEscape(model)
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", base, escapedModel, url.QueryEscape(apiKey))
}
func geminiBody(request Request) map[string]any {
if contents, ok := request.Body["contents"]; ok {
return map[string]any{"contents": contents}
}
prompt := firstNonEmptyPrompt(request.Body, "")
if prompt != "" {
return map[string]any{
"contents": []any{map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": prompt}},
}},
}
}
body := map[string]any{"contents": geminiContentsFromMessages(request.Body)}
if tools := geminiToolsFromOpenAITools(request.Body["tools"]); len(tools) > 0 {
body["tools"] = tools
}
contents, _ := body["contents"].([]any)
if len(contents) > 0 {
return body
}
return map[string]any{"contents": []any{map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": textFromMessages(request.Body)}},
}}}
}
func geminiContentsFromMessages(body map[string]any) []any {
normalized := NormalizeChatCompletionRequestBody(body)
messages, _ := normalized["messages"].([]any)
contents := make([]any, 0, len(messages))
toolNames := map[string]string{}
for _, rawMessage := range messages {
message, _ := rawMessage.(map[string]any)
if len(message) == 0 {
continue
}
role := stringFromAny(message["role"])
if role == "tool" {
toolCallID := stringFromAny(message["tool_call_id"])
name := toolNames[toolCallID]
if name == "" {
name = toolCallID
}
if name == "" {
name = "tool"
}
contents = append(contents, map[string]any{
"role": "user",
"parts": []any{map[string]any{"functionResponse": map[string]any{
"name": name,
"response": geminiFunctionResponsePayload(message["content"]),
}}},
})
continue
}
parts := geminiTextParts(message["content"])
if role == "assistant" {
for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) {
toolCall, _ := rawToolCall.(map[string]any)
function, _ := toolCall["function"].(map[string]any)
name := stringFromAny(function["name"])
if name == "" {
continue
}
if id := stringFromAny(toolCall["id"]); id != "" {
toolNames[id] = name
}
parts = append(parts, map[string]any{"functionCall": map[string]any{
"name": name,
"args": geminiFunctionArgs(function["arguments"]),
}})
}
}
if len(parts) == 0 {
continue
}
contents = append(contents, map[string]any{
"role": geminiRole(role),
"parts": parts,
})
}
return contents
}
func geminiRole(role string) string {
if role == "assistant" {
return "model"
}
return "user"
}
func geminiTextParts(content any) []any {
parts := make([]any, 0)
switch typed := content.(type) {
case string:
if strings.TrimSpace(typed) != "" {
parts = append(parts, map[string]any{"text": typed})
}
case []any:
for _, rawPart := range typed {
part, _ := rawPart.(map[string]any)
if text := stringFromAny(firstPresent(part["text"], part["content"])); strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
}
}
}
return parts
}
func toolCallsSlice(value any) []any {
switch typed := value.(type) {
case []any:
return typed
case map[string]any:
return []any{typed}
default:
return nil
}
}
func geminiFunctionArgs(value any) map[string]any {
if value == nil {
return map[string]any{}
}
if args, ok := value.(map[string]any); ok {
return args
}
if text, ok := value.(string); ok {
if strings.TrimSpace(text) == "" {
return map[string]any{}
}
var args map[string]any
if err := json.Unmarshal([]byte(text), &args); err == nil {
return args
}
return map[string]any{"arguments": text}
}
return map[string]any{"arguments": value}
}
func geminiFunctionResponsePayload(value any) map[string]any {
if payload, ok := value.(map[string]any); ok {
return payload
}
if text, ok := value.(string); ok {
var payload map[string]any
if err := json.Unmarshal([]byte(text), &payload); err == nil {
return payload
}
return map[string]any{"content": text}
}
if value == nil {
return map[string]any{}
}
return map[string]any{"content": value}
}
func geminiToolsFromOpenAITools(value any) []any {
tools, ok := value.([]any)
if !ok || len(tools) == 0 {
return nil
}
declarations := make([]any, 0, len(tools))
for _, rawTool := range tools {
tool, _ := rawTool.(map[string]any)
function, _ := tool["function"].(map[string]any)
name := stringFromAny(function["name"])
if name == "" {
continue
}
declaration := map[string]any{"name": name}
if description := stringFromAny(function["description"]); description != "" {
declaration["description"] = description
}
if parameters, ok := function["parameters"]; ok {
declaration["parameters"] = parameters
}
declarations = append(declarations, declaration)
}
if len(declarations) == 0 {
return nil
}
return []any{map[string]any{"functionDeclarations": declarations}}
}
func geminiResult(request Request, raw map[string]any) map[string]any {
if request.ModelType == "image" {
data := geminiImageData(raw)
if len(data) == 0 {
data = []any{map[string]any{"url": "/static/provider/gemini-image-placeholder.png"}}
}
return map[string]any{
"id": "gemini-image",
"created": nowUnix(),
"model": request.Model,
"data": data,
"raw": raw,
}
}
message, finishReason := geminiChatMessage(raw)
return map[string]any{
"id": "gemini-chat",
"object": "chat.completion",
"created": nowUnix(),
"model": request.Model,
"choices": []any{map[string]any{
"index": 0,
"finish_reason": finishReason,
"message": message,
}},
"usage": geminiUsageMap(raw),
"raw": raw,
}
}
func textFromMessages(body map[string]any) string {
messages, _ := body["messages"].([]any)
parts := make([]string, 0, len(messages))
for _, message := range messages {
item, _ := message.(map[string]any)
content := item["content"]
switch typed := content.(type) {
case string:
parts = append(parts, typed)
case []any:
for _, part := range typed {
partMap, _ := part.(map[string]any)
if text, ok := partMap["text"].(string); ok {
parts = append(parts, text)
}
}
}
}
return strings.TrimSpace(strings.Join(parts, "\n"))
}
func geminiText(raw map[string]any) string {
message, _ := geminiChatMessage(raw)
content, _ := message["content"].(string)
return content
}
func geminiChatMessage(raw map[string]any) (map[string]any, string) {
candidates, _ := raw["candidates"].([]any)
for _, candidate := range candidates {
candidateMap, _ := candidate.(map[string]any)
content, _ := candidateMap["content"].(map[string]any)
parts, _ := content["parts"].([]any)
textParts := make([]string, 0, len(parts))
toolCalls := make([]any, 0)
for _, part := range parts {
partMap, _ := part.(map[string]any)
if text, ok := partMap["text"].(string); ok && text != "" {
textParts = append(textParts, text)
}
functionCall := mapFromAny(firstPresent(partMap["functionCall"], partMap["function_call"]))
if len(functionCall) == 0 {
continue
}
if toolCall := normalizeGeminiFunctionCall(functionCall, len(toolCalls), false); toolCall != nil {
toolCalls = append(toolCalls, toolCall)
}
}
message := map[string]any{
"role": "assistant",
"content": strings.Join(textParts, ""),
}
if len(toolCalls) > 0 {
message["tool_calls"] = toolCalls
if len(textParts) == 0 {
message["content"] = nil
}
}
return message, geminiFinishReason(candidateMap, len(toolCalls) > 0)
}
return map[string]any{"role": "assistant", "content": ""}, "stop"
}
func geminiFinishReason(candidate map[string]any, hasToolCalls bool) string {
if hasToolCalls {
return "tool_calls"
}
switch strings.ToUpper(stringFromAny(candidate["finishReason"])) {
case "MAX_TOKENS":
return "length"
case "SAFETY", "RECITATION", "BLOCKLIST", "PROHIBITED_CONTENT", "SPII":
return "content_filter"
default:
return "stop"
}
}
func geminiImageData(raw map[string]any) []any {
candidates, _ := raw["candidates"].([]any)
out := []any{}
for _, candidate := range candidates {
candidateMap, _ := candidate.(map[string]any)
content, _ := candidateMap["content"].(map[string]any)
parts, _ := content["parts"].([]any)
for _, part := range parts {
partMap, _ := part.(map[string]any)
inline, _ := partMap["inlineData"].(map[string]any)
if inline == nil {
inline, _ = partMap["inline_data"].(map[string]any)
}
if data, ok := inline["data"].(string); ok && data != "" {
out = append(out, map[string]any{"b64_json": data, "mime_type": inline["mimeType"]})
}
}
}
return out
}
func geminiUsage(raw map[string]any) Usage {
usageMap := geminiUsageMap(raw)
input := intFromAny(usageMap["prompt_tokens"])
output := intFromAny(usageMap["completion_tokens"])
total := intFromAny(usageMap["total_tokens"])
return Usage{InputTokens: input, OutputTokens: output, TotalTokens: total}
}
func geminiUsageMap(raw map[string]any) map[string]any {
meta, _ := raw["usageMetadata"].(map[string]any)
input := intFromAny(meta["promptTokenCount"])
output := intFromAny(meta["candidatesTokenCount"])
total := intFromAny(meta["totalTokenCount"])
if total == 0 {
total = input + output
}
return map[string]any{"prompt_tokens": input, "completion_tokens": output, "total_tokens": total}
}