easyai-ai-gateway/apps/api/internal/httpapi/chat_completions_mode_test.go
chensipeng ae197a742f docs(api): 同步 /api/v1/chat/completions 的 OpenAPI 与同步响应
补充 Chat Completions 的兼容响应模型与路由注释,确保 /api/v1/chat/completions 按同步兼容格式返回并更新对应测试与 Swagger 文档。
2026-05-16 00:19:39 +08:00

115 lines
4.2 KiB
Go

package httpapi
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestPlanTaskResponseTreatsAPIV1ChatCompletionsAsSynchronousCompatibleResponse(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
req.Header.Set("X-Async", "true")
plan := planTaskResponse("chat.completions", false, map[string]any{"stream": true}, req)
if plan.asyncMode {
t.Fatal("/api/v1/chat/completions must not enter async task mode")
}
if !plan.compatibleMode {
t.Fatal("/api/v1/chat/completions should return OpenAI-compatible response payloads")
}
if !plan.streamMode {
t.Fatal("stream=true should select SSE streaming mode")
}
}
func TestPlanTaskResponseKeepsAsyncTaskModeForOtherAPIV1Tasks(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/v1/images/generations", nil)
req.Header.Set("X-Async", "true")
plan := planTaskResponse("images.generations", false, map[string]any{"stream": true}, req)
if !plan.asyncMode {
t.Fatal("non-chat /api/v1 task endpoints should keep X-Async task mode")
}
if plan.compatibleMode {
t.Fatal("non-compatible /api/v1 task endpoints should not return OpenAI-compatible payloads")
}
}
func TestWriteCompatibleTaskResponseReturnsJSONWhenStreamIsFalse(t *testing.T) {
executor := &fakeTaskExecutor{output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion"}}
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
recorder := httptest.NewRecorder()
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, false)
if recorder.Code != http.StatusOK {
t.Fatalf("status=%d want=%d body=%s", recorder.Code, http.StatusOK, recorder.Body.String())
}
if executor.executeCalls != 1 || executor.streamCalls != 0 {
t.Fatalf("expected non-stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
}
var body map[string]any
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response body: %v body=%s", err, recorder.Body.String())
}
if body["object"] != "chat.completion" {
t.Fatalf("unexpected compatible JSON response: %+v", body)
}
}
func TestWriteCompatibleTaskResponseReturnsSSEWhenStreamIsTrue(t *testing.T) {
executor := &fakeTaskExecutor{
deltas: []string{"hel", "lo"},
output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion"},
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
recorder := httptest.NewRecorder()
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, true)
if executor.executeCalls != 0 || executor.streamCalls != 1 {
t.Fatalf("expected stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
}
if contentType := recorder.Header().Get("Content-Type"); contentType != "text/event-stream" {
t.Fatalf("Content-Type=%q want text/event-stream", contentType)
}
body := recorder.Body.String()
for _, want := range []string{"event: message", `"content":"hel"`, `"content":"lo"`, `"finish_reason":"stop"`} {
if !strings.Contains(body, want) {
t.Fatalf("SSE body missing %s: %s", want, body)
}
}
}
type fakeTaskExecutor struct {
executeCalls int
streamCalls int
deltas []string
output map[string]any
}
func (f *fakeTaskExecutor) Execute(context.Context, store.GatewayTask, *auth.User) (runner.Result, error) {
f.executeCalls++
return runner.Result{Output: f.output}, nil
}
func (f *fakeTaskExecutor) ExecuteStream(_ context.Context, _ store.GatewayTask, _ *auth.User, onDelta clients.StreamDelta) (runner.Result, error) {
f.streamCalls++
for _, delta := range f.deltas {
if err := onDelta(delta); err != nil {
return runner.Result{}, err
}
}
return runner.Result{Output: f.output}, nil
}