303 lines
7.3 KiB
Go
303 lines
7.3 KiB
Go
package clients
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func credential(candidate map[string]any, keys ...string) string {
|
|
for _, key := range keys {
|
|
if value, ok := candidate[key].(string); ok && strings.TrimSpace(value) != "" {
|
|
return strings.TrimSpace(value)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func boolValue(body map[string]any, key string) bool {
|
|
value, _ := body[key].(bool)
|
|
return value
|
|
}
|
|
|
|
func stringValue(body map[string]any, key string) string {
|
|
value, _ := body[key].(string)
|
|
return strings.TrimSpace(value)
|
|
}
|
|
|
|
func intValue(body map[string]any, key string, fallback int) int {
|
|
switch value := body[key].(type) {
|
|
case float64:
|
|
return int(math.Round(value))
|
|
case int:
|
|
return value
|
|
default:
|
|
return fallback
|
|
}
|
|
}
|
|
|
|
func decodeHTTPResponse(resp *http.Response) (map[string]any, error) {
|
|
defer resp.Body.Close()
|
|
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, &ClientError{
|
|
Code: statusCodeName(resp.StatusCode),
|
|
Message: errorMessage(raw, resp.Status),
|
|
StatusCode: resp.StatusCode,
|
|
RequestID: requestIDFromHTTPResponse(resp),
|
|
Retryable: HTTPRetryable(resp.StatusCode),
|
|
}
|
|
}
|
|
var out map[string]any
|
|
if len(raw) == 0 {
|
|
return map[string]any{}, nil
|
|
}
|
|
if err := json.Unmarshal(raw, &out); err != nil {
|
|
return nil, &ClientError{Code: "invalid_response", Message: err.Error(), Retryable: false}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func decodeOpenAIStreamResponse(resp *http.Response, onDelta StreamDelta) (map[string]any, error) {
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
|
|
return nil, &ClientError{
|
|
Code: statusCodeName(resp.StatusCode),
|
|
Message: errorMessage(raw, resp.Status),
|
|
StatusCode: resp.StatusCode,
|
|
RequestID: requestIDFromHTTPResponse(resp),
|
|
Retryable: HTTPRetryable(resp.StatusCode),
|
|
}
|
|
}
|
|
if result, ok, err := decodeOpenAIStreamReader(resp.Body, onDelta); ok || err != nil {
|
|
return result, err
|
|
}
|
|
return map[string]any{}, nil
|
|
}
|
|
|
|
func decodeOpenAIStreamReader(reader io.Reader, onDelta StreamDelta) (map[string]any, bool, error) {
|
|
scanner := bufio.NewScanner(reader)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024)
|
|
rawLines := make([]string, 0)
|
|
parts := make([]string, 0)
|
|
var last map[string]any
|
|
var usage Usage
|
|
for scanner.Scan() {
|
|
rawLine := scanner.Text()
|
|
rawLines = append(rawLines, rawLine)
|
|
line := strings.TrimSpace(rawLine)
|
|
if !strings.HasPrefix(line, "data:") {
|
|
continue
|
|
}
|
|
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
if payload == "" || payload == "[DONE]" {
|
|
continue
|
|
}
|
|
var event map[string]any
|
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
|
continue
|
|
}
|
|
last = event
|
|
if text := streamEventText(event); text != "" {
|
|
parts = append(parts, text)
|
|
if onDelta != nil {
|
|
if err := onDelta(text); err != nil {
|
|
return nil, true, err
|
|
}
|
|
}
|
|
}
|
|
if eventUsage := usageFromOpenAI(event); eventUsage.TotalTokens > 0 {
|
|
usage = eventUsage
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, true, &ClientError{Code: "stream_read_error", Message: err.Error(), Retryable: true}
|
|
}
|
|
if last == nil {
|
|
raw := []byte(strings.Join(rawLines, "\n"))
|
|
if len(raw) == 0 {
|
|
return map[string]any{}, true, nil
|
|
}
|
|
var out map[string]any
|
|
if err := json.Unmarshal(raw, &out); err != nil {
|
|
return nil, false, nil
|
|
}
|
|
return out, true, nil
|
|
}
|
|
return buildOpenAIStreamResult(last, parts, usage), true, nil
|
|
}
|
|
|
|
func decodeOpenAIStream(raw []byte) (map[string]any, bool) {
|
|
if !bytes.Contains(raw, []byte("data:")) {
|
|
return nil, false
|
|
}
|
|
result, ok, err := decodeOpenAIStreamReader(bytes.NewReader(raw), nil)
|
|
return result, ok && err == nil
|
|
}
|
|
|
|
func buildOpenAIStreamResult(last map[string]any, parts []string, usage Usage) map[string]any {
|
|
if len(parts) == 0 {
|
|
return last
|
|
}
|
|
var out map[string]any
|
|
out = map[string]any{
|
|
"id": stringFromAny(firstPresent(last["id"], "chatcmpl-stream")),
|
|
"object": "chat.completion",
|
|
"model": stringFromAny(last["model"]),
|
|
"choices": []any{map[string]any{
|
|
"index": 0,
|
|
"message": map[string]any{
|
|
"role": "assistant",
|
|
"content": strings.Join(parts, ""),
|
|
},
|
|
"finish_reason": "stop",
|
|
}},
|
|
}
|
|
if usage.TotalTokens > 0 {
|
|
out["usage"] = map[string]any{
|
|
"prompt_tokens": usage.InputTokens,
|
|
"completion_tokens": usage.OutputTokens,
|
|
"total_tokens": usage.TotalTokens,
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func streamEventText(event map[string]any) string {
|
|
if choices, ok := event["choices"].([]any); ok {
|
|
for _, rawChoice := range choices {
|
|
choice, _ := rawChoice.(map[string]any)
|
|
if delta, ok := choice["delta"].(map[string]any); ok {
|
|
if content, ok := delta["content"].(string); ok {
|
|
return content
|
|
}
|
|
}
|
|
if message, ok := choice["message"].(map[string]any); ok {
|
|
if content, ok := message["content"].(string); ok {
|
|
return content
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if delta, ok := event["delta"].(string); ok {
|
|
return delta
|
|
}
|
|
if text, ok := event["output_text"].(string); ok {
|
|
return text
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func usageFromOpenAI(result map[string]any) Usage {
|
|
usage, _ := result["usage"].(map[string]any)
|
|
input := intFromAny(firstPresent(usage["prompt_tokens"], usage["input_tokens"]))
|
|
output := intFromAny(firstPresent(usage["completion_tokens"], usage["output_tokens"]))
|
|
total := intFromAny(usage["total_tokens"])
|
|
if total == 0 {
|
|
total = input + output
|
|
}
|
|
return Usage{InputTokens: input, OutputTokens: output, TotalTokens: total}
|
|
}
|
|
|
|
func requestIDFromHTTPResponse(resp *http.Response) string {
|
|
if resp == nil {
|
|
return ""
|
|
}
|
|
for _, key := range []string{
|
|
"x-request-id",
|
|
"x-requestid",
|
|
"request-id",
|
|
"x-amzn-requestid",
|
|
"x-amz-request-id",
|
|
"cf-ray",
|
|
} {
|
|
if value := strings.TrimSpace(resp.Header.Get(key)); value != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func requestIDFromResult(result map[string]any) string {
|
|
for _, key := range []string{"request_id", "requestId", "id", "response_id", "responseId"} {
|
|
if value := strings.TrimSpace(stringFromAny(result[key])); value != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func intFromAny(value any) int {
|
|
switch typed := value.(type) {
|
|
case float64:
|
|
return int(math.Round(typed))
|
|
case int:
|
|
return typed
|
|
case int64:
|
|
return int(typed)
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func stringFromAny(value any) string {
|
|
if text, ok := value.(string); ok {
|
|
return text
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func firstPresent(values ...any) any {
|
|
for _, value := range values {
|
|
if value != nil {
|
|
return value
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func errorMessage(raw []byte, fallback string) string {
|
|
if len(raw) == 0 {
|
|
return fallback
|
|
}
|
|
var parsed map[string]any
|
|
if json.Unmarshal(raw, &parsed) == nil {
|
|
if errObj, ok := parsed["error"].(map[string]any); ok {
|
|
if message, ok := errObj["message"].(string); ok {
|
|
return message
|
|
}
|
|
}
|
|
if message, ok := parsed["message"].(string); ok {
|
|
return message
|
|
}
|
|
}
|
|
return string(raw)
|
|
}
|
|
|
|
func statusCodeName(status int) string {
|
|
switch status {
|
|
case http.StatusTooManyRequests:
|
|
return "rate_limit"
|
|
case http.StatusRequestTimeout:
|
|
return "timeout"
|
|
case http.StatusUnauthorized, http.StatusForbidden:
|
|
return "auth_failed"
|
|
default:
|
|
if status >= 500 {
|
|
return "server_error"
|
|
}
|
|
return fmt.Sprintf("http_%d", status)
|
|
}
|
|
}
|
|
|
|
func nowUnix() int64 {
|
|
return time.Now().Unix()
|
|
}
|