fix(clients): preserve gemini tool call responses
This commit is contained in:
parent
baffccf8f8
commit
69d23efb57
@ -655,6 +655,63 @@ func TestGeminiClientChatRestoresToolContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiClientChatConvertsFunctionCallResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"candidates": []any{map[string]any{
|
||||
"finishReason": "STOP",
|
||||
"content": map[string]any{"parts": []any{
|
||||
map[string]any{"functionCall": map[string]any{
|
||||
"name": "get_weather",
|
||||
"args": map[string]any{"city": "SF"},
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
response, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||
Kind: "chat.completions",
|
||||
Model: "gemini:gemini-2.5-flash",
|
||||
Body: map[string]any{
|
||||
"model": "gemini:gemini-2.5-flash",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "weather?"}},
|
||||
"tools": []any{map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": "get_weather"},
|
||||
}},
|
||||
},
|
||||
Candidate: store.RuntimeModelCandidate{
|
||||
BaseURL: server.URL,
|
||||
ProviderModelName: "gemini-2.5-flash",
|
||||
ModelType: "chat",
|
||||
Credentials: map[string]any{"apiKey": "gemini-key"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("run gemini client: %v", err)
|
||||
}
|
||||
choices, _ := response.Result["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("Gemini function call should use tool_calls finish reason: %+v", response.Result)
|
||||
}
|
||||
message, _ := choice["message"].(map[string]any)
|
||||
if message["content"] != nil {
|
||||
t.Fatalf("tool-only Gemini response should keep nullable content: %+v", message)
|
||||
}
|
||||
toolCalls, _ := message["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("Gemini function call was not converted: %+v", message)
|
||||
}
|
||||
toolCall, _ := toolCalls[0].(map[string]any)
|
||||
function, _ := toolCall["function"].(map[string]any)
|
||||
if toolCall["type"] != "function" || toolCall["id"] != "call_0" || function["name"] != "get_weather" || function["arguments"] != `{"city":"SF"}` {
|
||||
t.Fatalf("unexpected Gemini tool call: %+v", toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiURLAcceptsVersionedBaseURL(t *testing.T) {
|
||||
got := geminiURL("https://generativelanguage.googleapis.com/v1beta", "gemini-2.5-flash", "test-key")
|
||||
want := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key=test-key"
|
||||
|
||||
@ -265,7 +265,7 @@ func geminiResult(request Request, raw map[string]any) map[string]any {
|
||||
"raw": raw,
|
||||
}
|
||||
}
|
||||
content := geminiText(raw)
|
||||
message, finishReason := geminiChatMessage(raw)
|
||||
return map[string]any{
|
||||
"id": "gemini-chat",
|
||||
"object": "chat.completion",
|
||||
@ -273,8 +273,8 @@ func geminiResult(request Request, raw map[string]any) map[string]any {
|
||||
"model": request.Model,
|
||||
"choices": []any{map[string]any{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": map[string]any{"role": "assistant", "content": content},
|
||||
"finish_reason": finishReason,
|
||||
"message": message,
|
||||
}},
|
||||
"usage": geminiUsageMap(raw),
|
||||
"raw": raw,
|
||||
@ -303,19 +303,59 @@ func textFromMessages(body map[string]any) string {
|
||||
}
|
||||
|
||||
func geminiText(raw map[string]any) string {
|
||||
message, _ := geminiChatMessage(raw)
|
||||
content, _ := message["content"].(string)
|
||||
return content
|
||||
}
|
||||
|
||||
func geminiChatMessage(raw map[string]any) (map[string]any, string) {
|
||||
candidates, _ := raw["candidates"].([]any)
|
||||
for _, candidate := range candidates {
|
||||
candidateMap, _ := candidate.(map[string]any)
|
||||
content, _ := candidateMap["content"].(map[string]any)
|
||||
parts, _ := content["parts"].([]any)
|
||||
textParts := make([]string, 0, len(parts))
|
||||
toolCalls := make([]any, 0)
|
||||
for _, part := range parts {
|
||||
partMap, _ := part.(map[string]any)
|
||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||
return text
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
functionCall := mapFromAny(firstPresent(partMap["functionCall"], partMap["function_call"]))
|
||||
if len(functionCall) == 0 {
|
||||
continue
|
||||
}
|
||||
if toolCall := normalizeGeminiFunctionCall(functionCall, len(toolCalls), false); toolCall != nil {
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
message := map[string]any{
|
||||
"role": "assistant",
|
||||
"content": strings.Join(textParts, ""),
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
message["tool_calls"] = toolCalls
|
||||
if len(textParts) == 0 {
|
||||
message["content"] = nil
|
||||
}
|
||||
}
|
||||
return message, geminiFinishReason(candidateMap, len(toolCalls) > 0)
|
||||
}
|
||||
return map[string]any{"role": "assistant", "content": ""}, "stop"
|
||||
}
|
||||
|
||||
func geminiFinishReason(candidate map[string]any, hasToolCalls bool) string {
|
||||
if hasToolCalls {
|
||||
return "tool_calls"
|
||||
}
|
||||
switch strings.ToUpper(stringFromAny(candidate["finishReason"])) {
|
||||
case "MAX_TOKENS":
|
||||
return "length"
|
||||
case "SAFETY", "RECITATION", "BLOCKLIST", "PROHIBITED_CONTENT", "SPII":
|
||||
return "content_filter"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func geminiImageData(raw map[string]any) []any {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user