easyai-ai-gateway/apps/api/internal/httpapi/chat_completions_mode_test.go

151 lines
8.5 KiB
Go

package httpapi
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestPlanTaskResponseTreatsAPIV1ChatCompletionsAsSynchronousCompatibleResponse(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
req.Header.Set("X-Async", "true")
plan := planTaskResponse("chat.completions", false, map[string]any{"stream": true}, req)
if plan.asyncMode {
t.Fatal("/api/v1/chat/completions must not enter async task mode")
}
if !plan.compatibleMode {
t.Fatal("/api/v1/chat/completions should return OpenAI-compatible response payloads")
}
if !plan.streamMode {
t.Fatal("stream=true should select SSE streaming mode")
}
}
func TestPlanTaskResponseKeepsAsyncTaskModeForOtherAPIV1Tasks(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/v1/images/generations", nil)
req.Header.Set("X-Async", "true")
plan := planTaskResponse("images.generations", false, map[string]any{"stream": true}, req)
if !plan.asyncMode {
t.Fatal("non-chat /api/v1 task endpoints should keep X-Async task mode")
}
if plan.compatibleMode {
t.Fatal("non-compatible /api/v1 task endpoints should not return OpenAI-compatible payloads")
}
}
func TestWriteCompatibleTaskResponseReturnsJSONWhenStreamIsFalse(t *testing.T) {
executor := &fakeTaskExecutor{output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion"}}
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
recorder := httptest.NewRecorder()
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, false, false)
if recorder.Code != http.StatusOK {
t.Fatalf("status=%d want=%d body=%s", recorder.Code, http.StatusOK, recorder.Body.String())
}
if executor.executeCalls != 1 || executor.streamCalls != 0 {
t.Fatalf("expected non-stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
}
var body map[string]any
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response body: %v body=%s", err, recorder.Body.String())
}
if body["object"] != "chat.completion" {
t.Fatalf("unexpected compatible JSON response: %+v", body)
}
}
func TestWriteCompatibleTaskResponseReturnsSSEWhenStreamIsTrue(t *testing.T) {
executor := &fakeTaskExecutor{
deltas: []clients.StreamDeltaEvent{{Text: "hel"}, {Text: "lo"}},
output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion", "usage": map[string]any{"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}},
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
recorder := httptest.NewRecorder()
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, true, true)
if executor.executeCalls != 0 || executor.streamCalls != 1 {
t.Fatalf("expected stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
}
if contentType := recorder.Header().Get("Content-Type"); contentType != "text/event-stream" {
t.Fatalf("Content-Type=%q want text/event-stream", contentType)
}
body := recorder.Body.String()
for _, want := range []string{`data: {`, `"role":"assistant"`, `"created":`, `"system_fingerprint":`, `"content":"hel"`, `"content":"lo"`, `"finish_reason":"stop"`, `"usage":{"completion_tokens":2,"prompt_tokens":1,"total_tokens":3}`, "data: [DONE]"} {
if !strings.Contains(body, want) {
t.Fatalf("SSE body missing %s: %s", want, body)
}
}
if strings.Contains(body, "event: message") {
t.Fatalf("chat completions stream should use OpenAI data-only SSE frames: %s", body)
}
}
func TestWriteCompatibleTaskResponseStreamsStructuredToolAndReasoningDeltas(t *testing.T) {
executor := &fakeTaskExecutor{
deltas: []clients.StreamDeltaEvent{
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"reasoning_details": []any{map[string]any{"type": "reasoning.text", "text": "detail-"}, map[string]any{"type": "reasoning.summary", "summary": "summary"}, map[string]any{"type": "reasoning.encrypted", "data": "secret"}}}, "finish_reason": nil}}}},
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"content": "<think>tagged</think>answer"}, "finish_reason": nil}}}},
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"functionCall": map[string]any{"name": "legacy_lookup", "arguments": "{\"city\":\"Boston\"}"}}, "finish_reason": nil}}}},
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"tool_calls": []any{map[string]any{"index": float64(0), "id": "call_1", "type": "function", "function": map[string]any{"name": "lookup", "arguments": "{\"q\":"}}}}, "finish_reason": nil}}}},
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"tool_calls": []any{map[string]any{"index": float64(0), "function": map[string]any{"arguments": "\"weather\"}"}}}}, "finish_reason": "tool_calls"}}}},
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "choices": []any{}, "usage": map[string]any{"prompt_tokens": float64(4), "completion_tokens": float64(5), "total_tokens": float64(9)}}},
},
output: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion", "model": "deepseek-v4"},
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
recorder := httptest.NewRecorder()
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, true, true)
body := recorder.Body.String()
roleIndex := strings.Index(body, `"role":"assistant"`)
reasoningIndex := strings.Index(body, `"reasoning_content":"detail-summary"`)
if roleIndex < 0 || reasoningIndex < 0 || roleIndex > reasoningIndex {
t.Fatalf("assistant role should be emitted before structured deltas: %s", body)
}
for _, want := range []string{`"system_fingerprint":"fp-test"`, `"created":1710000000`, `"reasoning_content":"tagged"`, `"content":"answer"`, `"tool_calls":[{"function":{"arguments":"{\"city\":\"Boston\"}","name":"legacy_lookup"}`, `"tool_calls":[{"function":{"arguments":"{\"q\":"`, `"finish_reason":"tool_calls"`, `"choices":[],"created":1710000000`, `"usage":{"completion_tokens":5,"prompt_tokens":4,"total_tokens":9}`, "data: [DONE]"} {
if !strings.Contains(body, want) {
t.Fatalf("SSE body missing %s: %s", want, body)
}
}
if strings.Contains(body, "reasoning_details") || strings.Contains(body, "<think>") || strings.Contains(body, "functionCall") {
t.Fatalf("provider-specific reasoning/tool fields should be converted away: %s", body)
}
}
type fakeTaskExecutor struct {
executeCalls int
streamCalls int
deltas []clients.StreamDeltaEvent
output map[string]any
}
func (f *fakeTaskExecutor) Execute(context.Context, store.GatewayTask, *auth.User) (runner.Result, error) {
f.executeCalls++
return runner.Result{Output: f.output}, nil
}
func (f *fakeTaskExecutor) ExecuteStream(_ context.Context, _ store.GatewayTask, _ *auth.User, onDelta clients.StreamDelta) (runner.Result, error) {
f.streamCalls++
for _, delta := range f.deltas {
if err := onDelta(delta); err != nil {
return runner.Result{}, err
}
}
return runner.Result{Output: f.output}, nil
}