easyai-ai-gateway/apps/api/internal/clients/gemini.go
wangbo 8ad5b06c18 feat(api): 添加多媒体内容支持并优化钱包计费系统
- 在 API 接口定义中为 video_url 和 audio_url 类型添加 mime_type 字段
- 实现 Google Gemini 客户端对视频和音频内容的支持,包括媒体类型检测和数据传输
- 添加 Gemini 客户端测试用例验证多媒体内容转换功能
- 重构 Playground 页面的媒体上传逻辑以支持 MIME 类型传递
- 实现钱包计费预留机制,确保任务执行前余额充足
- 添加钱包冻结余额管理,防止并发操作导致的超扣问题
- 实现计费预留释放逻辑,处理任务失败或取消情况下的资金返还
- 优化数据库事务处理,确保计费操作的原子性和一致性
- 添加数据库集成测试验证迁移脚本执行流程
- 统一 Google Gemini 相关模型提供商标识符映射
2026-05-22 23:46:08 +08:00

535 lines
15 KiB
Go

package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"mime"
"net/http"
"net/url"
"path"
"strings"
"time"
)
type GeminiClient struct {
HTTPClient *http.Client
}
func (c GeminiClient) Run(ctx context.Context, request Request) (Response, error) {
apiKey := credential(request.Candidate.Credentials, "apiKey", "api_key", "key", "token")
if apiKey == "" {
return Response{}, &ClientError{Code: "missing_credentials", Message: "gemini api key is required", Retryable: false}
}
body := geminiBody(request)
raw, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, geminiURL(request.Candidate.BaseURL, upstreamModelName(request.Candidate), apiKey), bytes.NewReader(raw))
if err != nil {
return Response{}, err
}
req.Header.Set("Content-Type", "application/json")
responseStartedAt := time.Now()
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
if err != nil {
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
responseFinishedAt := time.Now()
if err != nil {
return Response{}, annotateResponseError(err, requestID, responseStartedAt, responseFinishedAt)
}
output := geminiResult(request, result)
if requestID == "" {
requestID = requestIDFromResult(output)
}
return Response{
Result: output,
RequestID: requestID,
Usage: geminiUsage(result),
Progress: providerProgress(request),
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
}, nil
}
func geminiURL(baseURL string, model string, apiKey string) string {
base := strings.TrimRight(strings.TrimSpace(baseURL), "/")
if base == "" {
base = "https://generativelanguage.googleapis.com"
}
base = strings.TrimSuffix(base, "/openai")
if strings.HasSuffix(base, "/v1beta") {
base = strings.TrimSuffix(base, "/v1beta")
}
escapedModel := url.PathEscape(model)
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", base, escapedModel, url.QueryEscape(apiKey))
}
func geminiBody(request Request) map[string]any {
if contents, ok := request.Body["contents"]; ok {
return map[string]any{"contents": contents}
}
prompt := firstNonEmptyPrompt(request.Body, "")
if prompt != "" {
return map[string]any{
"contents": []any{map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": prompt}},
}},
}
}
body := map[string]any{"contents": geminiContentsFromMessages(request.Body)}
if tools := geminiToolsFromOpenAITools(request.Body["tools"]); len(tools) > 0 {
body["tools"] = tools
}
contents, _ := body["contents"].([]any)
if len(contents) > 0 {
return body
}
return map[string]any{"contents": []any{map[string]any{
"role": "user",
"parts": []any{map[string]any{"text": textFromMessages(request.Body)}},
}}}
}
func geminiContentsFromMessages(body map[string]any) []any {
normalized := NormalizeChatCompletionRequestBody(body)
messages, _ := normalized["messages"].([]any)
contents := make([]any, 0, len(messages))
toolNames := map[string]string{}
for _, rawMessage := range messages {
message, _ := rawMessage.(map[string]any)
if len(message) == 0 {
continue
}
role := stringFromAny(message["role"])
if role == "tool" {
toolCallID := stringFromAny(message["tool_call_id"])
name := toolNames[toolCallID]
if name == "" {
name = toolCallID
}
if name == "" {
name = "tool"
}
contents = append(contents, map[string]any{
"role": "user",
"parts": []any{map[string]any{"functionResponse": map[string]any{
"name": name,
"response": geminiFunctionResponsePayload(message["content"]),
}}},
})
continue
}
parts := geminiContentParts(message["content"])
if role == "assistant" {
for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) {
toolCall, _ := rawToolCall.(map[string]any)
function, _ := toolCall["function"].(map[string]any)
name := stringFromAny(function["name"])
if name == "" {
continue
}
if id := stringFromAny(toolCall["id"]); id != "" {
toolNames[id] = name
}
parts = append(parts, map[string]any{"functionCall": map[string]any{
"name": name,
"args": geminiFunctionArgs(function["arguments"]),
}})
}
}
if len(parts) == 0 {
continue
}
contents = append(contents, map[string]any{
"role": geminiRole(role),
"parts": parts,
})
}
return contents
}
func geminiRole(role string) string {
if role == "assistant" {
return "model"
}
return "user"
}
func geminiContentParts(content any) []any {
parts := make([]any, 0)
switch typed := content.(type) {
case string:
if strings.TrimSpace(typed) != "" {
parts = append(parts, map[string]any{"text": typed})
}
case []any:
for _, rawPart := range typed {
part, _ := rawPart.(map[string]any)
if len(part) == 0 {
continue
}
switch stringFromAny(part["type"]) {
case "text":
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
parts = append(parts, map[string]any{"text": text})
}
case "image_url":
if media := geminiMediaPart(part, "image_url", "image"); media != nil {
parts = append(parts, media)
}
case "video_url":
if media := geminiMediaPart(part, "video_url", "video"); media != nil {
parts = append(parts, media)
}
case "audio_url":
if media := geminiMediaPart(part, "audio_url", "audio"); media != nil {
parts = append(parts, media)
}
case "input_audio":
if media := geminiInputAudioPart(part); media != nil {
parts = append(parts, media)
}
default:
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
parts = append(parts, map[string]any{"text": text})
}
}
}
}
return parts
}
func geminiMediaPart(part map[string]any, key string, mediaType string) map[string]any {
nested := mapFromAny(part[key])
uri := firstNonEmptyString(nested["url"], part["url"], part[key])
if uri == "" {
return nil
}
mimeType := firstNonEmptyString(nested["mime_type"], nested["mimeType"], part["mime_type"], part["mimeType"])
return geminiMediaURLPart(uri, mimeType, mediaType)
}
func geminiInputAudioPart(part map[string]any) map[string]any {
audio := mapFromAny(part["input_audio"])
uri := firstNonEmptyString(audio["data"], audio["url"])
if uri == "" {
return nil
}
mimeType := firstNonEmptyString(audio["mime_type"], audio["mimeType"])
if mimeType == "" {
format := strings.ToLower(strings.TrimPrefix(stringFromAny(audio["format"]), "."))
if strings.Contains(format, "/") {
mimeType = format
} else if format == "mp3" {
mimeType = "audio/mpeg"
} else if format != "" {
mimeType = "audio/" + format
}
}
return geminiMediaURLPart(uri, mimeType, "audio")
}
func geminiMediaURLPart(uri string, explicitMimeType string, mediaType string) map[string]any {
if parsed := geminiDataURL(uri); parsed != nil {
return map[string]any{"inlineData": map[string]any{
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, parsed.mimeType), mediaType),
"data": parsed.data,
}}
}
return map[string]any{"fileData": map[string]any{
"fileUri": uri,
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, mimeFromURI(uri)), mediaType),
}}
}
type geminiParsedDataURL struct {
mimeType string
data string
}
func geminiDataURL(value string) *geminiParsedDataURL {
if !strings.HasPrefix(value, "data:") {
return nil
}
prefix, data, ok := strings.Cut(value, ",")
if !ok || !strings.Contains(prefix, ";base64") {
return nil
}
mimeType := strings.TrimPrefix(strings.Split(prefix, ";")[0], "data:")
if mimeType == "" {
mimeType = "application/octet-stream"
}
return &geminiParsedDataURL{mimeType: mimeType, data: data}
}
func mimeFromURI(value string) string {
pathValue := value
if parsed, err := url.Parse(value); err == nil && parsed.Path != "" {
pathValue = parsed.Path
}
extension := strings.ToLower(path.Ext(pathValue))
if extension == "" {
return ""
}
return mime.TypeByExtension(extension)
}
func geminiMediaMime(mimeType string, mediaType string) string {
normalized := strings.ToLower(strings.TrimSpace(strings.Split(mimeType, ";")[0]))
switch mediaType {
case "image":
if strings.HasPrefix(normalized, "image/") && normalized != "image/svg+xml" {
return normalized
}
return "image/png"
case "video":
switch normalized {
case "video/x-msvideo":
return "video/avi"
case "video/quicktime", "video/mpeg", "video/mp4", "video/avi", "video/x-flv", "video/mpg", "video/webm", "video/wmv", "video/3gpp":
return normalized
default:
return "video/mp4"
}
case "audio":
switch normalized {
case "audio/x-wav", "audio/wave":
return "audio/wav"
case "audio/mpeg", "audio/mp3", "audio/wav", "audio/aiff", "audio/aac", "audio/ogg", "audio/flac", "audio/mp4", "audio/webm":
return normalized
default:
return "audio/mpeg"
}
default:
return "application/octet-stream"
}
}
func toolCallsSlice(value any) []any {
switch typed := value.(type) {
case []any:
return typed
case map[string]any:
return []any{typed}
default:
return nil
}
}
func geminiFunctionArgs(value any) map[string]any {
if value == nil {
return map[string]any{}
}
if args, ok := value.(map[string]any); ok {
return args
}
if text, ok := value.(string); ok {
if strings.TrimSpace(text) == "" {
return map[string]any{}
}
var args map[string]any
if err := json.Unmarshal([]byte(text), &args); err == nil {
return args
}
return map[string]any{"arguments": text}
}
return map[string]any{"arguments": value}
}
func geminiFunctionResponsePayload(value any) map[string]any {
if payload, ok := value.(map[string]any); ok {
return payload
}
if text, ok := value.(string); ok {
var payload map[string]any
if err := json.Unmarshal([]byte(text), &payload); err == nil {
return payload
}
return map[string]any{"content": text}
}
if value == nil {
return map[string]any{}
}
return map[string]any{"content": value}
}
func geminiToolsFromOpenAITools(value any) []any {
tools, ok := value.([]any)
if !ok || len(tools) == 0 {
return nil
}
declarations := make([]any, 0, len(tools))
for _, rawTool := range tools {
tool, _ := rawTool.(map[string]any)
function, _ := tool["function"].(map[string]any)
name := stringFromAny(function["name"])
if name == "" {
continue
}
declaration := map[string]any{"name": name}
if description := stringFromAny(function["description"]); description != "" {
declaration["description"] = description
}
if parameters, ok := function["parameters"]; ok {
declaration["parameters"] = parameters
}
declarations = append(declarations, declaration)
}
if len(declarations) == 0 {
return nil
}
return []any{map[string]any{"functionDeclarations": declarations}}
}
func geminiResult(request Request, raw map[string]any) map[string]any {
if request.ModelType == "image" {
data := geminiImageData(raw)
if len(data) == 0 {
data = []any{map[string]any{"url": "/static/provider/gemini-image-placeholder.png"}}
}
return map[string]any{
"id": "gemini-image",
"created": nowUnix(),
"model": request.Model,
"data": data,
"raw": raw,
}
}
message, finishReason := geminiChatMessage(raw)
return map[string]any{
"id": "gemini-chat",
"object": "chat.completion",
"created": nowUnix(),
"model": request.Model,
"choices": []any{map[string]any{
"index": 0,
"finish_reason": finishReason,
"message": message,
}},
"usage": geminiUsageMap(raw),
"raw": raw,
}
}
func textFromMessages(body map[string]any) string {
messages, _ := body["messages"].([]any)
parts := make([]string, 0, len(messages))
for _, message := range messages {
item, _ := message.(map[string]any)
content := item["content"]
switch typed := content.(type) {
case string:
parts = append(parts, typed)
case []any:
for _, part := range typed {
partMap, _ := part.(map[string]any)
if text, ok := partMap["text"].(string); ok {
parts = append(parts, text)
}
}
}
}
return strings.TrimSpace(strings.Join(parts, "\n"))
}
func geminiText(raw map[string]any) string {
message, _ := geminiChatMessage(raw)
content, _ := message["content"].(string)
return content
}
func geminiChatMessage(raw map[string]any) (map[string]any, string) {
candidates, _ := raw["candidates"].([]any)
for _, candidate := range candidates {
candidateMap, _ := candidate.(map[string]any)
content, _ := candidateMap["content"].(map[string]any)
parts, _ := content["parts"].([]any)
textParts := make([]string, 0, len(parts))
toolCalls := make([]any, 0)
for _, part := range parts {
partMap, _ := part.(map[string]any)
if text, ok := partMap["text"].(string); ok && text != "" {
textParts = append(textParts, text)
}
functionCall := mapFromAny(firstPresent(partMap["functionCall"], partMap["function_call"]))
if len(functionCall) == 0 {
continue
}
if toolCall := normalizeGeminiFunctionCall(functionCall, len(toolCalls), false); toolCall != nil {
toolCalls = append(toolCalls, toolCall)
}
}
message := map[string]any{
"role": "assistant",
"content": strings.Join(textParts, ""),
}
if len(toolCalls) > 0 {
message["tool_calls"] = toolCalls
if len(textParts) == 0 {
message["content"] = nil
}
}
return message, geminiFinishReason(candidateMap, len(toolCalls) > 0)
}
return map[string]any{"role": "assistant", "content": ""}, "stop"
}
func geminiFinishReason(candidate map[string]any, hasToolCalls bool) string {
if hasToolCalls {
return "tool_calls"
}
switch strings.ToUpper(stringFromAny(candidate["finishReason"])) {
case "MAX_TOKENS":
return "length"
case "SAFETY", "RECITATION", "BLOCKLIST", "PROHIBITED_CONTENT", "SPII":
return "content_filter"
default:
return "stop"
}
}
func geminiImageData(raw map[string]any) []any {
candidates, _ := raw["candidates"].([]any)
out := []any{}
for _, candidate := range candidates {
candidateMap, _ := candidate.(map[string]any)
content, _ := candidateMap["content"].(map[string]any)
parts, _ := content["parts"].([]any)
for _, part := range parts {
partMap, _ := part.(map[string]any)
inline, _ := partMap["inlineData"].(map[string]any)
if inline == nil {
inline, _ = partMap["inline_data"].(map[string]any)
}
if data, ok := inline["data"].(string); ok && data != "" {
out = append(out, map[string]any{"b64_json": data, "mime_type": inline["mimeType"]})
}
}
}
return out
}
func geminiUsage(raw map[string]any) Usage {
usageMap := geminiUsageMap(raw)
input := intFromAny(usageMap["prompt_tokens"])
output := intFromAny(usageMap["completion_tokens"])
total := intFromAny(usageMap["total_tokens"])
return Usage{InputTokens: input, OutputTokens: output, TotalTokens: total}
}
func geminiUsageMap(raw map[string]any) map[string]any {
meta, _ := raw["usageMetadata"].(map[string]any)
input := intFromAny(meta["promptTokenCount"])
output := intFromAny(meta["candidatesTokenCount"])
total := intFromAny(meta["totalTokenCount"])
if total == 0 {
total = input + output
}
return map[string]any{"prompt_tokens": input, "completion_tokens": output, "total_tokens": total}
}