easyai-ai-gateway/apps/api/internal/httpapi/request_preparation.go

632 lines
18 KiB
Go

package httpapi
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"mime"
"net/http"
"path/filepath"
"sort"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
const requestAssetFilePrefix = "gateway-request-asset-"
type preparedTaskRequest struct {
Body map[string]any
ConversationID string
MessageRefs []store.TaskMessageRefInput
NewMessageCount int
}
type decodedRequestAsset struct {
Bytes []byte
ContentType string
}
type requestAssetOptions struct {
RequirePublicURL bool
UploadScene string
Source string
}
func (s *Server) prepareTaskRequest(ctx context.Context, r *http.Request, user *auth.User, body map[string]any) (preparedTaskRequest, error) {
preparedBody, err := s.prepareRequestAssetRefs(ctx, body)
if err != nil {
return preparedTaskRequest{}, err
}
result := preparedTaskRequest{Body: preparedBody}
conversationKey := requestConversationKey(r, preparedBody)
if conversationKey == "" {
return result, nil
}
messages, ok := preparedBody["messages"].([]any)
if !ok || len(messages) == 0 {
return result, nil
}
conversationID, err := s.store.EnsureConversation(ctx, user, conversationKey, map[string]any{
"source": "ai-gateway",
"conversationKey": conversationKey,
})
if err != nil {
return preparedTaskRequest{}, err
}
inputs := make([]store.ConversationMessageInput, 0, len(messages))
for _, rawMessage := range messages {
message, _ := rawMessage.(map[string]any)
if message == nil {
message = map[string]any{"content": rawMessage}
}
hash, assetHashes := canonicalConversationMessageHash(message)
inputs = append(inputs, store.ConversationMessageInput{
Hash: hash,
Role: stringFromRequestAny(message["role"]),
Snapshot: message,
AssetSHA256s: assetHashes,
})
}
refs, newCount, err := s.store.UpsertConversationMessages(ctx, conversationID, inputs)
if err != nil {
return preparedTaskRequest{}, err
}
preparedBody["conversationId"] = conversationKey
preparedBody["conversationRecordId"] = conversationID
preparedBody["messageRefs"] = messageRefsForRequest(refs)
preparedBody["newMessageCount"] = newCount
delete(preparedBody, "messages")
result.ConversationID = conversationID
result.MessageRefs = refs
result.NewMessageCount = newCount
return result, nil
}
func (s *Server) prepareRequestAssetRefs(ctx context.Context, body map[string]any) (map[string]any, error) {
value, err := s.prepareRequestAssetValue(ctx, body, nil, nil)
if err != nil {
return nil, err
}
next, _ := value.(map[string]any)
if next == nil {
return map[string]any{}, nil
}
return next, nil
}
func (s *Server) prepareRequestAssetValue(ctx context.Context, value any, path []string, siblings map[string]any) (any, error) {
switch typed := value.(type) {
case map[string]any:
if typed["assetRef"] != nil {
return typed, nil
}
next := make(map[string]any, len(typed))
keys := make([]string, 0, len(typed))
for key := range typed {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
item := typed[key]
if decoded, ok, err := requestAssetFromValue(key, path, item, typed); err != nil {
return nil, err
} else if ok {
ref, err := s.ensureRequestAsset(ctx, decoded)
if err != nil {
return nil, err
}
next[key] = requestAssetWrapper(ref)
continue
}
prepared, err := s.prepareRequestAssetValue(ctx, item, append(path, key), typed)
if err != nil {
return nil, err
}
next[key] = prepared
}
return next, nil
case []any:
next := make([]any, 0, len(typed))
for index, item := range typed {
prepared, err := s.prepareRequestAssetValue(ctx, item, append(path, fmt.Sprintf("[%d]", index)), siblings)
if err != nil {
return nil, err
}
next = append(next, prepared)
}
return next, nil
default:
return value, nil
}
}
func requestAssetFromValue(key string, path []string, value any, siblings map[string]any) (decodedRequestAsset, bool, error) {
text, ok := value.(string)
if !ok {
return decodedRequestAsset{}, false, nil
}
raw := strings.TrimSpace(text)
if raw == "" || mediaURLString(raw) {
return decodedRequestAsset{}, false, nil
}
if strings.HasPrefix(strings.ToLower(raw), "data:") {
contentType, encoded, ok, err := parseRequestDataURL(raw)
if err != nil || !ok {
return decodedRequestAsset{}, false, err
}
payload, err := decodeRequestBase64(encoded)
if err != nil {
return decodedRequestAsset{}, false, requestAssetDecodeError(err)
}
return decodedRequestAsset{
Bytes: payload,
ContentType: requestAssetContentType(contentType, payload, key, path, siblings),
}, true, nil
}
strict := strictRequestBase64Field(key, path)
if !strict && !likelyRequestBase64MediaField(key, path, raw) {
return decodedRequestAsset{}, false, nil
}
payload, err := decodeRequestBase64(raw)
if err != nil {
if strict {
return decodedRequestAsset{}, false, requestAssetDecodeError(err)
}
return decodedRequestAsset{}, false, nil
}
contentType := requestAssetContentType("", payload, key, path, siblings)
if !strict && !requestContentTypeIsMedia(contentType) {
return decodedRequestAsset{}, false, nil
}
return decodedRequestAsset{Bytes: payload, ContentType: contentType}, true, nil
}
func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) {
return s.ensureRequestAssetWithOptions(ctx, decoded, requestAssetOptions{
UploadScene: store.FileStorageSceneRequestAsset,
Source: "ai-gateway-request",
})
}
func (s *Server) ensurePublicRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) {
return s.ensureRequestAssetWithOptions(ctx, decoded, requestAssetOptions{
RequirePublicURL: true,
UploadScene: store.FileStorageSceneUpload,
Source: "ai-gateway-form-data",
})
}
func (s *Server) ensureRequestAssetWithOptions(ctx context.Context, decoded decodedRequestAsset, options requestAssetOptions) (map[string]any, error) {
sum := sha256.Sum256(decoded.Bytes)
sha := hex.EncodeToString(sum[:])
contentType := strings.TrimSpace(decoded.ContentType)
if contentType == "" {
contentType = "application/octet-stream"
}
now := time.Now()
if existing, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
return nil, err
} else if ok && requestAssetStillUsable(existing, now) {
ref := requestAssetRef(existing)
if !options.RequirePublicURL || requestAssetRefHasPublicURL(ref) {
if err := s.store.IncrementRequestAssetRefCount(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
return nil, err
}
return ref, nil
}
}
uploadScene := strings.TrimSpace(options.UploadScene)
if uploadScene == "" {
uploadScene = store.FileStorageSceneRequestAsset
}
source := strings.TrimSpace(options.Source)
if source == "" {
source = "ai-gateway-request"
}
upload, err := s.runner.UploadFile(ctx, runner.FileUploadPayload{
Bytes: decoded.Bytes,
ContentType: contentType,
FileName: requestAssetFileName(sha, contentType),
Scene: uploadScene,
Source: source,
})
if err != nil {
return nil, err
}
storageProvider := requestAssetStorageProvider(upload)
url := stringFromRequestAny(upload["url"])
if url == "" {
return nil, &clients.ClientError{Code: "request_asset_upload_failed", Message: "file storage response did not include url", Retryable: false}
}
if options.RequirePublicURL && !requestAssetURLIsPublic(storageProvider, url) {
return nil, &clients.ClientError{Code: "request_asset_public_url_required", Message: "multipart image assets require a public file storage URL; enable a non-local file storage channel for uploads", Retryable: false}
}
var expiresAt *time.Time
localPath := ""
if storageProvider == "local_static" {
expiry := now.Add(time.Duration(s.localTempAssetTTLHours()) * time.Hour)
expiresAt = &expiry
localPath = requestAssetLocalPath(s.cfg.LocalUploadedStorageDir, stringFromRequestAny(upload["fileName"]))
}
asset, err := s.store.UpsertRequestAsset(ctx, store.RequestAssetInput{
SHA256: sha,
ContentType: contentType,
ByteSize: int64(len(decoded.Bytes)),
URL: url,
StorageProvider: storageProvider,
LocalPath: localPath,
ExpiresAt: expiresAt,
})
if err != nil {
if store.IsUndefinedDatabaseObject(err) {
return map[string]any{
"sha256": sha,
"url": url,
"contentType": contentType,
"size": len(decoded.Bytes),
"storageProvider": storageProvider,
"expiresAt": timePtrToRFC3339(expiresAt),
}, nil
}
return nil, err
}
return requestAssetRef(asset), nil
}
func requestConversationKey(r *http.Request, body map[string]any) string {
if r != nil {
if value := strings.TrimSpace(r.Header.Get("X-EasyAI-Conversation-ID")); value != "" {
return value
}
}
if value := firstNonEmptyRequestString(body, "conversation_id", "conversationId"); value != "" {
return value
}
if metadata, ok := body["metadata"].(map[string]any); ok {
return firstNonEmptyRequestString(metadata, "conversation_id", "conversationId")
}
return ""
}
func messageRefsForRequest(refs []store.TaskMessageRefInput) []any {
out := make([]any, 0, len(refs))
for _, ref := range refs {
out = append(out, map[string]any{"messageId": ref.MessageID, "position": ref.Position})
}
return out
}
func canonicalConversationMessageHash(message map[string]any) (string, []string) {
parts := map[string]any{
"role": stringFromRequestAny(message["role"]),
}
texts := []string{}
assetHashes := []string{}
collectConversationMessageParts(message, &texts, &assetHashes)
parts["text"] = strings.Join(texts, "\n")
parts["assets"] = assetHashes
if parts["text"] == "" && len(assetHashes) == 0 {
encoded, _ := json.Marshal(message)
parts["fallback"] = string(encoded)
}
encoded, _ := json.Marshal(parts)
sum := sha256.Sum256(encoded)
return hex.EncodeToString(sum[:]), uniqueRequestStrings(assetHashes)
}
func collectConversationMessageParts(value any, texts *[]string, assetHashes *[]string) {
switch typed := value.(type) {
case map[string]any:
if ref, ok := typed["assetRef"].(map[string]any); ok {
if sha := stringFromRequestAny(ref["sha256"]); sha != "" {
*assetHashes = append(*assetHashes, sha)
}
}
for key, item := range typed {
switch key {
case "content", "text":
if text := stringFromRequestAny(item); text != "" && !strings.HasPrefix(strings.ToLower(text), "data:") {
*texts = append(*texts, text)
continue
}
}
switch item.(type) {
case map[string]any, []any:
collectConversationMessageParts(item, texts, assetHashes)
}
}
case []any:
for _, item := range typed {
collectConversationMessageParts(item, texts, assetHashes)
}
case string:
text := strings.TrimSpace(typed)
if text != "" && !mediaURLString(text) && !strings.HasPrefix(strings.ToLower(text), "data:") {
*texts = append(*texts, text)
}
}
}
func requestAssetWrapper(ref map[string]any) map[string]any {
return map[string]any{
"assetRef": ref,
"url": ref["url"],
}
}
func requestAssetRef(asset store.RequestAsset) map[string]any {
return map[string]any{
"sha256": asset.SHA256,
"url": asset.URL,
"contentType": asset.ContentType,
"size": asset.ByteSize,
"storageProvider": asset.StorageProvider,
"expiresAt": timePtrToRFC3339(asset.ExpiresAt),
}
}
func requestAssetRefHasPublicURL(ref map[string]any) bool {
return requestAssetURLIsPublic(stringFromRequestAny(ref["storageProvider"]), stringFromRequestAny(ref["url"]))
}
func requestAssetURLIsPublic(storageProvider string, url string) bool {
if strings.EqualFold(strings.TrimSpace(storageProvider), "local_static") {
return false
}
lower := strings.ToLower(strings.TrimSpace(url))
return strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://")
}
func requestAssetStillUsable(asset store.RequestAsset, now time.Time) bool {
if asset.ExpiredAt != nil {
return false
}
if asset.ExpiresAt != nil && !asset.ExpiresAt.After(now) {
return false
}
return strings.TrimSpace(asset.URL) != ""
}
func requestAssetStorageProvider(upload map[string]any) string {
if channel, ok := upload["storageChannel"].(map[string]any); ok {
if provider := stringFromRequestAny(channel["provider"]); provider != "" {
return provider
}
}
return "unknown"
}
func requestAssetLocalPath(storageDir string, fileName string) string {
if strings.TrimSpace(storageDir) == "" {
storageDir = config.DefaultLocalUploadedStorageDir
}
if strings.TrimSpace(fileName) == "" {
return ""
}
return filepath.Join(storageDir, filepath.Base(fileName))
}
func requestAssetFileName(sha string, contentType string) string {
prefix := sha
if len(prefix) > 16 {
prefix = prefix[:16]
}
return requestAssetFilePrefix + prefix + requestAssetExtension(contentType)
}
func requestAssetExtension(contentType string) string {
switch strings.ToLower(strings.TrimSpace(contentType)) {
case "image/png":
return ".png"
case "image/jpeg", "image/jpg":
return ".jpg"
case "image/webp":
return ".webp"
case "image/gif":
return ".gif"
case "audio/mpeg":
return ".mp3"
case "audio/wav", "audio/x-wav":
return ".wav"
case "video/mp4":
return ".mp4"
}
if extensions, err := mime.ExtensionsByType(contentType); err == nil && len(extensions) > 0 {
return extensions[0]
}
return ".bin"
}
func requestAssetContentType(explicit string, payload []byte, key string, path []string, siblings map[string]any) string {
if contentType := firstNonEmptyRequestString(siblings, "content_type", "contentType", "mime_type", "mimeType"); contentType != "" {
return contentType
}
if explicit = strings.TrimSpace(explicit); explicit != "" {
return explicit
}
detected := ""
if len(payload) > 0 {
detected = http.DetectContentType(payload)
}
if requestContentTypeIsMedia(detected) {
return detected
}
switch requestMediaKind(key, path, siblings) {
case "audio":
return "audio/mpeg"
case "video":
return "video/mp4"
default:
return "image/png"
}
}
func parseRequestDataURL(value string) (string, string, bool, error) {
prefix, payload, ok := strings.Cut(value, ",")
if !ok {
return "", "", false, &clients.ClientError{Code: "request_asset_decode_failed", Message: "invalid data URL media payload", Retryable: false}
}
meta := strings.TrimPrefix(strings.TrimPrefix(prefix, "data:"), "DATA:")
parts := strings.Split(meta, ";")
contentType := strings.TrimSpace(parts[0])
for _, part := range parts[1:] {
if strings.EqualFold(strings.TrimSpace(part), "base64") {
return contentType, payload, true, nil
}
}
return "", "", false, &clients.ClientError{Code: "request_asset_decode_failed", Message: "data URL media payload is not base64 encoded", Retryable: false}
}
func decodeRequestBase64(value string) ([]byte, error) {
normalized := strings.Map(func(r rune) rune {
switch r {
case '\n', '\r', '\t', ' ':
return -1
default:
return r
}
}, strings.TrimSpace(value))
encodings := []*base64.Encoding{
base64.StdEncoding,
base64.RawStdEncoding,
base64.URLEncoding,
base64.RawURLEncoding,
}
var lastErr error
for _, encoding := range encodings {
payload, err := encoding.DecodeString(normalized)
if err == nil && len(payload) > 0 {
return payload, nil
}
if err != nil {
lastErr = err
}
}
if lastErr == nil {
lastErr = fmt.Errorf("empty base64 payload")
}
return nil, lastErr
}
func requestAssetDecodeError(err error) error {
return &clients.ClientError{Code: "request_asset_decode_failed", Message: err.Error(), Retryable: false}
}
func strictRequestBase64Field(key string, path []string) bool {
lower := strings.ToLower(strings.TrimSpace(key))
parent := ""
if len(path) > 0 {
parent = strings.ToLower(strings.Trim(path[len(path)-1], "[]"))
}
return lower == "b64_json" ||
lower == "base64" ||
lower == "b64" ||
strings.Contains(lower, "base64") ||
strings.Contains(lower, "_b64") ||
(parent == "input_audio" && lower == "data") ||
((parent == "inlinedata" || parent == "inline_data") && lower == "data")
}
func likelyRequestBase64MediaField(key string, path []string, value string) bool {
if len(value) < 64 {
return false
}
return requestMediaKind(key, path, nil) != ""
}
func requestMediaKind(key string, path []string, siblings map[string]any) string {
candidates := append([]string{key}, path...)
if siblings != nil {
candidates = append(candidates, stringFromRequestAny(siblings["type"]), stringFromRequestAny(siblings["role"]))
}
for _, item := range candidates {
lower := strings.ToLower(strings.TrimSpace(item))
switch {
case strings.Contains(lower, "audio"):
return "audio"
case strings.Contains(lower, "video"):
return "video"
case strings.Contains(lower, "image") || lower == "mask" || lower == "b64_json":
return "image"
}
}
return ""
}
func requestContentTypeIsMedia(contentType string) bool {
lower := strings.ToLower(strings.TrimSpace(contentType))
return strings.HasPrefix(lower, "image/") ||
strings.HasPrefix(lower, "video/") ||
strings.HasPrefix(lower, "audio/")
}
func mediaURLString(value string) bool {
raw := strings.TrimSpace(value)
if raw == "" {
return false
}
lower := strings.ToLower(raw)
if strings.HasPrefix(lower, "data:") {
return false
}
return strings.HasPrefix(lower, "http://") ||
strings.HasPrefix(lower, "https://") ||
strings.HasPrefix(lower, "/") ||
strings.Contains(lower, "://")
}
func firstNonEmptyRequestString(values map[string]any, keys ...string) string {
for _, key := range keys {
if values == nil {
continue
}
if value := stringFromRequestAny(values[key]); value != "" {
return value
}
}
return ""
}
func stringFromRequestAny(value any) string {
text, _ := value.(string)
return strings.TrimSpace(text)
}
func uniqueRequestStrings(values []string) []string {
seen := map[string]bool{}
out := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" || seen[value] {
continue
}
seen[value] = true
out = append(out, value)
}
return out
}
func timePtrToRFC3339(value *time.Time) any {
if value == nil {
return nil
}
return value.UTC().Format(time.RFC3339)
}
func (s *Server) localTempAssetTTLHours() int {
if s.cfg.LocalTempAssetTTLHours <= 0 {
return 24
}
return s.cfg.LocalTempAssetTTLHours
}