115 lines
3.3 KiB
Go
115 lines
3.3 KiB
Go
package clients
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type OpenAIClient struct {
|
|
HTTPClient *http.Client
|
|
}
|
|
|
|
func (c OpenAIClient) Run(ctx context.Context, request Request) (Response, error) {
|
|
apiKey := credential(request.Candidate.Credentials, "apiKey", "api_key", "key", "token")
|
|
if apiKey == "" {
|
|
return Response{}, &ClientError{Code: "missing_credentials", Message: "openai api key is required", Retryable: false}
|
|
}
|
|
endpoint := openAIEndpoint(request.Kind)
|
|
if endpoint == "" {
|
|
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported openai request kind", Retryable: false}
|
|
}
|
|
body := cloneBody(request.Body)
|
|
body["model"] = request.Candidate.ModelName
|
|
stream := request.Stream || boolValue(body, "stream")
|
|
raw, _ := json.Marshal(body)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, endpoint), bytes.NewReader(raw))
|
|
if err != nil {
|
|
return Response{}, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
resp, err := httpClient(c.HTTPClient).Do(req)
|
|
if err != nil {
|
|
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
|
}
|
|
responseStartedAt := time.Now()
|
|
requestID := requestIDFromHTTPResponse(resp)
|
|
result, err := decodeOpenAIResponse(resp, stream, request.StreamDelta)
|
|
responseFinishedAt := time.Now()
|
|
if err != nil {
|
|
return Response{}, annotateResponseError(err, requestID, responseStartedAt, responseFinishedAt)
|
|
}
|
|
if requestID == "" {
|
|
requestID = requestIDFromResult(result)
|
|
}
|
|
return Response{
|
|
Result: result,
|
|
RequestID: requestID,
|
|
Usage: usageFromOpenAI(result),
|
|
Progress: providerProgress(request),
|
|
ResponseStartedAt: responseStartedAt,
|
|
ResponseFinishedAt: responseFinishedAt,
|
|
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
|
|
}, nil
|
|
}
|
|
|
|
func decodeOpenAIResponse(resp *http.Response, stream bool, onDelta StreamDelta) (map[string]any, error) {
|
|
if stream {
|
|
result, err := decodeOpenAIStreamResponse(resp, onDelta)
|
|
if err == nil {
|
|
return result, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return decodeHTTPResponse(resp)
|
|
}
|
|
|
|
func openAIEndpoint(kind string) string {
|
|
switch kind {
|
|
case "chat.completions":
|
|
return "/chat/completions"
|
|
case "responses":
|
|
return "/responses"
|
|
case "images.generations":
|
|
return "/images/generations"
|
|
case "images.edits":
|
|
return "/images/edits"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func cloneBody(body map[string]any) map[string]any {
|
|
out := map[string]any{}
|
|
for key, value := range body {
|
|
out[key] = value
|
|
}
|
|
return out
|
|
}
|
|
|
|
func joinURL(base string, path string) string {
|
|
base = strings.TrimRight(strings.TrimSpace(base), "/")
|
|
if base == "" {
|
|
base = "https://api.openai.com/v1"
|
|
}
|
|
return base + path
|
|
}
|
|
|
|
func httpClient(client *http.Client) *http.Client {
|
|
if client != nil {
|
|
return client
|
|
}
|
|
return http.DefaultClient
|
|
}
|
|
|
|
func providerProgress(request Request) []Progress {
|
|
return []Progress{
|
|
{Phase: "submitting", Progress: 0.35, Message: "provider request submitted", Payload: map[string]any{"clientId": request.Candidate.ClientID}},
|
|
{Phase: "fetching_result", Progress: 0.8, Message: "provider response received", Payload: map[string]any{"provider": request.Candidate.Provider}},
|
|
}
|
|
}
|