400 lines
11 KiB
Go
400 lines
11 KiB
Go
package clients
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type GeminiClient struct {
|
|
HTTPClient *http.Client
|
|
}
|
|
|
|
func (c GeminiClient) Run(ctx context.Context, request Request) (Response, error) {
|
|
apiKey := credential(request.Candidate.Credentials, "apiKey", "api_key", "key", "token")
|
|
if apiKey == "" {
|
|
return Response{}, &ClientError{Code: "missing_credentials", Message: "gemini api key is required", Retryable: false}
|
|
}
|
|
body := geminiBody(request)
|
|
raw, _ := json.Marshal(body)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, geminiURL(request.Candidate.BaseURL, upstreamModelName(request.Candidate), apiKey), bytes.NewReader(raw))
|
|
if err != nil {
|
|
return Response{}, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
responseStartedAt := time.Now()
|
|
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
|
if err != nil {
|
|
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
|
}
|
|
requestID := requestIDFromHTTPResponse(resp)
|
|
result, err := decodeHTTPResponse(resp)
|
|
responseFinishedAt := time.Now()
|
|
if err != nil {
|
|
return Response{}, annotateResponseError(err, requestID, responseStartedAt, responseFinishedAt)
|
|
}
|
|
output := geminiResult(request, result)
|
|
if requestID == "" {
|
|
requestID = requestIDFromResult(output)
|
|
}
|
|
return Response{
|
|
Result: output,
|
|
RequestID: requestID,
|
|
Usage: geminiUsage(result),
|
|
Progress: providerProgress(request),
|
|
ResponseStartedAt: responseStartedAt,
|
|
ResponseFinishedAt: responseFinishedAt,
|
|
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
|
|
}, nil
|
|
}
|
|
|
|
func geminiURL(baseURL string, model string, apiKey string) string {
|
|
base := strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
|
if base == "" {
|
|
base = "https://generativelanguage.googleapis.com"
|
|
}
|
|
if strings.HasSuffix(base, "/v1beta") {
|
|
base = strings.TrimSuffix(base, "/v1beta")
|
|
}
|
|
escapedModel := url.PathEscape(model)
|
|
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", base, escapedModel, url.QueryEscape(apiKey))
|
|
}
|
|
|
|
func geminiBody(request Request) map[string]any {
|
|
if contents, ok := request.Body["contents"]; ok {
|
|
return map[string]any{"contents": contents}
|
|
}
|
|
prompt := firstNonEmptyPrompt(request.Body, "")
|
|
if prompt != "" {
|
|
return map[string]any{
|
|
"contents": []any{map[string]any{
|
|
"role": "user",
|
|
"parts": []any{map[string]any{"text": prompt}},
|
|
}},
|
|
}
|
|
}
|
|
body := map[string]any{"contents": geminiContentsFromMessages(request.Body)}
|
|
if tools := geminiToolsFromOpenAITools(request.Body["tools"]); len(tools) > 0 {
|
|
body["tools"] = tools
|
|
}
|
|
contents, _ := body["contents"].([]any)
|
|
if len(contents) > 0 {
|
|
return body
|
|
}
|
|
return map[string]any{"contents": []any{map[string]any{
|
|
"role": "user",
|
|
"parts": []any{map[string]any{"text": textFromMessages(request.Body)}},
|
|
}}}
|
|
}
|
|
|
|
func geminiContentsFromMessages(body map[string]any) []any {
|
|
normalized := NormalizeChatCompletionRequestBody(body)
|
|
messages, _ := normalized["messages"].([]any)
|
|
contents := make([]any, 0, len(messages))
|
|
toolNames := map[string]string{}
|
|
for _, rawMessage := range messages {
|
|
message, _ := rawMessage.(map[string]any)
|
|
if len(message) == 0 {
|
|
continue
|
|
}
|
|
role := stringFromAny(message["role"])
|
|
if role == "tool" {
|
|
toolCallID := stringFromAny(message["tool_call_id"])
|
|
name := toolNames[toolCallID]
|
|
if name == "" {
|
|
name = toolCallID
|
|
}
|
|
if name == "" {
|
|
name = "tool"
|
|
}
|
|
contents = append(contents, map[string]any{
|
|
"role": "user",
|
|
"parts": []any{map[string]any{"functionResponse": map[string]any{
|
|
"name": name,
|
|
"response": geminiFunctionResponsePayload(message["content"]),
|
|
}}},
|
|
})
|
|
continue
|
|
}
|
|
parts := geminiTextParts(message["content"])
|
|
if role == "assistant" {
|
|
for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) {
|
|
toolCall, _ := rawToolCall.(map[string]any)
|
|
function, _ := toolCall["function"].(map[string]any)
|
|
name := stringFromAny(function["name"])
|
|
if name == "" {
|
|
continue
|
|
}
|
|
if id := stringFromAny(toolCall["id"]); id != "" {
|
|
toolNames[id] = name
|
|
}
|
|
parts = append(parts, map[string]any{"functionCall": map[string]any{
|
|
"name": name,
|
|
"args": geminiFunctionArgs(function["arguments"]),
|
|
}})
|
|
}
|
|
}
|
|
if len(parts) == 0 {
|
|
continue
|
|
}
|
|
contents = append(contents, map[string]any{
|
|
"role": geminiRole(role),
|
|
"parts": parts,
|
|
})
|
|
}
|
|
return contents
|
|
}
|
|
|
|
func geminiRole(role string) string {
|
|
if role == "assistant" {
|
|
return "model"
|
|
}
|
|
return "user"
|
|
}
|
|
|
|
func geminiTextParts(content any) []any {
|
|
parts := make([]any, 0)
|
|
switch typed := content.(type) {
|
|
case string:
|
|
if strings.TrimSpace(typed) != "" {
|
|
parts = append(parts, map[string]any{"text": typed})
|
|
}
|
|
case []any:
|
|
for _, rawPart := range typed {
|
|
part, _ := rawPart.(map[string]any)
|
|
if text := stringFromAny(firstPresent(part["text"], part["content"])); strings.TrimSpace(text) != "" {
|
|
parts = append(parts, map[string]any{"text": text})
|
|
}
|
|
}
|
|
}
|
|
return parts
|
|
}
|
|
|
|
func toolCallsSlice(value any) []any {
|
|
switch typed := value.(type) {
|
|
case []any:
|
|
return typed
|
|
case map[string]any:
|
|
return []any{typed}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func geminiFunctionArgs(value any) map[string]any {
|
|
if value == nil {
|
|
return map[string]any{}
|
|
}
|
|
if args, ok := value.(map[string]any); ok {
|
|
return args
|
|
}
|
|
if text, ok := value.(string); ok {
|
|
if strings.TrimSpace(text) == "" {
|
|
return map[string]any{}
|
|
}
|
|
var args map[string]any
|
|
if err := json.Unmarshal([]byte(text), &args); err == nil {
|
|
return args
|
|
}
|
|
return map[string]any{"arguments": text}
|
|
}
|
|
return map[string]any{"arguments": value}
|
|
}
|
|
|
|
func geminiFunctionResponsePayload(value any) map[string]any {
|
|
if payload, ok := value.(map[string]any); ok {
|
|
return payload
|
|
}
|
|
if text, ok := value.(string); ok {
|
|
var payload map[string]any
|
|
if err := json.Unmarshal([]byte(text), &payload); err == nil {
|
|
return payload
|
|
}
|
|
return map[string]any{"content": text}
|
|
}
|
|
if value == nil {
|
|
return map[string]any{}
|
|
}
|
|
return map[string]any{"content": value}
|
|
}
|
|
|
|
func geminiToolsFromOpenAITools(value any) []any {
|
|
tools, ok := value.([]any)
|
|
if !ok || len(tools) == 0 {
|
|
return nil
|
|
}
|
|
declarations := make([]any, 0, len(tools))
|
|
for _, rawTool := range tools {
|
|
tool, _ := rawTool.(map[string]any)
|
|
function, _ := tool["function"].(map[string]any)
|
|
name := stringFromAny(function["name"])
|
|
if name == "" {
|
|
continue
|
|
}
|
|
declaration := map[string]any{"name": name}
|
|
if description := stringFromAny(function["description"]); description != "" {
|
|
declaration["description"] = description
|
|
}
|
|
if parameters, ok := function["parameters"]; ok {
|
|
declaration["parameters"] = parameters
|
|
}
|
|
declarations = append(declarations, declaration)
|
|
}
|
|
if len(declarations) == 0 {
|
|
return nil
|
|
}
|
|
return []any{map[string]any{"functionDeclarations": declarations}}
|
|
}
|
|
|
|
func geminiResult(request Request, raw map[string]any) map[string]any {
|
|
if request.ModelType == "image" {
|
|
data := geminiImageData(raw)
|
|
if len(data) == 0 {
|
|
data = []any{map[string]any{"url": "/static/provider/gemini-image-placeholder.png"}}
|
|
}
|
|
return map[string]any{
|
|
"id": "gemini-image",
|
|
"created": nowUnix(),
|
|
"model": request.Model,
|
|
"data": data,
|
|
"raw": raw,
|
|
}
|
|
}
|
|
message, finishReason := geminiChatMessage(raw)
|
|
return map[string]any{
|
|
"id": "gemini-chat",
|
|
"object": "chat.completion",
|
|
"created": nowUnix(),
|
|
"model": request.Model,
|
|
"choices": []any{map[string]any{
|
|
"index": 0,
|
|
"finish_reason": finishReason,
|
|
"message": message,
|
|
}},
|
|
"usage": geminiUsageMap(raw),
|
|
"raw": raw,
|
|
}
|
|
}
|
|
|
|
func textFromMessages(body map[string]any) string {
|
|
messages, _ := body["messages"].([]any)
|
|
parts := make([]string, 0, len(messages))
|
|
for _, message := range messages {
|
|
item, _ := message.(map[string]any)
|
|
content := item["content"]
|
|
switch typed := content.(type) {
|
|
case string:
|
|
parts = append(parts, typed)
|
|
case []any:
|
|
for _, part := range typed {
|
|
partMap, _ := part.(map[string]any)
|
|
if text, ok := partMap["text"].(string); ok {
|
|
parts = append(parts, text)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return strings.TrimSpace(strings.Join(parts, "\n"))
|
|
}
|
|
|
|
func geminiText(raw map[string]any) string {
|
|
message, _ := geminiChatMessage(raw)
|
|
content, _ := message["content"].(string)
|
|
return content
|
|
}
|
|
|
|
func geminiChatMessage(raw map[string]any) (map[string]any, string) {
|
|
candidates, _ := raw["candidates"].([]any)
|
|
for _, candidate := range candidates {
|
|
candidateMap, _ := candidate.(map[string]any)
|
|
content, _ := candidateMap["content"].(map[string]any)
|
|
parts, _ := content["parts"].([]any)
|
|
textParts := make([]string, 0, len(parts))
|
|
toolCalls := make([]any, 0)
|
|
for _, part := range parts {
|
|
partMap, _ := part.(map[string]any)
|
|
if text, ok := partMap["text"].(string); ok && text != "" {
|
|
textParts = append(textParts, text)
|
|
}
|
|
functionCall := mapFromAny(firstPresent(partMap["functionCall"], partMap["function_call"]))
|
|
if len(functionCall) == 0 {
|
|
continue
|
|
}
|
|
if toolCall := normalizeGeminiFunctionCall(functionCall, len(toolCalls), false); toolCall != nil {
|
|
toolCalls = append(toolCalls, toolCall)
|
|
}
|
|
}
|
|
message := map[string]any{
|
|
"role": "assistant",
|
|
"content": strings.Join(textParts, ""),
|
|
}
|
|
if len(toolCalls) > 0 {
|
|
message["tool_calls"] = toolCalls
|
|
if len(textParts) == 0 {
|
|
message["content"] = nil
|
|
}
|
|
}
|
|
return message, geminiFinishReason(candidateMap, len(toolCalls) > 0)
|
|
}
|
|
return map[string]any{"role": "assistant", "content": ""}, "stop"
|
|
}
|
|
|
|
func geminiFinishReason(candidate map[string]any, hasToolCalls bool) string {
|
|
if hasToolCalls {
|
|
return "tool_calls"
|
|
}
|
|
switch strings.ToUpper(stringFromAny(candidate["finishReason"])) {
|
|
case "MAX_TOKENS":
|
|
return "length"
|
|
case "SAFETY", "RECITATION", "BLOCKLIST", "PROHIBITED_CONTENT", "SPII":
|
|
return "content_filter"
|
|
default:
|
|
return "stop"
|
|
}
|
|
}
|
|
|
|
func geminiImageData(raw map[string]any) []any {
|
|
candidates, _ := raw["candidates"].([]any)
|
|
out := []any{}
|
|
for _, candidate := range candidates {
|
|
candidateMap, _ := candidate.(map[string]any)
|
|
content, _ := candidateMap["content"].(map[string]any)
|
|
parts, _ := content["parts"].([]any)
|
|
for _, part := range parts {
|
|
partMap, _ := part.(map[string]any)
|
|
inline, _ := partMap["inlineData"].(map[string]any)
|
|
if inline == nil {
|
|
inline, _ = partMap["inline_data"].(map[string]any)
|
|
}
|
|
if data, ok := inline["data"].(string); ok && data != "" {
|
|
out = append(out, map[string]any{"b64_json": data, "mime_type": inline["mimeType"]})
|
|
}
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func geminiUsage(raw map[string]any) Usage {
|
|
usageMap := geminiUsageMap(raw)
|
|
input := intFromAny(usageMap["prompt_tokens"])
|
|
output := intFromAny(usageMap["completion_tokens"])
|
|
total := intFromAny(usageMap["total_tokens"])
|
|
return Usage{InputTokens: input, OutputTokens: output, TotalTokens: total}
|
|
}
|
|
|
|
func geminiUsageMap(raw map[string]any) map[string]any {
|
|
meta, _ := raw["usageMetadata"].(map[string]any)
|
|
input := intFromAny(meta["promptTokenCount"])
|
|
output := intFromAny(meta["candidatesTokenCount"])
|
|
total := intFromAny(meta["totalTokenCount"])
|
|
if total == 0 {
|
|
total = input + output
|
|
}
|
|
return map[string]any{"prompt_tokens": input, "completion_tokens": output, "total_tokens": total}
|
|
}
|