121 lines
3.8 KiB
Go
121 lines
3.8 KiB
Go
package clients
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
func TestOpenAIClientChatContract(t *testing.T) {
|
|
var gotPath string
|
|
var gotAuth string
|
|
var gotModel string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotPath = r.URL.Path
|
|
gotAuth = r.Header.Get("Authorization")
|
|
var body map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
gotModel, _ = body["model"].(string)
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion",
|
|
"model": gotModel,
|
|
"choices": []any{map[string]any{
|
|
"message": map[string]any{"role": "assistant", "content": "ok"},
|
|
}},
|
|
"usage": map[string]any{"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5},
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
response, err := (OpenAIClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
|
Kind: "chat.completions",
|
|
Model: "openai:gpt-4o-mini",
|
|
Body: map[string]any{"model": "openai:gpt-4o-mini", "messages": []any{map[string]any{"role": "user", "content": "ping"}}},
|
|
Candidate: store.RuntimeModelCandidate{
|
|
BaseURL: server.URL,
|
|
ModelName: "gpt-4o-mini",
|
|
Credentials: map[string]any{"apiKey": "test-key"},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("run openai client: %v", err)
|
|
}
|
|
if gotPath != "/chat/completions" || gotAuth != "Bearer test-key" || gotModel != "gpt-4o-mini" {
|
|
t.Fatalf("unexpected request path=%s auth=%s model=%s", gotPath, gotAuth, gotModel)
|
|
}
|
|
if response.Usage.TotalTokens != 5 || response.Result["id"] != "chatcmpl-test" {
|
|
t.Fatalf("unexpected response: %+v", response)
|
|
}
|
|
}
|
|
|
|
func TestGeminiClientChatContract(t *testing.T) {
|
|
var gotPath string
|
|
var gotKey string
|
|
var gotText string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotPath = r.URL.Path
|
|
gotKey = r.URL.Query().Get("key")
|
|
var body map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
contents, _ := body["contents"].([]any)
|
|
first, _ := contents[0].(map[string]any)
|
|
parts, _ := first["parts"].([]any)
|
|
part, _ := parts[0].(map[string]any)
|
|
gotText, _ = part["text"].(string)
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"candidates": []any{map[string]any{
|
|
"content": map[string]any{
|
|
"parts": []any{map[string]any{"text": "gemini ok"}},
|
|
},
|
|
}},
|
|
"usageMetadata": map[string]any{
|
|
"promptTokenCount": 4,
|
|
"candidatesTokenCount": 6,
|
|
"totalTokenCount": 10,
|
|
},
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
response, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
|
Kind: "chat.completions",
|
|
Model: "gemini:gemini-2.5-flash",
|
|
Body: map[string]any{
|
|
"model": "gemini:gemini-2.5-flash",
|
|
"messages": []any{map[string]any{"role": "user", "content": "ping"}},
|
|
},
|
|
Candidate: store.RuntimeModelCandidate{
|
|
BaseURL: server.URL,
|
|
ModelName: "gemini-2.5-flash",
|
|
ModelType: "chat",
|
|
Credentials: map[string]any{"apiKey": "gemini-key"},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("run gemini client: %v", err)
|
|
}
|
|
if gotPath != "/v1beta/models/gemini-2.5-flash:generateContent" || gotKey != "gemini-key" || gotText != "ping" {
|
|
t.Fatalf("unexpected request path=%s key=%s text=%s", gotPath, gotKey, gotText)
|
|
}
|
|
if response.Usage.TotalTokens != 10 || extractText(response.Result) != "gemini ok" {
|
|
t.Fatalf("unexpected response: %+v", response)
|
|
}
|
|
}
|
|
|
|
func extractText(result map[string]any) string {
|
|
choices, _ := result["choices"].([]any)
|
|
choice, _ := choices[0].(map[string]any)
|
|
message, _ := choice["message"].(map[string]any)
|
|
text, _ := message["content"].(string)
|
|
return text
|
|
}
|