diff --git a/apps/api/internal/clients/clients_test.go b/apps/api/internal/clients/clients_test.go index a4cce96..da8563c 100644 --- a/apps/api/internal/clients/clients_test.go +++ b/apps/api/internal/clients/clients_test.go @@ -149,15 +149,19 @@ func TestOpenAIClientChatContract(t *testing.T) { func TestOpenAIClientChatStreamContract(t *testing.T) { var gotStream bool + var gotIncludeUsage bool server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var body map[string]any if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("decode request: %v", err) } gotStream, _ = body["stream"].(bool) + streamOptions, _ := body["stream_options"].(map[string]any) + gotIncludeUsage, _ = streamOptions["include_usage"].(bool) w.Header().Set("Content-Type", "text/event-stream") - _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n")) - _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\" world\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}],\"usage\":null}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\" world\"},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n")) _, _ = w.Write([]byte("data: [DONE]\n\n")) })) defer server.Close() @@ -182,6 +186,9 @@ func TestOpenAIClientChatStreamContract(t *testing.T) { if !gotStream { t.Fatalf("expected upstream stream request") } + if !gotIncludeUsage { + t.Fatalf("expected upstream stream_options.include_usage=true") + } if response.Usage.TotalTokens != 3 { t.Fatalf("unexpected usage: %+v", response.Usage) } diff --git a/apps/api/internal/clients/openai.go b/apps/api/internal/clients/openai.go index 8dcd1a8..534f148 100644 --- a/apps/api/internal/clients/openai.go +++ b/apps/api/internal/clients/openai.go @@ -25,6 +25,7 @@ func (c OpenAIClient) Run(ctx context.Context, request Request) (Response, error body := cloneBody(request.Body) body["model"] = upstreamModelName(request.Candidate) stream := request.Stream || boolValue(body, "stream") + ensureOpenAIStreamUsage(body, request.Kind, stream) raw, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, endpoint), bytes.NewReader(raw)) if err != nil { @@ -91,6 +92,20 @@ func cloneBody(body map[string]any) map[string]any { return out } +func ensureOpenAIStreamUsage(body map[string]any, kind string, stream bool) { + if !stream || kind != "chat.completions" { + return + } + streamOptions := map[string]any{} + if existing, ok := body["stream_options"].(map[string]any); ok { + for key, value := range existing { + streamOptions[key] = value + } + } + streamOptions["include_usage"] = true + body["stream_options"] = streamOptions +} + func joinURL(base string, path string) string { base = strings.TrimRight(strings.TrimSpace(base), "/") if base == "" {