easyai-ai-gateway/apps/api/internal/clients/helpers.go

303 lines
7.3 KiB
Go

package clients
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"time"
)
func credential(candidate map[string]any, keys ...string) string {
for _, key := range keys {
if value, ok := candidate[key].(string); ok && strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func boolValue(body map[string]any, key string) bool {
value, _ := body[key].(bool)
return value
}
func stringValue(body map[string]any, key string) string {
value, _ := body[key].(string)
return strings.TrimSpace(value)
}
func intValue(body map[string]any, key string, fallback int) int {
switch value := body[key].(type) {
case float64:
return int(math.Round(value))
case int:
return value
default:
return fallback
}
}
func decodeHTTPResponse(resp *http.Response) (map[string]any, error) {
defer resp.Body.Close()
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, &ClientError{
Code: statusCodeName(resp.StatusCode),
Message: errorMessage(raw, resp.Status),
StatusCode: resp.StatusCode,
RequestID: requestIDFromHTTPResponse(resp),
Retryable: HTTPRetryable(resp.StatusCode),
}
}
var out map[string]any
if len(raw) == 0 {
return map[string]any{}, nil
}
if err := json.Unmarshal(raw, &out); err != nil {
return nil, &ClientError{Code: "invalid_response", Message: err.Error(), Retryable: false}
}
return out, nil
}
func decodeOpenAIStreamResponse(resp *http.Response, onDelta StreamDelta) (map[string]any, error) {
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
return nil, &ClientError{
Code: statusCodeName(resp.StatusCode),
Message: errorMessage(raw, resp.Status),
StatusCode: resp.StatusCode,
RequestID: requestIDFromHTTPResponse(resp),
Retryable: HTTPRetryable(resp.StatusCode),
}
}
if result, ok, err := decodeOpenAIStreamReader(resp.Body, onDelta); ok || err != nil {
return result, err
}
return map[string]any{}, nil
}
func decodeOpenAIStreamReader(reader io.Reader, onDelta StreamDelta) (map[string]any, bool, error) {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024)
rawLines := make([]string, 0)
parts := make([]string, 0)
var last map[string]any
var usage Usage
for scanner.Scan() {
rawLine := scanner.Text()
rawLines = append(rawLines, rawLine)
line := strings.TrimSpace(rawLine)
if !strings.HasPrefix(line, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" || payload == "[DONE]" {
continue
}
var event map[string]any
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
last = event
if text := streamEventText(event); text != "" {
parts = append(parts, text)
if onDelta != nil {
if err := onDelta(text); err != nil {
return nil, true, err
}
}
}
if eventUsage := usageFromOpenAI(event); eventUsage.TotalTokens > 0 {
usage = eventUsage
}
}
if err := scanner.Err(); err != nil {
return nil, true, &ClientError{Code: "stream_read_error", Message: err.Error(), Retryable: true}
}
if last == nil {
raw := []byte(strings.Join(rawLines, "\n"))
if len(raw) == 0 {
return map[string]any{}, true, nil
}
var out map[string]any
if err := json.Unmarshal(raw, &out); err != nil {
return nil, false, nil
}
return out, true, nil
}
return buildOpenAIStreamResult(last, parts, usage), true, nil
}
func decodeOpenAIStream(raw []byte) (map[string]any, bool) {
if !bytes.Contains(raw, []byte("data:")) {
return nil, false
}
result, ok, err := decodeOpenAIStreamReader(bytes.NewReader(raw), nil)
return result, ok && err == nil
}
func buildOpenAIStreamResult(last map[string]any, parts []string, usage Usage) map[string]any {
if len(parts) == 0 {
return last
}
var out map[string]any
out = map[string]any{
"id": stringFromAny(firstPresent(last["id"], "chatcmpl-stream")),
"object": "chat.completion",
"model": stringFromAny(last["model"]),
"choices": []any{map[string]any{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": strings.Join(parts, ""),
},
"finish_reason": "stop",
}},
}
if usage.TotalTokens > 0 {
out["usage"] = map[string]any{
"prompt_tokens": usage.InputTokens,
"completion_tokens": usage.OutputTokens,
"total_tokens": usage.TotalTokens,
}
}
return out
}
func streamEventText(event map[string]any) string {
if choices, ok := event["choices"].([]any); ok {
for _, rawChoice := range choices {
choice, _ := rawChoice.(map[string]any)
if delta, ok := choice["delta"].(map[string]any); ok {
if content, ok := delta["content"].(string); ok {
return content
}
}
if message, ok := choice["message"].(map[string]any); ok {
if content, ok := message["content"].(string); ok {
return content
}
}
}
}
if delta, ok := event["delta"].(string); ok {
return delta
}
if text, ok := event["output_text"].(string); ok {
return text
}
return ""
}
func usageFromOpenAI(result map[string]any) Usage {
usage, _ := result["usage"].(map[string]any)
input := intFromAny(firstPresent(usage["prompt_tokens"], usage["input_tokens"]))
output := intFromAny(firstPresent(usage["completion_tokens"], usage["output_tokens"]))
total := intFromAny(usage["total_tokens"])
if total == 0 {
total = input + output
}
return Usage{InputTokens: input, OutputTokens: output, TotalTokens: total}
}
func requestIDFromHTTPResponse(resp *http.Response) string {
if resp == nil {
return ""
}
for _, key := range []string{
"x-request-id",
"x-requestid",
"request-id",
"x-amzn-requestid",
"x-amz-request-id",
"cf-ray",
} {
if value := strings.TrimSpace(resp.Header.Get(key)); value != "" {
return value
}
}
return ""
}
func requestIDFromResult(result map[string]any) string {
for _, key := range []string{"request_id", "requestId", "id", "response_id", "responseId"} {
if value := strings.TrimSpace(stringFromAny(result[key])); value != "" {
return value
}
}
return ""
}
func intFromAny(value any) int {
switch typed := value.(type) {
case float64:
return int(math.Round(typed))
case int:
return typed
case int64:
return int(typed)
default:
return 0
}
}
func stringFromAny(value any) string {
if text, ok := value.(string); ok {
return text
}
return ""
}
func firstPresent(values ...any) any {
for _, value := range values {
if value != nil {
return value
}
}
return nil
}
func errorMessage(raw []byte, fallback string) string {
if len(raw) == 0 {
return fallback
}
var parsed map[string]any
if json.Unmarshal(raw, &parsed) == nil {
if errObj, ok := parsed["error"].(map[string]any); ok {
if message, ok := errObj["message"].(string); ok {
return message
}
}
if message, ok := parsed["message"].(string); ok {
return message
}
}
return string(raw)
}
func statusCodeName(status int) string {
switch status {
case http.StatusTooManyRequests:
return "rate_limit"
case http.StatusRequestTimeout:
return "timeout"
case http.StatusUnauthorized, http.StatusForbidden:
return "auth_failed"
default:
if status >= 500 {
return "server_error"
}
return fmt.Sprintf("http_%d", status)
}
}
func nowUnix() int64 {
return time.Now().Unix()
}