easyai-ai-gateway/apps/api/internal/runner/request_assets.go

392 lines
12 KiB
Go

package runner
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func (s *Service) restoreTaskRequestReferences(ctx context.Context, task store.GatewayTask) (map[string]any, error) {
body := cloneMap(task.Request)
if body["messages"] != nil || body["messageRefs"] == nil || s.store == nil {
return body, nil
}
refs, err := s.store.ListTaskConversationMessages(ctx, task.ID)
if err != nil {
return nil, err
}
if len(refs) == 0 {
return body, nil
}
messages := make([]any, 0, len(refs))
for _, ref := range refs {
messages = append(messages, ref.Message)
}
body["messages"] = messages
return body, nil
}
func (s *Service) slimTaskRequestSnapshot(task store.GatewayTask, body map[string]any) map[string]any {
out := cloneMap(body)
messageRefs := task.Request["messageRefs"]
if messageRefs == nil {
return out
}
delete(out, "messages")
out["messageRefs"] = messageRefs
for _, key := range []string{"conversationId", "conversationRecordId", "newMessageCount"} {
if value := task.Request[key]; value != nil {
out[key] = value
}
}
return out
}
func (s *Service) slimParameterPreprocessingLog(task store.GatewayTask, log parameterPreprocessingLog) parameterPreprocessingLog {
log.Input = s.slimTaskRequestSnapshot(task, log.Input)
log.Output = s.slimTaskRequestSnapshot(task, log.Output)
return log
}
type requestAssetHydrationStyle string
const (
requestAssetHydrateURL requestAssetHydrationStyle = "url"
requestAssetHydrateDataURL requestAssetHydrationStyle = "data_url"
requestAssetHydrateRawBase64 requestAssetHydrationStyle = "raw_base64"
)
func (s *Service) hydrateProviderRequestAssets(ctx context.Context, body map[string]any, candidate store.RuntimeModelCandidate) (map[string]any, error) {
value, err := s.hydrateProviderRequestAssetValue(ctx, body, nil, candidate)
if err != nil {
return nil, err
}
out, _ := value.(map[string]any)
if out == nil {
return map[string]any{}, nil
}
return out, nil
}
func (s *Service) hydrateProviderRequestAssetValue(ctx context.Context, value any, path []string, candidate store.RuntimeModelCandidate) (any, error) {
switch typed := value.(type) {
case map[string]any:
if ref, ok := typed["assetRef"].(map[string]any); ok {
return s.hydrateProviderRequestAssetRef(ctx, ref, path, candidate)
}
next := make(map[string]any, len(typed))
for key, item := range typed {
hydrated, err := s.hydrateProviderRequestAssetValue(ctx, item, append(path, key), candidate)
if err != nil {
return nil, err
}
next[key] = hydrated
}
return next, nil
case []any:
next := make([]any, 0, len(typed))
for index, item := range typed {
hydrated, err := s.hydrateProviderRequestAssetValue(ctx, item, append(path, fmt.Sprintf("[%d]", index)), candidate)
if err != nil {
return nil, err
}
next = append(next, hydrated)
}
return next, nil
default:
return value, nil
}
}
func (s *Service) hydrateProviderRequestAssetRef(ctx context.Context, ref map[string]any, path []string, candidate store.RuntimeModelCandidate) (any, error) {
asset, err := s.resolveRequestAsset(ctx, ref)
if err != nil {
return nil, err
}
switch requestAssetHydrationForField(path, candidate) {
case requestAssetHydrateDataURL:
payload, err := s.readRequestAssetBytes(ctx, asset)
if err != nil {
return nil, err
}
contentType := strings.TrimSpace(asset.ContentType)
if contentType == "" {
contentType = "application/octet-stream"
}
return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(payload), nil
case requestAssetHydrateRawBase64:
payload, err := s.readRequestAssetBytes(ctx, asset)
if err != nil {
return nil, err
}
return base64.StdEncoding.EncodeToString(payload), nil
}
if strings.TrimSpace(asset.URL) == "" {
return nil, requestAssetExpiredError(asset)
}
return asset.URL, nil
}
func (s *Service) resolveRequestAsset(ctx context.Context, ref map[string]any) (store.RequestAsset, error) {
sha := stringFromAny(ref["sha256"])
contentType := stringFromAny(ref["contentType"])
asset := store.RequestAsset{
SHA256: sha,
ContentType: contentType,
URL: stringFromAny(ref["url"]),
StorageProvider: stringFromAny(ref["storageProvider"]),
}
if size := floatFromAny(ref["size"]); size > 0 {
asset.ByteSize = int64(size)
}
if expiresAt := stringFromAny(ref["expiresAt"]); expiresAt != "" {
if parsed, err := time.Parse(time.RFC3339, expiresAt); err == nil {
asset.ExpiresAt = &parsed
}
}
if s.store != nil && sha != "" && contentType != "" {
if stored, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
return store.RequestAsset{}, err
} else if ok {
asset = stored
}
}
if requestAssetIsExpired(asset, time.Now()) {
return store.RequestAsset{}, requestAssetExpiredError(asset)
}
return asset, nil
}
func (s *Service) readRequestAssetBytes(ctx context.Context, asset store.RequestAsset) ([]byte, error) {
if requestAssetIsExpired(asset, time.Now()) {
return nil, requestAssetExpiredError(asset)
}
if strings.TrimSpace(asset.LocalPath) != "" {
payload, err := os.ReadFile(asset.LocalPath)
if err != nil {
return nil, requestAssetExpiredError(asset)
}
return payload, nil
}
if localPath := s.localPathFromRequestAssetURL(asset.URL); localPath != "" {
payload, err := os.ReadFile(localPath)
if err != nil {
return nil, requestAssetExpiredError(asset)
}
return payload, nil
}
if strings.HasPrefix(asset.URL, "http://") || strings.HasPrefix(asset.URL, "https://") {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, asset.URL, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: err.Error(), Retryable: true}
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: resp.Status, StatusCode: resp.StatusCode, Retryable: clients.HTTPRetryable(resp.StatusCode)}
}
payload, err := io.ReadAll(io.LimitReader(resp.Body, 256<<20))
if err != nil {
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: err.Error(), Retryable: true}
}
return payload, nil
}
return nil, requestAssetExpiredError(asset)
}
func (s *Service) localPathFromRequestAssetURL(value string) string {
raw := strings.TrimSpace(value)
if raw == "" {
return ""
}
pathValue := raw
if parsed, err := url.Parse(raw); err == nil && parsed.Path != "" {
pathValue = parsed.Path
}
const uploadedPrefix = "/static/uploaded/"
if !strings.HasPrefix(pathValue, uploadedPrefix) {
return ""
}
fileName := filepath.Base(strings.TrimPrefix(pathValue, uploadedPrefix))
if !strings.HasPrefix(fileName, "gateway-request-asset-") {
return ""
}
storageDir := strings.TrimSpace(s.cfg.LocalUploadedStorageDir)
if storageDir == "" {
storageDir = config.DefaultLocalUploadedStorageDir
}
return filepath.Join(storageDir, fileName)
}
func requestAssetHydrationForField(path []string, candidate store.RuntimeModelCandidate) requestAssetHydrationStyle {
if providerFieldNeedsRawBase64(path) {
return requestAssetHydrateRawBase64
}
if mediaURLFieldNeedsHydration(path) {
if style := configuredRequestAssetMediaURLHydration(candidate, requestAssetMediaURLKind(path)); style != "" {
return style
}
if providerMediaURLNeedsDataURL(candidate) {
return requestAssetHydrateDataURL
}
}
return requestAssetHydrateURL
}
func mediaURLFieldNeedsHydration(path []string) bool {
key, parent := requestAssetFieldPath(path)
return key == "url" && (parent == "image_url" || parent == "audio_url" || parent == "video_url" || parent == "file_url")
}
func providerFieldNeedsRawBase64(path []string) bool {
key, parent := requestAssetFieldPath(path)
return key == "b64_json" ||
key == "base64" ||
key == "b64" ||
strings.Contains(key, "base64") ||
strings.Contains(key, "_b64") ||
(parent == "input_audio" && key == "data")
}
func requestAssetMediaURLKind(path []string) string {
_, parent := requestAssetFieldPath(path)
switch parent {
case "image_url":
return "image"
case "audio_url":
return "audio"
case "video_url":
return "video"
case "file_url":
return "file"
default:
return ""
}
}
func configuredRequestAssetMediaURLHydration(candidate store.RuntimeModelCandidate, kind string) requestAssetHydrationStyle {
keys := []string{}
switch kind {
case "image":
keys = append(keys, "requestAssetImageURLFormat", "request_asset_image_url_format")
case "audio":
keys = append(keys, "requestAssetAudioURLFormat", "request_asset_audio_url_format")
case "video":
keys = append(keys, "requestAssetVideoURLFormat", "request_asset_video_url_format")
case "file":
keys = append(keys, "requestAssetFileURLFormat", "request_asset_file_url_format")
}
keys = append(keys,
"requestAssetMediaURLFormat",
"request_asset_media_url_format",
"mediaURLAssetFormat",
"media_url_asset_format",
)
for _, key := range keys {
if style := requestAssetHydrationStyleFromString(stringFromAny(candidate.PlatformConfig[key])); style != "" {
return style
}
}
return ""
}
func requestAssetHydrationStyleFromString(value string) requestAssetHydrationStyle {
normalized := strings.ToLower(strings.TrimSpace(value))
normalized = strings.ReplaceAll(normalized, "-", "_")
normalized = strings.ReplaceAll(normalized, " ", "_")
switch normalized {
case "url", "remote_url", "public_url":
return requestAssetHydrateURL
case "data_url", "dataurl", "prefixed_base64", "base64_with_prefix", "base64_with_data_url_prefix":
return requestAssetHydrateDataURL
case "raw_base64", "base64", "bare_base64", "naked_base64":
return requestAssetHydrateRawBase64
default:
return ""
}
}
func providerMediaURLNeedsDataURL(candidate store.RuntimeModelCandidate) bool {
for _, name := range []string{candidate.Provider, candidate.SpecType, candidate.PlatformKey} {
switch normalizeProviderKey(name) {
case "openai", "volces", "volces_openai", "gemini", "vidu":
return true
}
}
return false
}
func normalizeProviderKey(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
normalized = strings.ReplaceAll(normalized, "-", "_")
normalized = strings.ReplaceAll(normalized, " ", "_")
return normalized
}
func requestAssetFieldPath(path []string) (string, string) {
if len(path) == 0 {
return "", ""
}
names := make([]string, 0, len(path))
for _, segment := range path {
name := strings.ToLower(strings.TrimSpace(strings.Trim(segment, "[]")))
if name == "" || requestAssetPathSegmentIsIndex(name) {
continue
}
names = append(names, name)
}
if len(names) == 0 {
return "", ""
}
key := names[len(names)-1]
parent := ""
if len(names) > 1 {
parent = names[len(names)-2]
}
return key, parent
}
func requestAssetPathSegmentIsIndex(value string) bool {
if value == "" {
return false
}
for _, char := range value {
if char < '0' || char > '9' {
return false
}
}
return true
}
func requestAssetIsExpired(asset store.RequestAsset, now time.Time) bool {
if asset.ExpiredAt != nil {
return true
}
if asset.ExpiresAt != nil && !asset.ExpiresAt.After(now) {
return true
}
return false
}
func requestAssetExpiredError(asset store.RequestAsset) error {
message := "request asset is expired or unavailable"
if asset.SHA256 != "" {
message = "request asset is expired or unavailable: " + asset.SHA256
}
return &clients.ClientError{Code: "request_asset_expired", Message: message, Retryable: false}
}