134 lines
2.9 KiB
Go
134 lines
2.9 KiB
Go
package clients
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func credential(candidate map[string]any, keys ...string) string {
|
|
for _, key := range keys {
|
|
if value, ok := candidate[key].(string); ok && strings.TrimSpace(value) != "" {
|
|
return strings.TrimSpace(value)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func boolValue(body map[string]any, key string) bool {
|
|
value, _ := body[key].(bool)
|
|
return value
|
|
}
|
|
|
|
func stringValue(body map[string]any, key string) string {
|
|
value, _ := body[key].(string)
|
|
return strings.TrimSpace(value)
|
|
}
|
|
|
|
func intValue(body map[string]any, key string, fallback int) int {
|
|
switch value := body[key].(type) {
|
|
case float64:
|
|
return int(math.Round(value))
|
|
case int:
|
|
return value
|
|
default:
|
|
return fallback
|
|
}
|
|
}
|
|
|
|
func decodeHTTPResponse(resp *http.Response) (map[string]any, error) {
|
|
defer resp.Body.Close()
|
|
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, &ClientError{
|
|
Code: statusCodeName(resp.StatusCode),
|
|
Message: errorMessage(raw, resp.Status),
|
|
StatusCode: resp.StatusCode,
|
|
Retryable: HTTPRetryable(resp.StatusCode),
|
|
}
|
|
}
|
|
var out map[string]any
|
|
if len(raw) == 0 {
|
|
return map[string]any{}, nil
|
|
}
|
|
if err := json.Unmarshal(raw, &out); err != nil {
|
|
return nil, &ClientError{Code: "invalid_response", Message: err.Error(), Retryable: false}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func usageFromOpenAI(result map[string]any) Usage {
|
|
usage, _ := result["usage"].(map[string]any)
|
|
input := intFromAny(firstPresent(usage["prompt_tokens"], usage["input_tokens"]))
|
|
output := intFromAny(firstPresent(usage["completion_tokens"], usage["output_tokens"]))
|
|
total := intFromAny(usage["total_tokens"])
|
|
if total == 0 {
|
|
total = input + output
|
|
}
|
|
return Usage{InputTokens: input, OutputTokens: output, TotalTokens: total}
|
|
}
|
|
|
|
func intFromAny(value any) int {
|
|
switch typed := value.(type) {
|
|
case float64:
|
|
return int(math.Round(typed))
|
|
case int:
|
|
return typed
|
|
case int64:
|
|
return int(typed)
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func firstPresent(values ...any) any {
|
|
for _, value := range values {
|
|
if value != nil {
|
|
return value
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func errorMessage(raw []byte, fallback string) string {
|
|
if len(raw) == 0 {
|
|
return fallback
|
|
}
|
|
var parsed map[string]any
|
|
if json.Unmarshal(raw, &parsed) == nil {
|
|
if errObj, ok := parsed["error"].(map[string]any); ok {
|
|
if message, ok := errObj["message"].(string); ok {
|
|
return message
|
|
}
|
|
}
|
|
if message, ok := parsed["message"].(string); ok {
|
|
return message
|
|
}
|
|
}
|
|
return string(raw)
|
|
}
|
|
|
|
func statusCodeName(status int) string {
|
|
switch status {
|
|
case http.StatusTooManyRequests:
|
|
return "rate_limit"
|
|
case http.StatusRequestTimeout:
|
|
return "timeout"
|
|
case http.StatusUnauthorized, http.StatusForbidden:
|
|
return "auth_failed"
|
|
default:
|
|
if status >= 500 {
|
|
return "server_error"
|
|
}
|
|
return fmt.Sprintf("http_%d", status)
|
|
}
|
|
}
|
|
|
|
func nowUnix() int64 {
|
|
return time.Now().Unix()
|
|
}
|