easyai-ai-gateway/apps/api/internal/clients/clients_test.go

320 lines
11 KiB
Go

package clients
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestOpenAIClientChatContract(t *testing.T) {
var gotPath string
var gotAuth string
var gotModel string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization")
w.Header().Set("X-Request-Id", "req-chat-test")
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode request: %v", err)
}
gotModel, _ = body["model"].(string)
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-test",
"object": "chat.completion",
"model": gotModel,
"choices": []any{map[string]any{
"message": map[string]any{"role": "assistant", "content": "ok"},
}},
"usage": map[string]any{"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5},
})
}))
defer server.Close()
response, err := (OpenAIClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "chat.completions",
Model: "openai:gpt-4o-mini",
Body: map[string]any{"model": "openai:gpt-4o-mini", "messages": []any{map[string]any{"role": "user", "content": "ping"}}},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ModelName: "gpt-4o-mini",
Credentials: map[string]any{"apiKey": "test-key"},
},
})
if err != nil {
t.Fatalf("run openai client: %v", err)
}
if gotPath != "/chat/completions" || gotAuth != "Bearer test-key" || gotModel != "gpt-4o-mini" {
t.Fatalf("unexpected request path=%s auth=%s model=%s", gotPath, gotAuth, gotModel)
}
if response.Usage.TotalTokens != 5 || response.Result["id"] != "chatcmpl-test" {
t.Fatalf("unexpected response: %+v", response)
}
if response.RequestID != "req-chat-test" || response.ResponseStartedAt.IsZero() || response.ResponseFinishedAt.IsZero() {
t.Fatalf("response metadata was not captured: %+v", response)
}
}
func TestOpenAIClientChatStreamContract(t *testing.T) {
var gotStream bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode request: %v", err)
}
gotStream, _ = body["stream"].(bool)
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n"))
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\" world\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n"))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
}))
defer server.Close()
response, err := (OpenAIClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "chat.completions",
Model: "DeepSeek-V4-Flash",
Body: map[string]any{
"model": "DeepSeek-V4-Flash",
"messages": []any{map[string]any{"role": "user", "content": "ping"}},
"stream": true,
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ModelName: "deepseek-v4-flash",
Credentials: map[string]any{"apiKey": "test-key"},
},
})
if err != nil {
t.Fatalf("run openai stream client: %v", err)
}
if !gotStream {
t.Fatalf("expected upstream stream request")
}
if response.Usage.TotalTokens != 3 {
t.Fatalf("unexpected usage: %+v", response.Usage)
}
choices, _ := response.Result["choices"].([]any)
choice, _ := choices[0].(map[string]any)
message, _ := choice["message"].(map[string]any)
if message["content"] != "hello world" {
t.Fatalf("unexpected stream response: %+v", response.Result)
}
}
func TestGeminiClientChatContract(t *testing.T) {
var gotPath string
var gotKey string
var gotText string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotKey = r.URL.Query().Get("key")
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode request: %v", err)
}
contents, _ := body["contents"].([]any)
first, _ := contents[0].(map[string]any)
parts, _ := first["parts"].([]any)
part, _ := parts[0].(map[string]any)
gotText, _ = part["text"].(string)
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{map[string]any{"text": "gemini ok"}},
},
}},
"usageMetadata": map[string]any{
"promptTokenCount": 4,
"candidatesTokenCount": 6,
"totalTokenCount": 10,
},
})
}))
defer server.Close()
response, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "chat.completions",
Model: "gemini:gemini-2.5-flash",
Body: map[string]any{
"model": "gemini:gemini-2.5-flash",
"messages": []any{map[string]any{"role": "user", "content": "ping"}},
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ModelName: "gemini-2.5-flash",
ModelType: "chat",
Credentials: map[string]any{"apiKey": "gemini-key"},
},
})
if err != nil {
t.Fatalf("run gemini client: %v", err)
}
if gotPath != "/v1beta/models/gemini-2.5-flash:generateContent" || gotKey != "gemini-key" || gotText != "ping" {
t.Fatalf("unexpected request path=%s key=%s text=%s", gotPath, gotKey, gotText)
}
if response.Usage.TotalTokens != 10 || extractText(response.Result) != "gemini ok" {
t.Fatalf("unexpected response: %+v", response)
}
}
func TestGeminiURLAcceptsVersionedBaseURL(t *testing.T) {
got := geminiURL("https://generativelanguage.googleapis.com/v1beta", "gemini-2.5-flash", "test-key")
want := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key=test-key"
if got != want {
t.Fatalf("unexpected gemini url: %s", got)
}
}
func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) {
var gotPath string
var gotAuth string
var gotModel string
var gotImage string
var gotSequential string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization")
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode request: %v", err)
}
gotModel, _ = body["model"].(string)
gotImage, _ = body["image"].(string)
gotSequential, _ = body["sequential_image_generation"].(string)
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "img-volces-edit",
"created": 123,
"data": []any{map[string]any{"url": "https://example.com/out.png"}},
})
}))
defer server.Close()
response, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "images.edits",
ModelType: "image_edit",
Model: "doubao-4.0图像编辑",
Body: map[string]any{
"model": "doubao-4.0图像编辑",
"prompt": "make it brighter",
"image": "https://example.com/source.png",
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ModelName: "doubao-seedream-4-0-250828",
Credentials: map[string]any{"apiKey": "volces-key"},
Capabilities: map[string]any{
"image_edit": map[string]any{"output_multiple_images": true},
},
},
})
if err != nil {
t.Fatalf("run volces image edit: %v", err)
}
if gotPath != "/images/generations" || gotAuth != "Bearer volces-key" {
t.Fatalf("unexpected request path=%s auth=%s", gotPath, gotAuth)
}
if gotModel != "doubao-seedream-4-0-250828" || gotImage != "https://example.com/source.png" || gotSequential != "auto" {
t.Fatalf("unexpected body model=%s image=%s sequential=%s", gotModel, gotImage, gotSequential)
}
if response.Result["id"] != "img-volces-edit" {
t.Fatalf("unexpected response: %+v", response.Result)
}
}
func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
var submitPath string
var pollPath string
var gotAuth string
var gotModel string
var gotText string
var gotFirstFrameRole string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
switch r.Method + " " + r.URL.Path {
case "POST /contents/generations/tasks":
submitPath = r.URL.Path
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode request: %v", err)
}
gotModel, _ = body["model"].(string)
if body["prompt"] != nil || body["first_frame"] != nil {
t.Fatalf("video convenience fields leaked upstream: %+v", body)
}
content, _ := body["content"].([]any)
textItem, _ := content[0].(map[string]any)
gotText, _ = textItem["text"].(string)
frameItem, _ := content[1].(map[string]any)
gotFirstFrameRole, _ = frameItem["role"].(string)
_ = json.NewEncoder(w).Encode(map[string]any{"id": "cgt-test"})
case "GET /contents/generations/tasks/cgt-test":
pollPath = r.URL.Path
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "cgt-test",
"model": "doubao-seedance-2-0-260128",
"status": "succeeded",
"created_at": 456,
"content": map[string]any{"video_url": "https://example.com/out.mp4"},
"usage": map[string]any{"completion_tokens": 7, "total_tokens": 9},
})
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
}))
defer server.Close()
response, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "豆包Seedance-2.0",
Body: map[string]any{
"model": "豆包Seedance-2.0",
"prompt": "A clean product reveal",
"first_frame": "https://example.com/first.png",
"duration": 6,
"aspect_ratio": "16:9",
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ModelName: "doubao-seedance-2-0-260128",
Credentials: map[string]any{"apiKey": "volces-key"},
PlatformConfig: map[string]any{
"volcesPollIntervalMs": 100,
"volcesPollTimeoutSeconds": 1,
},
},
})
if err != nil {
t.Fatalf("run volces video: %v", err)
}
if submitPath != "/contents/generations/tasks" || pollPath != "/contents/generations/tasks/cgt-test" || gotAuth != "Bearer volces-key" {
t.Fatalf("unexpected paths/auth submit=%s poll=%s auth=%s", submitPath, pollPath, gotAuth)
}
if gotModel != "doubao-seedance-2-0-260128" || gotFirstFrameRole != "first_frame" {
t.Fatalf("unexpected submitted model=%s role=%s", gotModel, gotFirstFrameRole)
}
for _, fragment := range []string{"A clean product reveal", "--dur 6", "--ratio 16:9", "--watermark false", "--seed -1"} {
if !strings.Contains(gotText, fragment) {
t.Fatalf("expected text to contain %q, got %q", fragment, gotText)
}
}
data, _ := response.Result["data"].([]any)
item, _ := data[0].(map[string]any)
if item["url"] != "https://example.com/out.mp4" || response.Usage.TotalTokens != 9 {
t.Fatalf("unexpected response: %+v usage=%+v", response.Result, response.Usage)
}
}
func extractText(result map[string]any) string {
choices, _ := result["choices"].([]any)
choice, _ := choices[0].(map[string]any)
message, _ := choice["message"].(map[string]any)
text, _ := message["content"].(string)
return text
}