easyai-ai-gateway/apps/api/internal/clients/openai.go

115 lines
3.3 KiB
Go

package clients
import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"time"
)
type OpenAIClient struct {
HTTPClient *http.Client
}
func (c OpenAIClient) Run(ctx context.Context, request Request) (Response, error) {
apiKey := credential(request.Candidate.Credentials, "apiKey", "api_key", "key", "token")
if apiKey == "" {
return Response{}, &ClientError{Code: "missing_credentials", Message: "openai api key is required", Retryable: false}
}
endpoint := openAIEndpoint(request.Kind)
if endpoint == "" {
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported openai request kind", Retryable: false}
}
body := cloneBody(request.Body)
body["model"] = request.Candidate.ModelName
stream := request.Stream || boolValue(body, "stream")
raw, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, endpoint), bytes.NewReader(raw))
if err != nil {
return Response{}, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := httpClient(c.HTTPClient).Do(req)
if err != nil {
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
responseStartedAt := time.Now()
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeOpenAIResponse(resp, stream, request.StreamDelta)
responseFinishedAt := time.Now()
if err != nil {
return Response{}, annotateResponseError(err, requestID, responseStartedAt, responseFinishedAt)
}
if requestID == "" {
requestID = requestIDFromResult(result)
}
return Response{
Result: result,
RequestID: requestID,
Usage: usageFromOpenAI(result),
Progress: providerProgress(request),
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
}, nil
}
func decodeOpenAIResponse(resp *http.Response, stream bool, onDelta StreamDelta) (map[string]any, error) {
if stream {
result, err := decodeOpenAIStreamResponse(resp, onDelta)
if err == nil {
return result, nil
}
return nil, err
}
return decodeHTTPResponse(resp)
}
func openAIEndpoint(kind string) string {
switch kind {
case "chat.completions":
return "/chat/completions"
case "responses":
return "/responses"
case "images.generations":
return "/images/generations"
case "images.edits":
return "/images/edits"
default:
return ""
}
}
func cloneBody(body map[string]any) map[string]any {
out := map[string]any{}
for key, value := range body {
out[key] = value
}
return out
}
func joinURL(base string, path string) string {
base = strings.TrimRight(strings.TrimSpace(base), "/")
if base == "" {
base = "https://api.openai.com/v1"
}
return base + path
}
func httpClient(client *http.Client) *http.Client {
if client != nil {
return client
}
return http.DefaultClient
}
func providerProgress(request Request) []Progress {
return []Progress{
{Phase: "submitting", Progress: 0.35, Message: "provider request submitted", Payload: map[string]any{"clientId": request.Candidate.ClientID}},
{Phase: "fetching_result", Progress: 0.8, Message: "provider response received", Payload: map[string]any{"provider": request.Candidate.Provider}},
}
}