fix: request usage for OpenAI streams
This commit is contained in:
parent
4e54134e2a
commit
483f3ab1f5
@ -149,15 +149,19 @@ func TestOpenAIClientChatContract(t *testing.T) {
|
||||
|
||||
func TestOpenAIClientChatStreamContract(t *testing.T) {
|
||||
var gotStream bool
|
||||
var gotIncludeUsage bool
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
gotStream, _ = body["stream"].(bool)
|
||||
streamOptions, _ := body["stream_options"].(map[string]any)
|
||||
gotIncludeUsage, _ = streamOptions["include_usage"].(bool)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\" world\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}],\"usage\":null}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[{\"delta\":{\"content\":\" world\"},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4-flash\",\"choices\":[],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
@ -182,6 +186,9 @@ func TestOpenAIClientChatStreamContract(t *testing.T) {
|
||||
if !gotStream {
|
||||
t.Fatalf("expected upstream stream request")
|
||||
}
|
||||
if !gotIncludeUsage {
|
||||
t.Fatalf("expected upstream stream_options.include_usage=true")
|
||||
}
|
||||
if response.Usage.TotalTokens != 3 {
|
||||
t.Fatalf("unexpected usage: %+v", response.Usage)
|
||||
}
|
||||
|
||||
@ -25,6 +25,7 @@ func (c OpenAIClient) Run(ctx context.Context, request Request) (Response, error
|
||||
body := cloneBody(request.Body)
|
||||
body["model"] = upstreamModelName(request.Candidate)
|
||||
stream := request.Stream || boolValue(body, "stream")
|
||||
ensureOpenAIStreamUsage(body, request.Kind, stream)
|
||||
raw, _ := json.Marshal(body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, endpoint), bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
@ -91,6 +92,20 @@ func cloneBody(body map[string]any) map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func ensureOpenAIStreamUsage(body map[string]any, kind string, stream bool) {
|
||||
if !stream || kind != "chat.completions" {
|
||||
return
|
||||
}
|
||||
streamOptions := map[string]any{}
|
||||
if existing, ok := body["stream_options"].(map[string]any); ok {
|
||||
for key, value := range existing {
|
||||
streamOptions[key] = value
|
||||
}
|
||||
}
|
||||
streamOptions["include_usage"] = true
|
||||
body["stream_options"] = streamOptions
|
||||
}
|
||||
|
||||
func joinURL(base string, path string) string {
|
||||
base = strings.TrimRight(strings.TrimSpace(base), "/")
|
||||
if base == "" {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user