补充 Chat Completions 的兼容响应模型与路由注释,确保 /api/v1/chat/completions 按同步兼容格式返回并更新对应测试与 Swagger 文档。
115 lines
4.2 KiB
Go
115 lines
4.2 KiB
Go
package httpapi
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
func TestPlanTaskResponseTreatsAPIV1ChatCompletionsAsSynchronousCompatibleResponse(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
|
req.Header.Set("X-Async", "true")
|
|
|
|
plan := planTaskResponse("chat.completions", false, map[string]any{"stream": true}, req)
|
|
|
|
if plan.asyncMode {
|
|
t.Fatal("/api/v1/chat/completions must not enter async task mode")
|
|
}
|
|
if !plan.compatibleMode {
|
|
t.Fatal("/api/v1/chat/completions should return OpenAI-compatible response payloads")
|
|
}
|
|
if !plan.streamMode {
|
|
t.Fatal("stream=true should select SSE streaming mode")
|
|
}
|
|
}
|
|
|
|
func TestPlanTaskResponseKeepsAsyncTaskModeForOtherAPIV1Tasks(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/images/generations", nil)
|
|
req.Header.Set("X-Async", "true")
|
|
|
|
plan := planTaskResponse("images.generations", false, map[string]any{"stream": true}, req)
|
|
|
|
if !plan.asyncMode {
|
|
t.Fatal("non-chat /api/v1 task endpoints should keep X-Async task mode")
|
|
}
|
|
if plan.compatibleMode {
|
|
t.Fatal("non-compatible /api/v1 task endpoints should not return OpenAI-compatible payloads")
|
|
}
|
|
}
|
|
|
|
func TestWriteCompatibleTaskResponseReturnsJSONWhenStreamIsFalse(t *testing.T) {
|
|
executor := &fakeTaskExecutor{output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion"}}
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, false)
|
|
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("status=%d want=%d body=%s", recorder.Code, http.StatusOK, recorder.Body.String())
|
|
}
|
|
if executor.executeCalls != 1 || executor.streamCalls != 0 {
|
|
t.Fatalf("expected non-stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
|
|
}
|
|
var body map[string]any
|
|
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("decode response body: %v body=%s", err, recorder.Body.String())
|
|
}
|
|
if body["object"] != "chat.completion" {
|
|
t.Fatalf("unexpected compatible JSON response: %+v", body)
|
|
}
|
|
}
|
|
|
|
func TestWriteCompatibleTaskResponseReturnsSSEWhenStreamIsTrue(t *testing.T) {
|
|
executor := &fakeTaskExecutor{
|
|
deltas: []string{"hel", "lo"},
|
|
output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion"},
|
|
}
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, true)
|
|
|
|
if executor.executeCalls != 0 || executor.streamCalls != 1 {
|
|
t.Fatalf("expected stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
|
|
}
|
|
if contentType := recorder.Header().Get("Content-Type"); contentType != "text/event-stream" {
|
|
t.Fatalf("Content-Type=%q want text/event-stream", contentType)
|
|
}
|
|
body := recorder.Body.String()
|
|
for _, want := range []string{"event: message", `"content":"hel"`, `"content":"lo"`, `"finish_reason":"stop"`} {
|
|
if !strings.Contains(body, want) {
|
|
t.Fatalf("SSE body missing %s: %s", want, body)
|
|
}
|
|
}
|
|
}
|
|
|
|
type fakeTaskExecutor struct {
|
|
executeCalls int
|
|
streamCalls int
|
|
deltas []string
|
|
output map[string]any
|
|
}
|
|
|
|
func (f *fakeTaskExecutor) Execute(context.Context, store.GatewayTask, *auth.User) (runner.Result, error) {
|
|
f.executeCalls++
|
|
return runner.Result{Output: f.output}, nil
|
|
}
|
|
|
|
func (f *fakeTaskExecutor) ExecuteStream(_ context.Context, _ store.GatewayTask, _ *auth.User, onDelta clients.StreamDelta) (runner.Result, error) {
|
|
f.streamCalls++
|
|
for _, delta := range f.deltas {
|
|
if err := onDelta(delta); err != nil {
|
|
return runner.Result{}, err
|
|
}
|
|
}
|
|
return runner.Result{Output: f.output}, nil
|
|
}
|