Add runtime restore and temp asset cleanup
This commit is contained in:
parent
644a6f9d17
commit
d41d9482c7
@ -4,6 +4,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,6 +24,7 @@ type Config struct {
|
|||||||
PublicBaseURL string
|
PublicBaseURL string
|
||||||
LocalGeneratedStorageDir string
|
LocalGeneratedStorageDir string
|
||||||
LocalUploadedStorageDir string
|
LocalUploadedStorageDir string
|
||||||
|
LocalTempAssetTTLHours int
|
||||||
TaskProgressCallbackEnabled bool
|
TaskProgressCallbackEnabled bool
|
||||||
TaskProgressCallbackURL string
|
TaskProgressCallbackURL string
|
||||||
TaskProgressCallbackTimeoutMS string
|
TaskProgressCallbackTimeoutMS string
|
||||||
@ -49,6 +51,7 @@ func Load() Config {
|
|||||||
PublicBaseURL: strings.TrimRight(env("AI_GATEWAY_PUBLIC_BASE_URL", env("PUBLIC_BASE_URL", "")), "/"),
|
PublicBaseURL: strings.TrimRight(env("AI_GATEWAY_PUBLIC_BASE_URL", env("PUBLIC_BASE_URL", "")), "/"),
|
||||||
LocalGeneratedStorageDir: env("AI_GATEWAY_GENERATED_STORAGE_DIR", env("LOCAL_GENERATED_STORAGE_DIR", env("AI_GATEWAY_STATIC_STORAGE_DIR", DefaultLocalGeneratedStorageDir))),
|
LocalGeneratedStorageDir: env("AI_GATEWAY_GENERATED_STORAGE_DIR", env("LOCAL_GENERATED_STORAGE_DIR", env("AI_GATEWAY_STATIC_STORAGE_DIR", DefaultLocalGeneratedStorageDir))),
|
||||||
LocalUploadedStorageDir: env("AI_GATEWAY_UPLOADED_STORAGE_DIR", env("LOCAL_UPLOADED_STORAGE_DIR", DefaultLocalUploadedStorageDir)),
|
LocalUploadedStorageDir: env("AI_GATEWAY_UPLOADED_STORAGE_DIR", env("LOCAL_UPLOADED_STORAGE_DIR", DefaultLocalUploadedStorageDir)),
|
||||||
|
LocalTempAssetTTLHours: envInt("AI_GATEWAY_LOCAL_TEMP_ASSET_TTL_HOURS", 24),
|
||||||
TaskProgressCallbackEnabled: env("TASK_PROGRESS_CALLBACK_ENABLED", "true") == "true",
|
TaskProgressCallbackEnabled: env("TASK_PROGRESS_CALLBACK_ENABLED", "true") == "true",
|
||||||
TaskProgressCallbackURL: env("TASK_PROGRESS_CALLBACK_URL",
|
TaskProgressCallbackURL: env("TASK_PROGRESS_CALLBACK_URL",
|
||||||
strings.TrimRight(env("SERVER_MAIN_BASE_URL", "http://localhost:3000"), "/")+"/internal/platform/task-progress-callbacks",
|
strings.TrimRight(env("SERVER_MAIN_BASE_URL", "http://localhost:3000"), "/")+"/internal/platform/task-progress-callbacks",
|
||||||
@ -136,6 +139,18 @@ func env(key string, fallback string) string {
|
|||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func envInt(key string, fallback int) int {
|
||||||
|
value := envValue(key)
|
||||||
|
if value == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
parsed, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
func logLevel(value string) slog.Level {
|
func logLevel(value string) slog.Level {
|
||||||
switch strings.ToLower(value) {
|
switch strings.ToLower(value) {
|
||||||
case "debug":
|
case "debug":
|
||||||
|
|||||||
@ -916,13 +916,26 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
responsePlan := planTaskResponse(kind, compatible, body, r)
|
responsePlan := planTaskResponse(kind, compatible, body, r)
|
||||||
|
prepared, err := s.prepareTaskRequest(r.Context(), r, user, body)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("prepare task request failed", "kind", kind, "error", err)
|
||||||
|
status := http.StatusBadRequest
|
||||||
|
if code := clients.ErrorCode(err); strings.HasPrefix(code, "upload_") || code == "request_asset_upload_failed" {
|
||||||
|
status = http.StatusBadGateway
|
||||||
|
}
|
||||||
|
writeError(w, status, err.Error(), clients.ErrorCode(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
|
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Model: model,
|
Model: model,
|
||||||
RunMode: runModeFromRequest(body),
|
RunMode: runModeFromRequest(prepared.Body),
|
||||||
Async: responsePlan.asyncMode,
|
Async: responsePlan.asyncMode,
|
||||||
Request: body,
|
Request: prepared.Body,
|
||||||
|
ConversationID: prepared.ConversationID,
|
||||||
|
NewMessageCount: prepared.NewMessageCount,
|
||||||
|
MessageRefs: prepared.MessageRefs,
|
||||||
}, user)
|
}, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("create task failed", "kind", kind, "error", err)
|
s.logger.Error("create task failed", "kind", kind, "error", err)
|
||||||
|
|||||||
70
apps/api/internal/httpapi/local_temp_assets.go
Normal file
70
apps/api/internal/httpapi/local_temp_assets.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package httpapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) startLocalTempAssetCleanup(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
s.cleanupExpiredLocalTempAssets(ctx, time.Now())
|
||||||
|
ticker := time.NewTicker(time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case now := <-ticker.C:
|
||||||
|
s.cleanupExpiredLocalTempAssets(ctx, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) cleanupExpiredLocalTempAssets(ctx context.Context, now time.Time) int {
|
||||||
|
storageDir := strings.TrimSpace(s.cfg.LocalUploadedStorageDir)
|
||||||
|
if storageDir == "" {
|
||||||
|
storageDir = config.DefaultLocalUploadedStorageDir
|
||||||
|
}
|
||||||
|
entries, err := os.ReadDir(storageDir)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
s.logger.Warn("read local temp asset dir failed", "dir", storageDir, "error", err)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
ttl := time.Duration(s.localTempAssetTTLHours()) * time.Hour
|
||||||
|
expiredBefore := now.Add(-ttl)
|
||||||
|
deleted := 0
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || !strings.HasPrefix(entry.Name(), requestAssetFilePrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.ModTime().After(expiredBefore) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
localPath := filepath.Join(storageDir, entry.Name())
|
||||||
|
if err := os.Remove(localPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
|
s.logger.Warn("remove local temp asset failed", "path", localPath, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
deleted++
|
||||||
|
if s.store != nil {
|
||||||
|
if err := s.store.MarkRequestAssetExpiredByLocalPath(ctx, localPath, now); err != nil && !store.IsUndefinedDatabaseObject(err) {
|
||||||
|
s.logger.Warn("mark local temp asset expired failed", "path", localPath, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return deleted
|
||||||
|
}
|
||||||
583
apps/api/internal/httpapi/request_preparation.go
Normal file
583
apps/api/internal/httpapi/request_preparation.go
Normal file
@ -0,0 +1,583 @@
|
|||||||
|
package httpapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
const requestAssetFilePrefix = "gateway-request-asset-"
|
||||||
|
|
||||||
|
type preparedTaskRequest struct {
|
||||||
|
Body map[string]any
|
||||||
|
ConversationID string
|
||||||
|
MessageRefs []store.TaskMessageRefInput
|
||||||
|
NewMessageCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
type decodedRequestAsset struct {
|
||||||
|
Bytes []byte
|
||||||
|
ContentType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) prepareTaskRequest(ctx context.Context, r *http.Request, user *auth.User, body map[string]any) (preparedTaskRequest, error) {
|
||||||
|
preparedBody, err := s.prepareRequestAssetRefs(ctx, body)
|
||||||
|
if err != nil {
|
||||||
|
return preparedTaskRequest{}, err
|
||||||
|
}
|
||||||
|
result := preparedTaskRequest{Body: preparedBody}
|
||||||
|
conversationKey := requestConversationKey(r, preparedBody)
|
||||||
|
if conversationKey == "" {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
messages, ok := preparedBody["messages"].([]any)
|
||||||
|
if !ok || len(messages) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
conversationID, err := s.store.EnsureConversation(ctx, user, conversationKey, map[string]any{
|
||||||
|
"source": "ai-gateway",
|
||||||
|
"conversationKey": conversationKey,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return preparedTaskRequest{}, err
|
||||||
|
}
|
||||||
|
inputs := make([]store.ConversationMessageInput, 0, len(messages))
|
||||||
|
for _, rawMessage := range messages {
|
||||||
|
message, _ := rawMessage.(map[string]any)
|
||||||
|
if message == nil {
|
||||||
|
message = map[string]any{"content": rawMessage}
|
||||||
|
}
|
||||||
|
hash, assetHashes := canonicalConversationMessageHash(message)
|
||||||
|
inputs = append(inputs, store.ConversationMessageInput{
|
||||||
|
Hash: hash,
|
||||||
|
Role: stringFromRequestAny(message["role"]),
|
||||||
|
Snapshot: message,
|
||||||
|
AssetSHA256s: assetHashes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
refs, newCount, err := s.store.UpsertConversationMessages(ctx, conversationID, inputs)
|
||||||
|
if err != nil {
|
||||||
|
return preparedTaskRequest{}, err
|
||||||
|
}
|
||||||
|
preparedBody["conversationId"] = conversationKey
|
||||||
|
preparedBody["conversationRecordId"] = conversationID
|
||||||
|
preparedBody["messageRefs"] = messageRefsForRequest(refs)
|
||||||
|
preparedBody["newMessageCount"] = newCount
|
||||||
|
delete(preparedBody, "messages")
|
||||||
|
result.ConversationID = conversationID
|
||||||
|
result.MessageRefs = refs
|
||||||
|
result.NewMessageCount = newCount
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) prepareRequestAssetRefs(ctx context.Context, body map[string]any) (map[string]any, error) {
|
||||||
|
value, err := s.prepareRequestAssetValue(ctx, body, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
next, _ := value.(map[string]any)
|
||||||
|
if next == nil {
|
||||||
|
return map[string]any{}, nil
|
||||||
|
}
|
||||||
|
return next, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) prepareRequestAssetValue(ctx context.Context, value any, path []string, siblings map[string]any) (any, error) {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if typed["assetRef"] != nil {
|
||||||
|
return typed, nil
|
||||||
|
}
|
||||||
|
next := make(map[string]any, len(typed))
|
||||||
|
keys := make([]string, 0, len(typed))
|
||||||
|
for key := range typed {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
for _, key := range keys {
|
||||||
|
item := typed[key]
|
||||||
|
if decoded, ok, err := requestAssetFromValue(key, path, item, typed); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if ok {
|
||||||
|
ref, err := s.ensureRequestAsset(ctx, decoded)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
next[key] = requestAssetWrapper(ref)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prepared, err := s.prepareRequestAssetValue(ctx, item, append(path, key), typed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
next[key] = prepared
|
||||||
|
}
|
||||||
|
return next, nil
|
||||||
|
case []any:
|
||||||
|
next := make([]any, 0, len(typed))
|
||||||
|
for index, item := range typed {
|
||||||
|
prepared, err := s.prepareRequestAssetValue(ctx, item, append(path, fmt.Sprintf("[%d]", index)), siblings)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
next = append(next, prepared)
|
||||||
|
}
|
||||||
|
return next, nil
|
||||||
|
default:
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetFromValue(key string, path []string, value any, siblings map[string]any) (decodedRequestAsset, bool, error) {
|
||||||
|
text, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return decodedRequestAsset{}, false, nil
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(text)
|
||||||
|
if raw == "" || mediaURLString(raw) {
|
||||||
|
return decodedRequestAsset{}, false, nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(strings.ToLower(raw), "data:") {
|
||||||
|
contentType, encoded, ok, err := parseRequestDataURL(raw)
|
||||||
|
if err != nil || !ok {
|
||||||
|
return decodedRequestAsset{}, false, err
|
||||||
|
}
|
||||||
|
payload, err := decodeRequestBase64(encoded)
|
||||||
|
if err != nil {
|
||||||
|
return decodedRequestAsset{}, false, requestAssetDecodeError(err)
|
||||||
|
}
|
||||||
|
return decodedRequestAsset{
|
||||||
|
Bytes: payload,
|
||||||
|
ContentType: requestAssetContentType(contentType, payload, key, path, siblings),
|
||||||
|
}, true, nil
|
||||||
|
}
|
||||||
|
strict := strictRequestBase64Field(key, path)
|
||||||
|
if !strict && !likelyRequestBase64MediaField(key, path, raw) {
|
||||||
|
return decodedRequestAsset{}, false, nil
|
||||||
|
}
|
||||||
|
payload, err := decodeRequestBase64(raw)
|
||||||
|
if err != nil {
|
||||||
|
if strict {
|
||||||
|
return decodedRequestAsset{}, false, requestAssetDecodeError(err)
|
||||||
|
}
|
||||||
|
return decodedRequestAsset{}, false, nil
|
||||||
|
}
|
||||||
|
contentType := requestAssetContentType("", payload, key, path, siblings)
|
||||||
|
if !strict && !requestContentTypeIsMedia(contentType) {
|
||||||
|
return decodedRequestAsset{}, false, nil
|
||||||
|
}
|
||||||
|
return decodedRequestAsset{Bytes: payload, ContentType: contentType}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) {
|
||||||
|
sum := sha256.Sum256(decoded.Bytes)
|
||||||
|
sha := hex.EncodeToString(sum[:])
|
||||||
|
contentType := strings.TrimSpace(decoded.ContentType)
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if existing, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
|
||||||
|
return nil, err
|
||||||
|
} else if ok && requestAssetStillUsable(existing, now) {
|
||||||
|
if err := s.store.IncrementRequestAssetRefCount(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return requestAssetRef(existing), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
upload, err := s.runner.UploadFile(ctx, runner.FileUploadPayload{
|
||||||
|
Bytes: decoded.Bytes,
|
||||||
|
ContentType: contentType,
|
||||||
|
FileName: requestAssetFileName(sha, contentType),
|
||||||
|
Scene: store.FileStorageSceneRequestAsset,
|
||||||
|
Source: "ai-gateway-request",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
storageProvider := requestAssetStorageProvider(upload)
|
||||||
|
url := stringFromRequestAny(upload["url"])
|
||||||
|
if url == "" {
|
||||||
|
return nil, &clients.ClientError{Code: "request_asset_upload_failed", Message: "file storage response did not include url", Retryable: false}
|
||||||
|
}
|
||||||
|
var expiresAt *time.Time
|
||||||
|
localPath := ""
|
||||||
|
if storageProvider == "local_static" {
|
||||||
|
expiry := now.Add(time.Duration(s.localTempAssetTTLHours()) * time.Hour)
|
||||||
|
expiresAt = &expiry
|
||||||
|
localPath = requestAssetLocalPath(s.cfg.LocalUploadedStorageDir, stringFromRequestAny(upload["fileName"]))
|
||||||
|
}
|
||||||
|
asset, err := s.store.UpsertRequestAsset(ctx, store.RequestAssetInput{
|
||||||
|
SHA256: sha,
|
||||||
|
ContentType: contentType,
|
||||||
|
ByteSize: int64(len(decoded.Bytes)),
|
||||||
|
URL: url,
|
||||||
|
StorageProvider: storageProvider,
|
||||||
|
LocalPath: localPath,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if store.IsUndefinedDatabaseObject(err) {
|
||||||
|
return map[string]any{
|
||||||
|
"sha256": sha,
|
||||||
|
"url": url,
|
||||||
|
"contentType": contentType,
|
||||||
|
"size": len(decoded.Bytes),
|
||||||
|
"storageProvider": storageProvider,
|
||||||
|
"expiresAt": timePtrToRFC3339(expiresAt),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return requestAssetRef(asset), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestConversationKey(r *http.Request, body map[string]any) string {
|
||||||
|
if r != nil {
|
||||||
|
if value := strings.TrimSpace(r.Header.Get("X-EasyAI-Conversation-ID")); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value := firstNonEmptyRequestString(body, "conversation_id", "conversationId"); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if metadata, ok := body["metadata"].(map[string]any); ok {
|
||||||
|
return firstNonEmptyRequestString(metadata, "conversation_id", "conversationId")
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func messageRefsForRequest(refs []store.TaskMessageRefInput) []any {
|
||||||
|
out := make([]any, 0, len(refs))
|
||||||
|
for _, ref := range refs {
|
||||||
|
out = append(out, map[string]any{"messageId": ref.MessageID, "position": ref.Position})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func canonicalConversationMessageHash(message map[string]any) (string, []string) {
|
||||||
|
parts := map[string]any{
|
||||||
|
"role": stringFromRequestAny(message["role"]),
|
||||||
|
}
|
||||||
|
texts := []string{}
|
||||||
|
assetHashes := []string{}
|
||||||
|
collectConversationMessageParts(message, &texts, &assetHashes)
|
||||||
|
parts["text"] = strings.Join(texts, "\n")
|
||||||
|
parts["assets"] = assetHashes
|
||||||
|
if parts["text"] == "" && len(assetHashes) == 0 {
|
||||||
|
encoded, _ := json.Marshal(message)
|
||||||
|
parts["fallback"] = string(encoded)
|
||||||
|
}
|
||||||
|
encoded, _ := json.Marshal(parts)
|
||||||
|
sum := sha256.Sum256(encoded)
|
||||||
|
return hex.EncodeToString(sum[:]), uniqueRequestStrings(assetHashes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectConversationMessageParts(value any, texts *[]string, assetHashes *[]string) {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if ref, ok := typed["assetRef"].(map[string]any); ok {
|
||||||
|
if sha := stringFromRequestAny(ref["sha256"]); sha != "" {
|
||||||
|
*assetHashes = append(*assetHashes, sha)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for key, item := range typed {
|
||||||
|
switch key {
|
||||||
|
case "content", "text":
|
||||||
|
if text := stringFromRequestAny(item); text != "" && !strings.HasPrefix(strings.ToLower(text), "data:") {
|
||||||
|
*texts = append(*texts, text)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch item.(type) {
|
||||||
|
case map[string]any, []any:
|
||||||
|
collectConversationMessageParts(item, texts, assetHashes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, item := range typed {
|
||||||
|
collectConversationMessageParts(item, texts, assetHashes)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
text := strings.TrimSpace(typed)
|
||||||
|
if text != "" && !mediaURLString(text) && !strings.HasPrefix(strings.ToLower(text), "data:") {
|
||||||
|
*texts = append(*texts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetWrapper(ref map[string]any) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"assetRef": ref,
|
||||||
|
"url": ref["url"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetRef(asset store.RequestAsset) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"sha256": asset.SHA256,
|
||||||
|
"url": asset.URL,
|
||||||
|
"contentType": asset.ContentType,
|
||||||
|
"size": asset.ByteSize,
|
||||||
|
"storageProvider": asset.StorageProvider,
|
||||||
|
"expiresAt": timePtrToRFC3339(asset.ExpiresAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetStillUsable(asset store.RequestAsset, now time.Time) bool {
|
||||||
|
if asset.ExpiredAt != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if asset.ExpiresAt != nil && !asset.ExpiresAt.After(now) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(asset.URL) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetStorageProvider(upload map[string]any) string {
|
||||||
|
if channel, ok := upload["storageChannel"].(map[string]any); ok {
|
||||||
|
if provider := stringFromRequestAny(channel["provider"]); provider != "" {
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetLocalPath(storageDir string, fileName string) string {
|
||||||
|
if strings.TrimSpace(storageDir) == "" {
|
||||||
|
storageDir = config.DefaultLocalUploadedStorageDir
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(fileName) == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return filepath.Join(storageDir, filepath.Base(fileName))
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetFileName(sha string, contentType string) string {
|
||||||
|
prefix := sha
|
||||||
|
if len(prefix) > 16 {
|
||||||
|
prefix = prefix[:16]
|
||||||
|
}
|
||||||
|
return requestAssetFilePrefix + prefix + requestAssetExtension(contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetExtension(contentType string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(contentType)) {
|
||||||
|
case "image/png":
|
||||||
|
return ".png"
|
||||||
|
case "image/jpeg", "image/jpg":
|
||||||
|
return ".jpg"
|
||||||
|
case "image/webp":
|
||||||
|
return ".webp"
|
||||||
|
case "image/gif":
|
||||||
|
return ".gif"
|
||||||
|
case "audio/mpeg":
|
||||||
|
return ".mp3"
|
||||||
|
case "audio/wav", "audio/x-wav":
|
||||||
|
return ".wav"
|
||||||
|
case "video/mp4":
|
||||||
|
return ".mp4"
|
||||||
|
}
|
||||||
|
if extensions, err := mime.ExtensionsByType(contentType); err == nil && len(extensions) > 0 {
|
||||||
|
return extensions[0]
|
||||||
|
}
|
||||||
|
return ".bin"
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetContentType(explicit string, payload []byte, key string, path []string, siblings map[string]any) string {
|
||||||
|
if contentType := firstNonEmptyRequestString(siblings, "content_type", "contentType", "mime_type", "mimeType"); contentType != "" {
|
||||||
|
return contentType
|
||||||
|
}
|
||||||
|
if explicit = strings.TrimSpace(explicit); explicit != "" {
|
||||||
|
return explicit
|
||||||
|
}
|
||||||
|
detected := ""
|
||||||
|
if len(payload) > 0 {
|
||||||
|
detected = http.DetectContentType(payload)
|
||||||
|
}
|
||||||
|
if requestContentTypeIsMedia(detected) {
|
||||||
|
return detected
|
||||||
|
}
|
||||||
|
switch requestMediaKind(key, path, siblings) {
|
||||||
|
case "audio":
|
||||||
|
return "audio/mpeg"
|
||||||
|
case "video":
|
||||||
|
return "video/mp4"
|
||||||
|
default:
|
||||||
|
return "image/png"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRequestDataURL(value string) (string, string, bool, error) {
|
||||||
|
prefix, payload, ok := strings.Cut(value, ",")
|
||||||
|
if !ok {
|
||||||
|
return "", "", false, &clients.ClientError{Code: "request_asset_decode_failed", Message: "invalid data URL media payload", Retryable: false}
|
||||||
|
}
|
||||||
|
meta := strings.TrimPrefix(strings.TrimPrefix(prefix, "data:"), "DATA:")
|
||||||
|
parts := strings.Split(meta, ";")
|
||||||
|
contentType := strings.TrimSpace(parts[0])
|
||||||
|
for _, part := range parts[1:] {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(part), "base64") {
|
||||||
|
return contentType, payload, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", "", false, &clients.ClientError{Code: "request_asset_decode_failed", Message: "data URL media payload is not base64 encoded", Retryable: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeRequestBase64(value string) ([]byte, error) {
|
||||||
|
normalized := strings.Map(func(r rune) rune {
|
||||||
|
switch r {
|
||||||
|
case '\n', '\r', '\t', ' ':
|
||||||
|
return -1
|
||||||
|
default:
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
}, strings.TrimSpace(value))
|
||||||
|
encodings := []*base64.Encoding{
|
||||||
|
base64.StdEncoding,
|
||||||
|
base64.RawStdEncoding,
|
||||||
|
base64.URLEncoding,
|
||||||
|
base64.RawURLEncoding,
|
||||||
|
}
|
||||||
|
var lastErr error
|
||||||
|
for _, encoding := range encodings {
|
||||||
|
payload, err := encoding.DecodeString(normalized)
|
||||||
|
if err == nil && len(payload) > 0 {
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastErr == nil {
|
||||||
|
lastErr = fmt.Errorf("empty base64 payload")
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetDecodeError(err error) error {
|
||||||
|
return &clients.ClientError{Code: "request_asset_decode_failed", Message: err.Error(), Retryable: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
func strictRequestBase64Field(key string, path []string) bool {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(key))
|
||||||
|
parent := ""
|
||||||
|
if len(path) > 0 {
|
||||||
|
parent = strings.ToLower(strings.Trim(path[len(path)-1], "[]"))
|
||||||
|
}
|
||||||
|
return lower == "b64_json" ||
|
||||||
|
lower == "base64" ||
|
||||||
|
lower == "b64" ||
|
||||||
|
strings.Contains(lower, "base64") ||
|
||||||
|
strings.Contains(lower, "_b64") ||
|
||||||
|
(parent == "input_audio" && lower == "data")
|
||||||
|
}
|
||||||
|
|
||||||
|
func likelyRequestBase64MediaField(key string, path []string, value string) bool {
|
||||||
|
if len(value) < 64 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return requestMediaKind(key, path, nil) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestMediaKind(key string, path []string, siblings map[string]any) string {
|
||||||
|
candidates := append([]string{key}, path...)
|
||||||
|
if siblings != nil {
|
||||||
|
candidates = append(candidates, stringFromRequestAny(siblings["type"]), stringFromRequestAny(siblings["role"]))
|
||||||
|
}
|
||||||
|
for _, item := range candidates {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(item))
|
||||||
|
switch {
|
||||||
|
case strings.Contains(lower, "audio"):
|
||||||
|
return "audio"
|
||||||
|
case strings.Contains(lower, "video"):
|
||||||
|
return "video"
|
||||||
|
case strings.Contains(lower, "image") || lower == "mask" || lower == "b64_json":
|
||||||
|
return "image"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestContentTypeIsMedia(contentType string) bool {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(contentType))
|
||||||
|
return strings.HasPrefix(lower, "image/") ||
|
||||||
|
strings.HasPrefix(lower, "video/") ||
|
||||||
|
strings.HasPrefix(lower, "audio/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func mediaURLString(value string) bool {
|
||||||
|
raw := strings.TrimSpace(value)
|
||||||
|
if raw == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(raw)
|
||||||
|
if strings.HasPrefix(lower, "data:") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.HasPrefix(lower, "http://") ||
|
||||||
|
strings.HasPrefix(lower, "https://") ||
|
||||||
|
strings.HasPrefix(lower, "/") ||
|
||||||
|
strings.Contains(lower, "://")
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmptyRequestString(values map[string]any, keys ...string) string {
|
||||||
|
for _, key := range keys {
|
||||||
|
if values == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if value := stringFromRequestAny(values[key]); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringFromRequestAny(value any) string {
|
||||||
|
text, _ := value.(string)
|
||||||
|
return strings.TrimSpace(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uniqueRequestStrings(values []string) []string {
|
||||||
|
seen := map[string]bool{}
|
||||||
|
out := make([]string, 0, len(values))
|
||||||
|
for _, value := range values {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" || seen[value] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[value] = true
|
||||||
|
out = append(out, value)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func timePtrToRFC3339(value *time.Time) any {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return value.UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) localTempAssetTTLHours() int {
|
||||||
|
if s.cfg.LocalTempAssetTTLHours <= 0 {
|
||||||
|
return 24
|
||||||
|
}
|
||||||
|
return s.cfg.LocalTempAssetTTLHours
|
||||||
|
}
|
||||||
136
apps/api/internal/httpapi/request_preparation_test.go
Normal file
136
apps/api/internal/httpapi/request_preparation_test.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package httpapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestAssetFromValueDetectsDataURLAndRawBase64(t *testing.T) {
|
||||||
|
payload := base64.StdEncoding.EncodeToString([]byte("inline image"))
|
||||||
|
decoded, ok, err := requestAssetFromValue("url", []string{"messages", "[0]", "content", "[1]", "image_url"}, "data:image/png;base64,"+payload, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode data URL: %v", err)
|
||||||
|
}
|
||||||
|
if !ok || decoded.ContentType != "image/png" || string(decoded.Bytes) != "inline image" {
|
||||||
|
t.Fatalf("unexpected data URL asset: ok=%v decoded=%+v", ok, decoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
audio := base64.StdEncoding.EncodeToString([]byte("inline audio"))
|
||||||
|
decoded, ok, err = requestAssetFromValue("data", []string{"input_audio"}, audio, map[string]any{"format": "mp3"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode raw audio: %v", err)
|
||||||
|
}
|
||||||
|
if !ok || decoded.ContentType != "audio/mpeg" || string(decoded.Bytes) != "inline audio" {
|
||||||
|
t.Fatalf("unexpected raw audio asset: ok=%v decoded=%+v", ok, decoded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanonicalConversationMessageHashUsesTextAndAssetRefs(t *testing.T) {
|
||||||
|
message := map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "describe it"},
|
||||||
|
map[string]any{"type": "image_url", "image_url": map[string]any{
|
||||||
|
"url": "https://cdn.example/a.png",
|
||||||
|
"assetRef": map[string]any{"sha256": "sha-a", "url": "https://cdn.example/a.png"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sameMessage := map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "describe it"},
|
||||||
|
map[string]any{"type": "image_url", "image_url": map[string]any{
|
||||||
|
"url": "https://different.example/a.png",
|
||||||
|
"assetRef": map[string]any{"sha256": "sha-a", "url": "https://different.example/a.png"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
changedMessage := map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": "describe something else",
|
||||||
|
}
|
||||||
|
|
||||||
|
firstHash, assetHashes := canonicalConversationMessageHash(message)
|
||||||
|
secondHash, _ := canonicalConversationMessageHash(sameMessage)
|
||||||
|
changedHash, _ := canonicalConversationMessageHash(changedMessage)
|
||||||
|
if firstHash != secondHash {
|
||||||
|
t.Fatalf("message hash should ignore resource URL drift when asset sha is stable")
|
||||||
|
}
|
||||||
|
if firstHash == changedHash {
|
||||||
|
t.Fatalf("message hash should change when text changes")
|
||||||
|
}
|
||||||
|
if len(assetHashes) != 1 || assetHashes[0] != "sha-a" {
|
||||||
|
t.Fatalf("unexpected asset hashes: %+v", assetHashes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpiredLocalTempAssetsOnlyDeletesExpiredPrefixedFiles(t *testing.T) {
|
||||||
|
storageDir := t.TempDir()
|
||||||
|
oldTemp := filepath.Join(storageDir, requestAssetFilePrefix+"old.png")
|
||||||
|
freshTemp := filepath.Join(storageDir, requestAssetFilePrefix+"fresh.png")
|
||||||
|
oldGenerated := filepath.Join(storageDir, "gateway-result-old.png")
|
||||||
|
for _, path := range []string{oldTemp, freshTemp, oldGenerated} {
|
||||||
|
if err := os.WriteFile(path, []byte("asset"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write fixture %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if err := os.Chtimes(oldTemp, now.Add(-25*time.Hour), now.Add(-25*time.Hour)); err != nil {
|
||||||
|
t.Fatalf("touch old temp: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.Chtimes(freshTemp, now.Add(-23*time.Hour), now.Add(-23*time.Hour)); err != nil {
|
||||||
|
t.Fatalf("touch fresh temp: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.Chtimes(oldGenerated, now.Add(-25*time.Hour), now.Add(-25*time.Hour)); err != nil {
|
||||||
|
t.Fatalf("touch old generated: %v", err)
|
||||||
|
}
|
||||||
|
server := &Server{
|
||||||
|
cfg: config.Config{LocalUploadedStorageDir: storageDir, LocalTempAssetTTLHours: 24},
|
||||||
|
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted := server.cleanupExpiredLocalTempAssets(context.Background(), now)
|
||||||
|
|
||||||
|
if deleted != 1 {
|
||||||
|
t.Fatalf("expected one expired temp asset delete, got %d", deleted)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(oldTemp); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("old prefixed temp asset should be deleted, stat err=%v", err)
|
||||||
|
}
|
||||||
|
for _, path := range []string{freshTemp, oldGenerated} {
|
||||||
|
if _, err := os.Stat(path); err != nil {
|
||||||
|
t.Fatalf("non-expired or non-prefixed file should remain %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestConversationKeyPriority(t *testing.T) {
|
||||||
|
request := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
||||||
|
request.Header.Set("X-EasyAI-Conversation-ID", "from-header")
|
||||||
|
body := map[string]any{
|
||||||
|
"conversation_id": "from-body",
|
||||||
|
"metadata": map[string]any{"conversation_id": "from-metadata"},
|
||||||
|
}
|
||||||
|
if got := requestConversationKey(request, body); got != "from-header" {
|
||||||
|
t.Fatalf("expected header conversation id, got %q", got)
|
||||||
|
}
|
||||||
|
request.Header.Del("X-EasyAI-Conversation-ID")
|
||||||
|
if got := requestConversationKey(request, body); got != "from-body" {
|
||||||
|
t.Fatalf("expected body conversation id, got %q", got)
|
||||||
|
}
|
||||||
|
delete(body, "conversation_id")
|
||||||
|
if got := requestConversationKey(request, body); got != "from-metadata" {
|
||||||
|
t.Fatalf("expected metadata conversation id, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -134,6 +134,33 @@ func (s *Server) updatePlatformDynamicPriority(w http.ResponseWriter, r *http.Re
|
|||||||
writeJSON(w, http.StatusOK, item)
|
writeJSON(w, http.StatusOK, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// restorePlatformModelRuntimeStatus godoc
|
||||||
|
// @Summary 恢复平台模型运行状态
|
||||||
|
// @Description 管理端手动解除平台模型停用、模型冷却、平台冷却或平台禁用状态,使其重新参与路由。
|
||||||
|
// @Tags runtime
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param platformModelID path string true "平台模型 ID"
|
||||||
|
// @Success 200 {object} store.ModelRateLimitStatus
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 404 {object} ErrorEnvelope
|
||||||
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
|
// @Router /api/admin/runtime/model-rate-limits/{platformModelID}/restore [post]
|
||||||
|
func (s *Server) restorePlatformModelRuntimeStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
item, err := s.store.RestorePlatformModelRuntimeStatus(r.Context(), r.PathValue("platformModelID"))
|
||||||
|
if err != nil {
|
||||||
|
if store.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "platform model not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.logger.Error("restore platform model runtime status failed", "error", err)
|
||||||
|
writeError(w, http.StatusInternalServerError, "restore platform model runtime status failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, item)
|
||||||
|
}
|
||||||
|
|
||||||
// createRuntimePolicySet godoc
|
// createRuntimePolicySet godoc
|
||||||
// @Summary 创建运行策略集
|
// @Summary 创建运行策略集
|
||||||
// @Description 管理端创建运行策略集,policyKey 和 name 必填。
|
// @Description 管理端创建运行策略集,policyKey 和 name 必填。
|
||||||
|
|||||||
@ -36,6 +36,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
|||||||
}
|
}
|
||||||
server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey
|
server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey
|
||||||
server.runner.StartAsyncQueueWorker(ctx)
|
server.runner.StartAsyncQueueWorker(ctx)
|
||||||
|
server.startLocalTempAssetCleanup(ctx)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("GET /healthz", server.health)
|
mux.HandleFunc("GET /healthz", server.health)
|
||||||
@ -103,6 +104,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
|||||||
mux.Handle("DELETE /api/admin/runtime/policy-sets/{policySetID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteRuntimePolicySet)))
|
mux.Handle("DELETE /api/admin/runtime/policy-sets/{policySetID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteRuntimePolicySet)))
|
||||||
mux.Handle("GET /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getRunnerPolicy)))
|
mux.Handle("GET /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getRunnerPolicy)))
|
||||||
mux.Handle("PATCH /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateRunnerPolicy)))
|
mux.Handle("PATCH /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateRunnerPolicy)))
|
||||||
|
mux.Handle("POST /api/admin/runtime/model-rate-limits/{platformModelID}/restore", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.restorePlatformModelRuntimeStatus)))
|
||||||
mux.Handle("GET /api/admin/config/network-proxy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getNetworkProxyConfig)))
|
mux.Handle("GET /api/admin/config/network-proxy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getNetworkProxyConfig)))
|
||||||
mux.Handle("GET /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getFileStorageSettings)))
|
mux.Handle("GET /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getFileStorageSettings)))
|
||||||
mux.Handle("PATCH /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateFileStorageSettings)))
|
mux.Handle("PATCH /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateFileStorageSettings)))
|
||||||
@ -173,7 +175,7 @@ func (s *Server) cors(next http.Handler) http.Handler {
|
|||||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
w.Header().Set("Vary", "Origin")
|
w.Header().Set("Vary", "Origin")
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async")
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async, X-EasyAI-Conversation-ID")
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||||
}
|
}
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
|
|||||||
@ -11,8 +11,18 @@ func stringFromMap(values map[string]any, key string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func stringFromAny(value any) string {
|
func stringFromAny(value any) string {
|
||||||
text, _ := value.(string)
|
switch typed := value.(type) {
|
||||||
return strings.TrimSpace(text)
|
case string:
|
||||||
|
return strings.TrimSpace(typed)
|
||||||
|
case map[string]any:
|
||||||
|
if text := stringFromAny(typed["url"]); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
if ref, ok := typed["assetRef"].(map[string]any); ok {
|
||||||
|
return stringFromAny(ref["url"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func boolFromMap(values map[string]any, key string) bool {
|
func boolFromMap(values map[string]any, key string) bool {
|
||||||
|
|||||||
249
apps/api/internal/runner/request_assets.go
Normal file
249
apps/api/internal/runner/request_assets.go
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) restoreTaskRequestReferences(ctx context.Context, task store.GatewayTask) (map[string]any, error) {
|
||||||
|
body := cloneMap(task.Request)
|
||||||
|
if body["messages"] != nil || body["messageRefs"] == nil || s.store == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
refs, err := s.store.ListTaskConversationMessages(ctx, task.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(refs) == 0 {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
messages := make([]any, 0, len(refs))
|
||||||
|
for _, ref := range refs {
|
||||||
|
messages = append(messages, ref.Message)
|
||||||
|
}
|
||||||
|
body["messages"] = messages
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) slimTaskRequestSnapshot(task store.GatewayTask, body map[string]any) map[string]any {
|
||||||
|
out := cloneMap(body)
|
||||||
|
messageRefs := task.Request["messageRefs"]
|
||||||
|
if messageRefs == nil {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
delete(out, "messages")
|
||||||
|
out["messageRefs"] = messageRefs
|
||||||
|
for _, key := range []string{"conversationId", "conversationRecordId", "newMessageCount"} {
|
||||||
|
if value := task.Request[key]; value != nil {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) slimParameterPreprocessingLog(task store.GatewayTask, log parameterPreprocessingLog) parameterPreprocessingLog {
|
||||||
|
log.Input = s.slimTaskRequestSnapshot(task, log.Input)
|
||||||
|
log.Output = s.slimTaskRequestSnapshot(task, log.Output)
|
||||||
|
return log
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) hydrateProviderRequestAssets(ctx context.Context, body map[string]any) (map[string]any, error) {
|
||||||
|
value, err := s.hydrateProviderRequestAssetValue(ctx, body, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out, _ := value.(map[string]any)
|
||||||
|
if out == nil {
|
||||||
|
return map[string]any{}, nil
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) hydrateProviderRequestAssetValue(ctx context.Context, value any, path []string) (any, error) {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if ref, ok := typed["assetRef"].(map[string]any); ok {
|
||||||
|
return s.hydrateProviderRequestAssetRef(ctx, ref, path)
|
||||||
|
}
|
||||||
|
next := make(map[string]any, len(typed))
|
||||||
|
for key, item := range typed {
|
||||||
|
hydrated, err := s.hydrateProviderRequestAssetValue(ctx, item, append(path, key))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
next[key] = hydrated
|
||||||
|
}
|
||||||
|
return next, nil
|
||||||
|
case []any:
|
||||||
|
next := make([]any, 0, len(typed))
|
||||||
|
for index, item := range typed {
|
||||||
|
hydrated, err := s.hydrateProviderRequestAssetValue(ctx, item, append(path, fmt.Sprintf("[%d]", index)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
next = append(next, hydrated)
|
||||||
|
}
|
||||||
|
return next, nil
|
||||||
|
default:
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) hydrateProviderRequestAssetRef(ctx context.Context, ref map[string]any, path []string) (any, error) {
|
||||||
|
asset, err := s.resolveRequestAsset(ctx, ref)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if providerFieldNeedsBase64(path) {
|
||||||
|
payload, err := s.readRequestAssetBytes(ctx, asset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(payload), nil
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(asset.URL) == "" {
|
||||||
|
return nil, requestAssetExpiredError(asset)
|
||||||
|
}
|
||||||
|
return asset.URL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) resolveRequestAsset(ctx context.Context, ref map[string]any) (store.RequestAsset, error) {
|
||||||
|
sha := stringFromAny(ref["sha256"])
|
||||||
|
contentType := stringFromAny(ref["contentType"])
|
||||||
|
asset := store.RequestAsset{
|
||||||
|
SHA256: sha,
|
||||||
|
ContentType: contentType,
|
||||||
|
URL: stringFromAny(ref["url"]),
|
||||||
|
StorageProvider: stringFromAny(ref["storageProvider"]),
|
||||||
|
}
|
||||||
|
if size := floatFromAny(ref["size"]); size > 0 {
|
||||||
|
asset.ByteSize = int64(size)
|
||||||
|
}
|
||||||
|
if expiresAt := stringFromAny(ref["expiresAt"]); expiresAt != "" {
|
||||||
|
if parsed, err := time.Parse(time.RFC3339, expiresAt); err == nil {
|
||||||
|
asset.ExpiresAt = &parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.store != nil && sha != "" && contentType != "" {
|
||||||
|
if stored, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
|
||||||
|
return store.RequestAsset{}, err
|
||||||
|
} else if ok {
|
||||||
|
asset = stored
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if requestAssetIsExpired(asset, time.Now()) {
|
||||||
|
return store.RequestAsset{}, requestAssetExpiredError(asset)
|
||||||
|
}
|
||||||
|
return asset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) readRequestAssetBytes(ctx context.Context, asset store.RequestAsset) ([]byte, error) {
|
||||||
|
if requestAssetIsExpired(asset, time.Now()) {
|
||||||
|
return nil, requestAssetExpiredError(asset)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(asset.LocalPath) != "" {
|
||||||
|
payload, err := os.ReadFile(asset.LocalPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, requestAssetExpiredError(asset)
|
||||||
|
}
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
if localPath := s.localPathFromRequestAssetURL(asset.URL); localPath != "" {
|
||||||
|
payload, err := os.ReadFile(localPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, requestAssetExpiredError(asset)
|
||||||
|
}
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(asset.URL, "http://") || strings.HasPrefix(asset.URL, "https://") {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, asset.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: resp.Status, StatusCode: resp.StatusCode, Retryable: clients.HTTPRetryable(resp.StatusCode)}
|
||||||
|
}
|
||||||
|
payload, err := io.ReadAll(io.LimitReader(resp.Body, 256<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
return nil, requestAssetExpiredError(asset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) localPathFromRequestAssetURL(value string) string {
|
||||||
|
raw := strings.TrimSpace(value)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
pathValue := raw
|
||||||
|
if parsed, err := url.Parse(raw); err == nil && parsed.Path != "" {
|
||||||
|
pathValue = parsed.Path
|
||||||
|
}
|
||||||
|
const uploadedPrefix = "/static/uploaded/"
|
||||||
|
if !strings.HasPrefix(pathValue, uploadedPrefix) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
fileName := filepath.Base(strings.TrimPrefix(pathValue, uploadedPrefix))
|
||||||
|
if !strings.HasPrefix(fileName, "gateway-request-asset-") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
storageDir := strings.TrimSpace(s.cfg.LocalUploadedStorageDir)
|
||||||
|
if storageDir == "" {
|
||||||
|
storageDir = config.DefaultLocalUploadedStorageDir
|
||||||
|
}
|
||||||
|
return filepath.Join(storageDir, fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerFieldNeedsBase64(path []string) bool {
|
||||||
|
if len(path) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key := strings.ToLower(strings.Trim(path[len(path)-1], "[]"))
|
||||||
|
parent := ""
|
||||||
|
if len(path) > 1 {
|
||||||
|
parent = strings.ToLower(strings.Trim(path[len(path)-2], "[]"))
|
||||||
|
}
|
||||||
|
return key == "b64_json" ||
|
||||||
|
key == "base64" ||
|
||||||
|
key == "b64" ||
|
||||||
|
strings.Contains(key, "base64") ||
|
||||||
|
strings.Contains(key, "_b64") ||
|
||||||
|
(parent == "input_audio" && key == "data")
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetIsExpired(asset store.RequestAsset, now time.Time) bool {
|
||||||
|
if asset.ExpiredAt != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if asset.ExpiresAt != nil && !asset.ExpiresAt.After(now) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAssetExpiredError(asset store.RequestAsset) error {
|
||||||
|
message := "request asset is expired or unavailable"
|
||||||
|
if asset.SHA256 != "" {
|
||||||
|
message = "request asset is expired or unavailable: " + asset.SHA256
|
||||||
|
}
|
||||||
|
return &clients.ClientError{Code: "request_asset_expired", Message: message, Retryable: false}
|
||||||
|
}
|
||||||
101
apps/api/internal/runner/request_assets_test.go
Normal file
101
apps/api/internal/runner/request_assets_test.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHydrateProviderRequestAssetsConvertsStrictBase64Field(t *testing.T) {
|
||||||
|
storageDir := t.TempDir()
|
||||||
|
fileName := "gateway-request-asset-test.png"
|
||||||
|
if err := os.WriteFile(filepath.Join(storageDir, fileName), []byte("image bytes"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write request asset: %v", err)
|
||||||
|
}
|
||||||
|
service := &Service{cfg: config.Config{LocalUploadedStorageDir: storageDir}}
|
||||||
|
body := map[string]any{
|
||||||
|
"model": "demo",
|
||||||
|
"b64_json": map[string]any{
|
||||||
|
"assetRef": map[string]any{
|
||||||
|
"sha256": "sha-test",
|
||||||
|
"contentType": "image/png",
|
||||||
|
"url": "/static/uploaded/" + fileName,
|
||||||
|
"storageProvider": "local_static",
|
||||||
|
},
|
||||||
|
"url": "/static/uploaded/" + fileName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
hydrated, err := service.hydrateProviderRequestAssets(context.Background(), body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("hydrate request assets: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := stringFromAny(hydrated["b64_json"]), base64.StdEncoding.EncodeToString([]byte("image bytes")); got != want {
|
||||||
|
t.Fatalf("unexpected hydrated base64: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHydrateProviderRequestAssetsReturnsExpiredError(t *testing.T) {
|
||||||
|
expiredAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
|
||||||
|
service := &Service{}
|
||||||
|
body := map[string]any{
|
||||||
|
"image_url": map[string]any{
|
||||||
|
"assetRef": map[string]any{
|
||||||
|
"sha256": "sha-expired",
|
||||||
|
"contentType": "image/png",
|
||||||
|
"url": "/static/uploaded/gateway-request-asset-expired.png",
|
||||||
|
"storageProvider": "local_static",
|
||||||
|
"expiresAt": expiredAt,
|
||||||
|
},
|
||||||
|
"url": "/static/uploaded/gateway-request-asset-expired.png",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := service.hydrateProviderRequestAssets(context.Background(), body)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected expired request asset error")
|
||||||
|
}
|
||||||
|
var clientErr *clients.ClientError
|
||||||
|
if !errors.As(err, &clientErr) || clientErr.Code != "request_asset_expired" {
|
||||||
|
t.Fatalf("expected request_asset_expired, got %T %v", err, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringFromAnyReadsRequestAssetWrapperURL(t *testing.T) {
|
||||||
|
wrapper := map[string]any{
|
||||||
|
"assetRef": map[string]any{"url": "https://cdn.example/request.png"},
|
||||||
|
}
|
||||||
|
if got := stringFromAny(wrapper); got != "https://cdn.example/request.png" {
|
||||||
|
t.Fatalf("expected wrapper URL, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlimTaskRequestSnapshotKeepsMessageRefs(t *testing.T) {
|
||||||
|
service := &Service{}
|
||||||
|
task := store.GatewayTask{Request: map[string]any{
|
||||||
|
"conversationId": "conv-1",
|
||||||
|
"messageRefs": []any{map[string]any{"messageId": "msg-1", "position": 0}},
|
||||||
|
"newMessageCount": 1,
|
||||||
|
}}
|
||||||
|
body := map[string]any{
|
||||||
|
"model": "demo",
|
||||||
|
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot := service.slimTaskRequestSnapshot(task, body)
|
||||||
|
|
||||||
|
if snapshot["messages"] != nil {
|
||||||
|
t.Fatalf("snapshot should not persist restored messages: %+v", snapshot)
|
||||||
|
}
|
||||||
|
if snapshot["messageRefs"] == nil || snapshot["newMessageCount"] != 1 {
|
||||||
|
t.Fatalf("snapshot should keep message refs and new count: %+v", snapshot)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -87,9 +87,13 @@ func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, use
|
|||||||
|
|
||||||
func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
|
func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
|
||||||
executeStartedAt := time.Now()
|
executeStartedAt := time.Now()
|
||||||
body := normalizeRequest(task.Kind, task.Request)
|
restoredRequest, err := s.restoreTaskRequestReferences(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
return Result{}, err
|
||||||
|
}
|
||||||
|
body := normalizeRequest(task.Kind, restoredRequest)
|
||||||
modelType := modelTypeFromKind(task.Kind, body)
|
modelType := modelTypeFromKind(task.Kind, body)
|
||||||
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
|
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, s.slimTaskRequestSnapshot(task, body)); err != nil {
|
||||||
return Result{}, err
|
return Result{}, err
|
||||||
}
|
}
|
||||||
if task.Status != "running" {
|
if task.Status != "running" {
|
||||||
@ -194,7 +198,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
|||||||
}
|
}
|
||||||
return Result{Task: failed, Output: failed.Result}, clientErr
|
return Result{Task: failed, Output: failed.Result}, clientErr
|
||||||
}
|
}
|
||||||
if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, firstCandidateBody); err != nil {
|
if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, s.slimTaskRequestSnapshot(task, firstCandidateBody)); err != nil {
|
||||||
return Result{}, err
|
return Result{}, err
|
||||||
}
|
}
|
||||||
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
|
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
|
||||||
@ -508,13 +512,13 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
|||||||
QueueKey: candidate.QueueKey,
|
QueueKey: candidate.QueueKey,
|
||||||
Status: "running",
|
Status: "running",
|
||||||
Simulated: simulated,
|
Simulated: simulated,
|
||||||
RequestSnapshot: body,
|
RequestSnapshot: s.slimTaskRequestSnapshot(task, body),
|
||||||
Metrics: baseAttemptMetrics,
|
Metrics: baseAttemptMetrics,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return clients.Response{}, fmt.Errorf("create task attempt: %w", err)
|
return clients.Response{}, fmt.Errorf("create task attempt: %w", err)
|
||||||
}
|
}
|
||||||
if err := s.recordTaskParameterPreprocessing(ctx, task.ID, attemptID, attemptNo, candidate, preprocessing); err != nil {
|
if err := s.recordTaskParameterPreprocessing(ctx, task.ID, attemptID, attemptNo, candidate, s.slimParameterPreprocessingLog(task, preprocessing)); err != nil {
|
||||||
clientErr := &clients.ClientError{Code: "runtime_error", Message: err.Error(), Retryable: false}
|
clientErr := &clients.ClientError{Code: "runtime_error", Message: err.Error(), Retryable: false}
|
||||||
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
||||||
AttemptID: attemptID,
|
AttemptID: attemptID,
|
||||||
@ -560,12 +564,24 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
|||||||
return clients.Response{}, fmt.Errorf("prepare http client: %w", err)
|
return clients.Response{}, fmt.Errorf("prepare http client: %w", err)
|
||||||
}
|
}
|
||||||
client := s.clientFor(candidate, simulated)
|
client := s.clientFor(candidate, simulated)
|
||||||
|
providerBody, err := s.hydrateProviderRequestAssets(ctx, body)
|
||||||
|
if err != nil {
|
||||||
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
||||||
|
AttemptID: attemptID,
|
||||||
|
Status: "failed",
|
||||||
|
Retryable: false,
|
||||||
|
Metrics: mergeMetrics(baseAttemptMetrics, map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(err, false)}}),
|
||||||
|
ErrorCode: clients.ErrorCode(err),
|
||||||
|
ErrorMessage: err.Error(),
|
||||||
|
})
|
||||||
|
return clients.Response{}, err
|
||||||
|
}
|
||||||
callStartedAt := time.Now()
|
callStartedAt := time.Now()
|
||||||
response, err := client.Run(ctx, clients.Request{
|
response, err := client.Run(ctx, clients.Request{
|
||||||
Kind: task.Kind,
|
Kind: task.Kind,
|
||||||
ModelType: candidate.ModelType,
|
ModelType: candidate.ModelType,
|
||||||
Model: task.Model,
|
Model: task.Model,
|
||||||
Body: body,
|
Body: providerBody,
|
||||||
Candidate: candidate,
|
Candidate: candidate,
|
||||||
HTTPClient: requestHTTPClient,
|
HTTPClient: requestHTTPClient,
|
||||||
RemoteTaskID: task.RemoteTaskID,
|
RemoteTaskID: task.RemoteTaskID,
|
||||||
@ -576,7 +592,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
|||||||
}
|
}
|
||||||
return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload)
|
return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload)
|
||||||
},
|
},
|
||||||
Stream: boolFromMap(body, "stream"),
|
Stream: boolFromMap(providerBody, "stream"),
|
||||||
StreamDelta: onDelta,
|
StreamDelta: onDelta,
|
||||||
})
|
})
|
||||||
callFinishedAt := time.Now()
|
callFinishedAt := time.Now()
|
||||||
@ -826,7 +842,7 @@ func (s *Service) recordFailedAttempt(ctx context.Context, input failedAttemptRe
|
|||||||
QueueKey: queueKey,
|
QueueKey: queueKey,
|
||||||
Status: "running",
|
Status: "running",
|
||||||
Simulated: input.Simulated,
|
Simulated: input.Simulated,
|
||||||
RequestSnapshot: input.Body,
|
RequestSnapshot: s.slimTaskRequestSnapshot(input.Task, input.Body),
|
||||||
Metrics: metrics,
|
Metrics: metrics,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -834,7 +850,8 @@ func (s *Service) recordFailedAttempt(ctx context.Context, input failedAttemptRe
|
|||||||
return attemptNo
|
return attemptNo
|
||||||
}
|
}
|
||||||
if input.Preprocessing != nil && input.Candidate != nil {
|
if input.Preprocessing != nil && input.Candidate != nil {
|
||||||
if err := s.recordTaskParameterPreprocessing(ctx, input.Task.ID, attemptID, attemptNo, *input.Candidate, *input.Preprocessing); err != nil {
|
preprocessing := s.slimParameterPreprocessingLog(input.Task, *input.Preprocessing)
|
||||||
|
if err := s.recordTaskParameterPreprocessing(ctx, input.Task.ID, attemptID, attemptNo, *input.Candidate, preprocessing); err != nil {
|
||||||
s.logger.Warn("record failed attempt parameter preprocessing failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
|
s.logger.Warn("record failed attempt parameter preprocessing failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
149
apps/api/internal/store/conversations.go
Normal file
149
apps/api/internal/store/conversations.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConversationMessageInput struct {
|
||||||
|
Hash string
|
||||||
|
Role string
|
||||||
|
Snapshot map[string]any
|
||||||
|
AssetSHA256s []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskMessageRefInput struct {
|
||||||
|
MessageID string
|
||||||
|
Position int
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConversationMessageRef struct {
|
||||||
|
MessageID string `json:"messageId"`
|
||||||
|
Position int `json:"position"`
|
||||||
|
Message map[string]any `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) EnsureConversation(ctx context.Context, user *auth.User, conversationKey string, metadata map[string]any) (string, error) {
|
||||||
|
conversationKey = strings.TrimSpace(conversationKey)
|
||||||
|
if conversationKey == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
userID := ""
|
||||||
|
gatewayUserID := ""
|
||||||
|
if user != nil {
|
||||||
|
userID = strings.TrimSpace(user.ID)
|
||||||
|
gatewayUserID = strings.TrimSpace(user.GatewayUserID)
|
||||||
|
}
|
||||||
|
if userID == "" {
|
||||||
|
userID = "anonymous"
|
||||||
|
}
|
||||||
|
metadataJSON, _ := json.Marshal(emptyObjectIfNil(metadata))
|
||||||
|
var conversationID string
|
||||||
|
err := s.pool.QueryRow(ctx, `
|
||||||
|
INSERT INTO gateway_conversations (user_id, gateway_user_id, conversation_key, metadata)
|
||||||
|
VALUES ($1, NULLIF($2, '')::uuid, $3, $4::jsonb)
|
||||||
|
ON CONFLICT (user_id, conversation_key) DO UPDATE
|
||||||
|
SET gateway_user_id = COALESCE(gateway_conversations.gateway_user_id, EXCLUDED.gateway_user_id),
|
||||||
|
metadata = gateway_conversations.metadata || EXCLUDED.metadata,
|
||||||
|
updated_at = now()
|
||||||
|
RETURNING id::text`, userID, gatewayUserID, conversationKey, string(metadataJSON)).Scan(&conversationID)
|
||||||
|
return conversationID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpsertConversationMessages(ctx context.Context, conversationID string, messages []ConversationMessageInput) ([]TaskMessageRefInput, int, error) {
|
||||||
|
if strings.TrimSpace(conversationID) == "" || len(messages) == 0 {
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
tx, err := s.pool.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
refs := make([]TaskMessageRefInput, 0, len(messages))
|
||||||
|
newCount := 0
|
||||||
|
for index, message := range messages {
|
||||||
|
snapshotJSON, _ := json.Marshal(emptyObjectIfNil(message.Snapshot))
|
||||||
|
var messageID string
|
||||||
|
var inserted bool
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
INSERT INTO gateway_conversation_messages (
|
||||||
|
conversation_id, message_hash, role, message_snapshot, asset_sha256s
|
||||||
|
)
|
||||||
|
VALUES ($1::uuid, $2, NULLIF($3, ''), $4::jsonb, $5)
|
||||||
|
ON CONFLICT (conversation_id, message_hash) DO UPDATE
|
||||||
|
SET updated_at = gateway_conversation_messages.updated_at
|
||||||
|
RETURNING id::text, (xmax = 0) AS inserted`,
|
||||||
|
conversationID,
|
||||||
|
message.Hash,
|
||||||
|
message.Role,
|
||||||
|
string(snapshotJSON),
|
||||||
|
message.AssetSHA256s,
|
||||||
|
).Scan(&messageID, &inserted); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if inserted {
|
||||||
|
newCount++
|
||||||
|
}
|
||||||
|
refs = append(refs, TaskMessageRefInput{MessageID: messageID, Position: index})
|
||||||
|
}
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return refs, newCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) ListTaskConversationMessages(ctx context.Context, taskID string) ([]ConversationMessageRef, error) {
|
||||||
|
rows, err := s.pool.Query(ctx, `
|
||||||
|
SELECT refs.message_id::text, refs.position, messages.message_snapshot
|
||||||
|
FROM gateway_task_message_refs refs
|
||||||
|
JOIN gateway_conversation_messages messages ON messages.id = refs.message_id
|
||||||
|
WHERE refs.task_id = $1::uuid
|
||||||
|
ORDER BY refs.position ASC`, taskID)
|
||||||
|
if err != nil {
|
||||||
|
if IsUndefinedDatabaseObject(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
items := make([]ConversationMessageRef, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var item ConversationMessageRef
|
||||||
|
var snapshot []byte
|
||||||
|
if err := rows.Scan(&item.MessageID, &item.Position, &snapshot); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
item.Message = decodeObject(snapshot)
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
return items, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func insertTaskMessageRefs(ctx context.Context, tx pgx.Tx, taskID string, refs []TaskMessageRefInput) error {
|
||||||
|
if len(refs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, ref := range refs {
|
||||||
|
if strings.TrimSpace(ref.MessageID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
INSERT INTO gateway_task_message_refs (task_id, message_id, position)
|
||||||
|
VALUES ($1::uuid, $2::uuid, $3)
|
||||||
|
ON CONFLICT (task_id, position) DO UPDATE
|
||||||
|
SET message_id = EXCLUDED.message_id`,
|
||||||
|
taskID,
|
||||||
|
ref.MessageID,
|
||||||
|
ref.Position,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -12,8 +12,9 @@ import (
|
|||||||
const defaultServerMainUploadURL = "http://127.0.0.1:3001/v1/files/upload"
|
const defaultServerMainUploadURL = "http://127.0.0.1:3001/v1/files/upload"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
FileStorageSceneUpload = "upload"
|
FileStorageSceneUpload = "upload"
|
||||||
FileStorageSceneImageResult = "image_result"
|
FileStorageSceneImageResult = "image_result"
|
||||||
|
FileStorageSceneRequestAsset = "request_asset"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@ -380,11 +380,14 @@ type RateLimitWindow struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CreateTaskInput struct {
|
type CreateTaskInput struct {
|
||||||
Kind string `json:"kind"`
|
Kind string `json:"kind"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
RunMode string `json:"runMode"`
|
RunMode string `json:"runMode"`
|
||||||
Async bool `json:"async"`
|
Async bool `json:"async"`
|
||||||
Request map[string]any `json:"request"`
|
Request map[string]any `json:"request"`
|
||||||
|
ConversationID string `json:"conversationId"`
|
||||||
|
NewMessageCount int `json:"newMessageCount"`
|
||||||
|
MessageRefs []TaskMessageRefInput `json:"messageRefs"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GatewayTask struct {
|
type GatewayTask struct {
|
||||||
@ -407,6 +410,8 @@ type GatewayTask struct {
|
|||||||
RequestedModel string `json:"requestedModel,omitempty"`
|
RequestedModel string `json:"requestedModel,omitempty"`
|
||||||
ResolvedModel string `json:"resolvedModel,omitempty"`
|
ResolvedModel string `json:"resolvedModel,omitempty"`
|
||||||
RequestID string `json:"requestId,omitempty"`
|
RequestID string `json:"requestId,omitempty"`
|
||||||
|
ConversationID string `json:"conversationId,omitempty"`
|
||||||
|
NewMessageCount int `json:"newMessageCount,omitempty"`
|
||||||
Request map[string]any `json:"request,omitempty"`
|
Request map[string]any `json:"request,omitempty"`
|
||||||
AsyncMode bool `json:"asyncMode"`
|
AsyncMode bool `json:"asyncMode"`
|
||||||
RiverJobID int64 `json:"riverJobId,omitempty"`
|
RiverJobID int64 `json:"riverJobId,omitempty"`
|
||||||
@ -438,6 +443,7 @@ COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_
|
|||||||
COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''),
|
COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''),
|
||||||
COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model,
|
COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model,
|
||||||
COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''),
|
COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''),
|
||||||
|
COALESCE(conversation_id::text, ''), COALESCE(new_message_count, 0),
|
||||||
request, COALESCE(async_mode, false), COALESCE(river_job_id, 0), status, COALESCE(attempt_count, 0),
|
request, COALESCE(async_mode, false), COALESCE(river_job_id, 0), status, COALESCE(attempt_count, 0),
|
||||||
COALESCE(remote_task_id, ''), COALESCE(remote_task_payload, '{}'::jsonb),
|
COALESCE(remote_task_id, ''), COALESCE(remote_task_payload, '{}'::jsonb),
|
||||||
COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
|
COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
|
||||||
@ -1746,15 +1752,18 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut
|
|||||||
INSERT INTO gateway_tasks (
|
INSERT INTO gateway_tasks (
|
||||||
kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key,
|
kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key,
|
||||||
api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key,
|
api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key,
|
||||||
model, requested_model, request, async_mode, status, result, billings, finished_at
|
model, requested_model, request, async_mode, status, result, billings, conversation_id, new_message_count, finished_at
|
||||||
)
|
)
|
||||||
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, CASE WHEN $20 THEN now() ELSE NULL END)
|
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, NULLIF($20, '')::uuid, $21, CASE WHEN $22 THEN now() ELSE NULL END)
|
||||||
RETURNING `+gatewayTaskColumns,
|
RETURNING `+gatewayTaskColumns,
|
||||||
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, false,
|
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, input.ConversationID, input.NewMessageCount, false,
|
||||||
))
|
))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return GatewayTask{}, err
|
return GatewayTask{}, err
|
||||||
}
|
}
|
||||||
|
if err := insertTaskMessageRefs(ctx, tx, task.ID, input.MessageRefs); err != nil {
|
||||||
|
return GatewayTask{}, err
|
||||||
|
}
|
||||||
events := taskEventsForCreate(task.ID, runMode, status, nil)
|
events := taskEventsForCreate(task.ID, runMode, status, nil)
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
payload, _ := json.Marshal(event.Payload)
|
payload, _ := json.Marshal(event.Payload)
|
||||||
@ -1822,6 +1831,8 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
|
|||||||
&task.RequestedModel,
|
&task.RequestedModel,
|
||||||
&task.ResolvedModel,
|
&task.ResolvedModel,
|
||||||
&task.RequestID,
|
&task.RequestID,
|
||||||
|
&task.ConversationID,
|
||||||
|
&task.NewMessageCount,
|
||||||
&requestBytes,
|
&requestBytes,
|
||||||
&task.AsyncMode,
|
&task.AsyncMode,
|
||||||
&task.RiverJobID,
|
&task.RiverJobID,
|
||||||
|
|||||||
@ -7,6 +7,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RateLimitMetricStatus struct {
|
type RateLimitMetricStatus struct {
|
||||||
@ -82,6 +84,59 @@ type PlatformPolicyEvent struct {
|
|||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) RestorePlatformModelRuntimeStatus(ctx context.Context, platformModelID string) (ModelRateLimitStatus, error) {
|
||||||
|
platformModelID = strings.TrimSpace(platformModelID)
|
||||||
|
if platformModelID == "" {
|
||||||
|
return ModelRateLimitStatus{}, pgx.ErrNoRows
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.pool.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return ModelRateLimitStatus{}, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
var restoredModelID string
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
WITH restored_model AS (
|
||||||
|
UPDATE platform_models
|
||||||
|
SET enabled = true,
|
||||||
|
cooldown_until = NULL,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = $1::uuid
|
||||||
|
RETURNING id::text, platform_id
|
||||||
|
),
|
||||||
|
restored_platform AS (
|
||||||
|
UPDATE integration_platforms
|
||||||
|
SET status = 'enabled',
|
||||||
|
disabled_reason = NULL,
|
||||||
|
cooldown_until = NULL,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = (SELECT platform_id FROM restored_model)
|
||||||
|
AND deleted_at IS NULL
|
||||||
|
RETURNING id
|
||||||
|
)
|
||||||
|
SELECT id
|
||||||
|
FROM restored_model
|
||||||
|
WHERE EXISTS (SELECT 1 FROM restored_platform)`, platformModelID).Scan(&restoredModelID); err != nil {
|
||||||
|
return ModelRateLimitStatus{}, err
|
||||||
|
}
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return ModelRateLimitStatus{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
items, err := s.ListModelRateLimitStatuses(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return ModelRateLimitStatus{}, err
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
if item.PlatformModelID == restoredModelID {
|
||||||
|
return item, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ModelRateLimitStatus{}, pgx.ErrNoRows
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) {
|
func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) {
|
||||||
rows, err := s.pool.Query(ctx, `
|
rows, err := s.pool.Query(ctx, `
|
||||||
SELECT m.id::text, m.platform_id::text, p.name, p.provider, p.status,
|
SELECT m.id::text, m.platform_id::text, p.name, p.provider, p.status,
|
||||||
|
|||||||
130
apps/api/internal/store/request_assets.go
Normal file
130
apps/api/internal/store/request_assets.go
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestAsset struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
SHA256 string `json:"sha256"`
|
||||||
|
ContentType string `json:"contentType"`
|
||||||
|
ByteSize int64 `json:"byteSize"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
StorageProvider string `json:"storageProvider"`
|
||||||
|
LocalPath string `json:"localPath,omitempty"`
|
||||||
|
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||||
|
ExpiredAt *time.Time `json:"expiredAt,omitempty"`
|
||||||
|
RefCount int `json:"refCount"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestAssetInput struct {
|
||||||
|
SHA256 string
|
||||||
|
ContentType string
|
||||||
|
ByteSize int64
|
||||||
|
URL string
|
||||||
|
StorageProvider string
|
||||||
|
LocalPath string
|
||||||
|
ExpiresAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) FindRequestAsset(ctx context.Context, sha256 string, contentType string) (RequestAsset, bool, error) {
|
||||||
|
asset, err := scanRequestAsset(s.pool.QueryRow(ctx, `
|
||||||
|
SELECT id::text, sha256, content_type, byte_size, url, storage_provider,
|
||||||
|
COALESCE(local_path, ''), expires_at, expired_at, ref_count, created_at, updated_at
|
||||||
|
FROM gateway_request_assets
|
||||||
|
WHERE sha256 = $1 AND content_type = $2`, sha256, contentType))
|
||||||
|
if err != nil {
|
||||||
|
if err == pgx.ErrNoRows {
|
||||||
|
return RequestAsset{}, false, nil
|
||||||
|
}
|
||||||
|
return RequestAsset{}, false, err
|
||||||
|
}
|
||||||
|
return asset, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpsertRequestAsset(ctx context.Context, input RequestAssetInput) (RequestAsset, error) {
|
||||||
|
return scanRequestAsset(s.pool.QueryRow(ctx, `
|
||||||
|
INSERT INTO gateway_request_assets (
|
||||||
|
sha256, content_type, byte_size, url, storage_provider, local_path, expires_at, expired_at, ref_count
|
||||||
|
)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, NULLIF($6, ''), $7, NULL, 1)
|
||||||
|
ON CONFLICT (sha256, content_type) DO UPDATE
|
||||||
|
SET byte_size = EXCLUDED.byte_size,
|
||||||
|
url = EXCLUDED.url,
|
||||||
|
storage_provider = EXCLUDED.storage_provider,
|
||||||
|
local_path = EXCLUDED.local_path,
|
||||||
|
expires_at = EXCLUDED.expires_at,
|
||||||
|
expired_at = NULL,
|
||||||
|
ref_count = gateway_request_assets.ref_count + 1,
|
||||||
|
updated_at = now()
|
||||||
|
RETURNING id::text, sha256, content_type, byte_size, url, storage_provider,
|
||||||
|
COALESCE(local_path, ''), expires_at, expired_at, ref_count, created_at, updated_at`,
|
||||||
|
input.SHA256,
|
||||||
|
input.ContentType,
|
||||||
|
input.ByteSize,
|
||||||
|
input.URL,
|
||||||
|
input.StorageProvider,
|
||||||
|
input.LocalPath,
|
||||||
|
input.ExpiresAt,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) IncrementRequestAssetRefCount(ctx context.Context, sha256 string, contentType string) error {
|
||||||
|
_, err := s.pool.Exec(ctx, `
|
||||||
|
UPDATE gateway_request_assets
|
||||||
|
SET ref_count = ref_count + 1,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE sha256 = $1 AND content_type = $2`, sha256, contentType)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) MarkRequestAssetExpiredByLocalPath(ctx context.Context, localPath string, expiredAt time.Time) error {
|
||||||
|
if localPath == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err := s.pool.Exec(ctx, `
|
||||||
|
UPDATE gateway_request_assets
|
||||||
|
SET expired_at = COALESCE(expired_at, $2),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE local_path = $1
|
||||||
|
AND storage_provider = 'local_static'
|
||||||
|
AND expired_at IS NULL`, localPath, expiredAt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanRequestAsset(scanner interface{ Scan(dest ...any) error }) (RequestAsset, error) {
|
||||||
|
var asset RequestAsset
|
||||||
|
var localPath string
|
||||||
|
var expiresAt sql.NullTime
|
||||||
|
var expiredAt sql.NullTime
|
||||||
|
if err := scanner.Scan(
|
||||||
|
&asset.ID,
|
||||||
|
&asset.SHA256,
|
||||||
|
&asset.ContentType,
|
||||||
|
&asset.ByteSize,
|
||||||
|
&asset.URL,
|
||||||
|
&asset.StorageProvider,
|
||||||
|
&localPath,
|
||||||
|
&expiresAt,
|
||||||
|
&expiredAt,
|
||||||
|
&asset.RefCount,
|
||||||
|
&asset.CreatedAt,
|
||||||
|
&asset.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return RequestAsset{}, err
|
||||||
|
}
|
||||||
|
asset.LocalPath = localPath
|
||||||
|
if expiresAt.Valid {
|
||||||
|
asset.ExpiresAt = &expiresAt.Time
|
||||||
|
}
|
||||||
|
if expiredAt.Valid {
|
||||||
|
asset.ExpiredAt = &expiredAt.Time
|
||||||
|
}
|
||||||
|
return asset, nil
|
||||||
|
}
|
||||||
@ -0,0 +1,68 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS gateway_request_assets (
|
||||||
|
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
sha256 text NOT NULL,
|
||||||
|
content_type text NOT NULL,
|
||||||
|
byte_size bigint NOT NULL,
|
||||||
|
url text NOT NULL,
|
||||||
|
storage_provider text NOT NULL,
|
||||||
|
local_path text,
|
||||||
|
expires_at timestamptz,
|
||||||
|
expired_at timestamptz,
|
||||||
|
ref_count integer NOT NULL DEFAULT 0,
|
||||||
|
created_at timestamptz NOT NULL DEFAULT now(),
|
||||||
|
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||||
|
UNIQUE (sha256, content_type)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_gateway_request_assets_expires
|
||||||
|
ON gateway_request_assets(storage_provider, expires_at)
|
||||||
|
WHERE expires_at IS NOT NULL AND expired_at IS NULL;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_gateway_request_assets_local_path
|
||||||
|
ON gateway_request_assets(local_path)
|
||||||
|
WHERE local_path IS NOT NULL;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS gateway_conversations (
|
||||||
|
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id text NOT NULL,
|
||||||
|
gateway_user_id uuid REFERENCES gateway_users(id) ON DELETE SET NULL,
|
||||||
|
conversation_key text NOT NULL,
|
||||||
|
metadata jsonb NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
created_at timestamptz NOT NULL DEFAULT now(),
|
||||||
|
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||||
|
UNIQUE (user_id, conversation_key)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS gateway_conversation_messages (
|
||||||
|
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id uuid NOT NULL REFERENCES gateway_conversations(id) ON DELETE CASCADE,
|
||||||
|
message_hash text NOT NULL,
|
||||||
|
role text,
|
||||||
|
message_snapshot jsonb NOT NULL,
|
||||||
|
asset_sha256s text[] NOT NULL DEFAULT '{}',
|
||||||
|
created_at timestamptz NOT NULL DEFAULT now(),
|
||||||
|
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||||
|
UNIQUE (conversation_id, message_hash)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_gateway_conversation_messages_conversation
|
||||||
|
ON gateway_conversation_messages(conversation_id, created_at);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS gateway_task_message_refs (
|
||||||
|
task_id uuid NOT NULL REFERENCES gateway_tasks(id) ON DELETE CASCADE,
|
||||||
|
message_id uuid NOT NULL REFERENCES gateway_conversation_messages(id) ON DELETE CASCADE,
|
||||||
|
position integer NOT NULL,
|
||||||
|
created_at timestamptz NOT NULL DEFAULT now(),
|
||||||
|
PRIMARY KEY (task_id, position)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_gateway_task_message_refs_message
|
||||||
|
ON gateway_task_message_refs(message_id);
|
||||||
|
|
||||||
|
ALTER TABLE gateway_tasks
|
||||||
|
ADD COLUMN IF NOT EXISTS conversation_id uuid REFERENCES gateway_conversations(id) ON DELETE SET NULL,
|
||||||
|
ADD COLUMN IF NOT EXISTS new_message_count integer NOT NULL DEFAULT 0;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_conversation_created
|
||||||
|
ON gateway_tasks(conversation_id, created_at DESC)
|
||||||
|
WHERE conversation_id IS NOT NULL;
|
||||||
@ -84,6 +84,7 @@ import {
|
|||||||
pollTaskUntilSettled,
|
pollTaskUntilSettled,
|
||||||
registerLocalAccount,
|
registerLocalAccount,
|
||||||
replacePlatformModels,
|
replacePlatformModels,
|
||||||
|
restoreModelRuntimeStatus,
|
||||||
setUserWalletBalance,
|
setUserWalletBalance,
|
||||||
type HealthResponse,
|
type HealthResponse,
|
||||||
updateAccessRule,
|
updateAccessRule,
|
||||||
@ -635,6 +636,50 @@ export function App() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function restoreRuntimeModel(platformModelId: string) {
|
||||||
|
setCoreState('loading');
|
||||||
|
setCoreMessage('');
|
||||||
|
try {
|
||||||
|
const restored = await restoreModelRuntimeStatus(token, platformModelId);
|
||||||
|
setModelRateLimits((current) => current.map((status) => {
|
||||||
|
if (status.platformId !== restored.platformId) return status;
|
||||||
|
return {
|
||||||
|
...status,
|
||||||
|
platformStatus: 'enabled',
|
||||||
|
platformCooldownUntil: undefined,
|
||||||
|
platformDisabledReason: undefined,
|
||||||
|
...(status.platformModelId === platformModelId
|
||||||
|
? {
|
||||||
|
enabled: restored.enabled,
|
||||||
|
modelCooldownUntil: restored.modelCooldownUntil,
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
|
};
|
||||||
|
}));
|
||||||
|
setPlatforms((current) => current.map((platform) => platform.id === restored.platformId
|
||||||
|
? {
|
||||||
|
...platform,
|
||||||
|
status: 'enabled',
|
||||||
|
cooldownUntil: undefined,
|
||||||
|
}
|
||||||
|
: platform));
|
||||||
|
setModels((current) => current.map((model) => model.id === platformModelId
|
||||||
|
? {
|
||||||
|
...model,
|
||||||
|
enabled: true,
|
||||||
|
cooldownUntil: undefined,
|
||||||
|
}
|
||||||
|
: model));
|
||||||
|
invalidateDataKeys('modelCatalog', 'modelRateLimits', 'models', 'platforms', 'playgroundModels');
|
||||||
|
setCoreState('ready');
|
||||||
|
setCoreMessage('模型运行状态已恢复。');
|
||||||
|
} catch (err) {
|
||||||
|
setCoreState('error');
|
||||||
|
setCoreMessage(err instanceof Error ? err.message : '恢复模型运行状态失败');
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function removePlatform(platformId: string) {
|
async function removePlatform(platformId: string) {
|
||||||
setCoreState('loading');
|
setCoreState('loading');
|
||||||
setCoreMessage('');
|
setCoreMessage('');
|
||||||
@ -1143,6 +1188,7 @@ export function App() {
|
|||||||
onResetBaseModel={resetBaseModelToDefault}
|
onResetBaseModel={resetBaseModelToDefault}
|
||||||
onSavePlatform={savePlatformWithModels}
|
onSavePlatform={savePlatformWithModels}
|
||||||
onSavePlatformDynamicPriority={savePlatformDynamicPriority}
|
onSavePlatformDynamicPriority={savePlatformDynamicPriority}
|
||||||
|
onRestoreRuntimeModel={restoreRuntimeModel}
|
||||||
onTogglePlatformStatus={savePlatformStatus}
|
onTogglePlatformStatus={savePlatformStatus}
|
||||||
onSaveProvider={saveProvider}
|
onSaveProvider={saveProvider}
|
||||||
onSavePricingRuleSet={savePricingRuleSet}
|
onSavePricingRuleSet={savePricingRuleSet}
|
||||||
|
|||||||
@ -913,6 +913,13 @@ export async function listModelRateLimitStatuses(token: string): Promise<ListRes
|
|||||||
return request<ListResponse<ModelRateLimitStatus>>('/api/admin/runtime/model-rate-limits', { token });
|
return request<ListResponse<ModelRateLimitStatus>>('/api/admin/runtime/model-rate-limits', { token });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function restoreModelRuntimeStatus(token: string, platformModelId: string): Promise<ModelRateLimitStatus> {
|
||||||
|
return request<ModelRateLimitStatus>(`/api/admin/runtime/model-rate-limits/${platformModelId}/restore`, {
|
||||||
|
method: 'POST',
|
||||||
|
token,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
export async function getNetworkProxyConfig(token: string): Promise<GatewayNetworkProxyConfig> {
|
export async function getNetworkProxyConfig(token: string): Promise<GatewayNetworkProxyConfig> {
|
||||||
return request<GatewayNetworkProxyConfig>('/api/admin/config/network-proxy', { token });
|
return request<GatewayNetworkProxyConfig>('/api/admin/config/network-proxy', { token });
|
||||||
}
|
}
|
||||||
|
|||||||
@ -71,6 +71,7 @@ export function AdminPage(props: {
|
|||||||
onBatchAccessRules: (input: GatewayAccessRuleBatchRequest) => Promise<void>;
|
onBatchAccessRules: (input: GatewayAccessRuleBatchRequest) => Promise<void>;
|
||||||
onSavePlatform: (input: PlatformWithModelsInput) => Promise<void>;
|
onSavePlatform: (input: PlatformWithModelsInput) => Promise<void>;
|
||||||
onSavePlatformDynamicPriority: (platformId: string, input: PlatformDynamicPriorityUpdateRequest) => Promise<void>;
|
onSavePlatformDynamicPriority: (platformId: string, input: PlatformDynamicPriorityUpdateRequest) => Promise<void>;
|
||||||
|
onRestoreRuntimeModel: (platformModelId: string) => Promise<void>;
|
||||||
onTogglePlatformStatus: (platform: IntegrationPlatform, status: 'enabled' | 'disabled') => Promise<void>;
|
onTogglePlatformStatus: (platform: IntegrationPlatform, status: 'enabled' | 'disabled') => Promise<void>;
|
||||||
onSaveProvider: (input: CatalogProviderUpsertRequest, providerId?: string) => Promise<void>;
|
onSaveProvider: (input: CatalogProviderUpsertRequest, providerId?: string) => Promise<void>;
|
||||||
onSavePricingRuleSet: (input: PricingRuleSetUpsertRequest, ruleSetId?: string) => Promise<void>;
|
onSavePricingRuleSet: (input: PricingRuleSetUpsertRequest, ruleSetId?: string) => Promise<void>;
|
||||||
@ -173,6 +174,7 @@ export function AdminPage(props: {
|
|||||||
modelRateLimitsUpdatedAt={props.data.modelRateLimitsUpdatedAt}
|
modelRateLimitsUpdatedAt={props.data.modelRateLimitsUpdatedAt}
|
||||||
platforms={props.data.platforms}
|
platforms={props.data.platforms}
|
||||||
onSavePlatformDynamicPriority={props.onSavePlatformDynamicPriority}
|
onSavePlatformDynamicPriority={props.onSavePlatformDynamicPriority}
|
||||||
|
onRestoreRuntimeModel={props.onRestoreRuntimeModel}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{props.section === 'tenants' && <TenantsPanel {...identityPanelProps(props)} />}
|
{props.section === 'tenants' && <TenantsPanel {...identityPanelProps(props)} />}
|
||||||
|
|||||||
@ -9,11 +9,13 @@ export function RealtimeLoadPanel(props: {
|
|||||||
modelRateLimitsUpdatedAt: number | null;
|
modelRateLimitsUpdatedAt: number | null;
|
||||||
platforms: IntegrationPlatform[];
|
platforms: IntegrationPlatform[];
|
||||||
onSavePlatformDynamicPriority: (platformId: string, input: PlatformDynamicPriorityUpdateRequest) => Promise<void>;
|
onSavePlatformDynamicPriority: (platformId: string, input: PlatformDynamicPriorityUpdateRequest) => Promise<void>;
|
||||||
|
onRestoreRuntimeModel: (platformModelId: string) => Promise<void>;
|
||||||
}) {
|
}) {
|
||||||
const [now, setNow] = useState(() => Date.now());
|
const [now, setNow] = useState(() => Date.now());
|
||||||
const [priorityDialog, setPriorityDialog] = useState<PriorityDialogState | null>(null);
|
const [priorityDialog, setPriorityDialog] = useState<PriorityDialogState | null>(null);
|
||||||
const [priorityError, setPriorityError] = useState('');
|
const [priorityError, setPriorityError] = useState('');
|
||||||
const [prioritySaving, setPrioritySaving] = useState(false);
|
const [prioritySaving, setPrioritySaving] = useState(false);
|
||||||
|
const [restoreSavingId, setRestoreSavingId] = useState<string | null>(null);
|
||||||
const platformMap = useMemo(() => new Map(props.platforms.map((item) => [item.id, item])), [props.platforms]);
|
const platformMap = useMemo(() => new Map(props.platforms.map((item) => [item.id, item])), [props.platforms]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -65,6 +67,17 @@ export function RealtimeLoadPanel(props: {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function restoreRuntimeModel(platformModelId: string) {
|
||||||
|
setRestoreSavingId(platformModelId);
|
||||||
|
try {
|
||||||
|
await props.onRestoreRuntimeModel(platformModelId);
|
||||||
|
} catch {
|
||||||
|
// App-level state owns the error message.
|
||||||
|
} finally {
|
||||||
|
setRestoreSavingId(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<section className="pageStack">
|
<section className="pageStack">
|
||||||
<Card>
|
<Card>
|
||||||
@ -81,6 +94,8 @@ export function RealtimeLoadPanel(props: {
|
|||||||
statuses={props.modelRateLimits}
|
statuses={props.modelRateLimits}
|
||||||
updatedAt={props.modelRateLimitsUpdatedAt}
|
updatedAt={props.modelRateLimitsUpdatedAt}
|
||||||
onAdjustPriority={openPriorityDialog}
|
onAdjustPriority={openPriorityDialog}
|
||||||
|
onRestoreRuntimeModel={restoreRuntimeModel}
|
||||||
|
restoreSavingId={restoreSavingId}
|
||||||
/>
|
/>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
@ -109,6 +124,8 @@ function RateLimitStatusTable(props: {
|
|||||||
now: number;
|
now: number;
|
||||||
updatedAt: number | null;
|
updatedAt: number | null;
|
||||||
onAdjustPriority: (status: ModelRateLimitStatus, platform: IntegrationPlatform | undefined) => void;
|
onAdjustPriority: (status: ModelRateLimitStatus, platform: IntegrationPlatform | undefined) => void;
|
||||||
|
onRestoreRuntimeModel: (platformModelId: string) => Promise<void>;
|
||||||
|
restoreSavingId: string | null;
|
||||||
}) {
|
}) {
|
||||||
if (!props.statuses.length) {
|
if (!props.statuses.length) {
|
||||||
return <EmptyState title="暂无实时负载" description="模型产生请求后会在这里显示实时 RPM、TPM 和并发窗口。" />;
|
return <EmptyState title="暂无实时负载" description="模型产生请求后会在这里显示实时 RPM、TPM 和并发窗口。" />;
|
||||||
@ -150,7 +167,15 @@ function RateLimitStatusTable(props: {
|
|||||||
<small>{status.provider}</small>
|
<small>{status.provider}</small>
|
||||||
</span>
|
</span>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell className="platformLimitStatusCell">{modelRuntimeStatusCell(status, platform, props.now)}</TableCell>
|
<TableCell className="platformLimitStatusCell">
|
||||||
|
{modelRuntimeStatusCell(
|
||||||
|
status,
|
||||||
|
platform,
|
||||||
|
props.now,
|
||||||
|
props.onRestoreRuntimeModel,
|
||||||
|
props.restoreSavingId === status.platformModelId,
|
||||||
|
)}
|
||||||
|
</TableCell>
|
||||||
<TableCell className="platformLimitNumberCell">{platformPriorityCell(status, platform, props.onAdjustPriority)}</TableCell>
|
<TableCell className="platformLimitNumberCell">{platformPriorityCell(status, platform, props.onAdjustPriority)}</TableCell>
|
||||||
<TableCell className="platformLimitNumberCell">
|
<TableCell className="platformLimitNumberCell">
|
||||||
<span className="rateLoadCell" data-overloaded={status.loadRatio > 0.8 ? 'true' : undefined}>
|
<span className="rateLoadCell" data-overloaded={status.loadRatio > 0.8 ? 'true' : undefined}>
|
||||||
@ -447,50 +472,104 @@ function shortId(value: string | undefined) {
|
|||||||
return value.length > 8 ? value.slice(0, 8) : value;
|
return value.length > 8 ? value.slice(0, 8) : value;
|
||||||
}
|
}
|
||||||
|
|
||||||
function modelRuntimeStatusCell(status: ModelRateLimitStatus, platform: IntegrationPlatform | undefined, now: number) {
|
function modelRuntimeStatusCell(
|
||||||
|
status: ModelRateLimitStatus,
|
||||||
|
platform: IntegrationPlatform | undefined,
|
||||||
|
now: number,
|
||||||
|
onRestore: (platformModelId: string) => Promise<void>,
|
||||||
|
restoring: boolean,
|
||||||
|
) {
|
||||||
const modelCooldownMs = cooldownRemainingMs(status.modelCooldownUntil, now);
|
const modelCooldownMs = cooldownRemainingMs(status.modelCooldownUntil, now);
|
||||||
const platformCooldownMs = cooldownRemainingMs(status.platformCooldownUntil, now);
|
const platformCooldownMs = cooldownRemainingMs(status.platformCooldownUntil, now);
|
||||||
const platformStatus = platform?.status || status.platformStatus || 'enabled';
|
const platformStatus = platform?.status || status.platformStatus || 'enabled';
|
||||||
|
const restoreButton = runtimeRestoreButton(status, platformStatus, modelCooldownMs, platformCooldownMs, onRestore, restoring);
|
||||||
if (modelCooldownMs > 0) {
|
if (modelCooldownMs > 0) {
|
||||||
return (
|
return (
|
||||||
<span className="platformTableName">
|
<span className="platformRuntimeStatusCell">
|
||||||
<strong><Badge variant="warning">模型冷却中</Badge></strong>
|
<span className="platformTableName">
|
||||||
<small>剩余 {formatCooldownRemaining(modelCooldownMs)}</small>
|
<strong><Badge variant="warning">模型冷却中</Badge></strong>
|
||||||
|
<small>剩余 {formatCooldownRemaining(modelCooldownMs)}</small>
|
||||||
|
</span>
|
||||||
|
{restoreButton}
|
||||||
</span>
|
</span>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (platformStatus !== 'enabled') {
|
if (platformStatus !== 'enabled') {
|
||||||
const badge = <Badge variant="warning">已禁用</Badge>;
|
const badge = <Badge variant="warning">已禁用</Badge>;
|
||||||
return (
|
return (
|
||||||
<AntPopover
|
<span className="platformRuntimeStatusCell">
|
||||||
align={{ offset: [0, 8] }}
|
<AntPopover
|
||||||
content={<PlatformDisabledReasonPopover record={status.platformDisabledReason} />}
|
align={{ offset: [0, 8] }}
|
||||||
overlayClassName="priorityDemotionAntPopover"
|
content={<PlatformDisabledReasonPopover record={status.platformDisabledReason} />}
|
||||||
placement="bottomLeft"
|
overlayClassName="priorityDemotionAntPopover"
|
||||||
trigger={['hover', 'focus']}
|
placement="bottomLeft"
|
||||||
>
|
trigger={['hover', 'focus']}
|
||||||
<span className="platformTableName" tabIndex={0}>
|
>
|
||||||
<strong>{badge}</strong>
|
<span className="platformTableName" tabIndex={0}>
|
||||||
|
<strong>{badge}</strong>
|
||||||
|
</span>
|
||||||
|
</AntPopover>
|
||||||
|
{restoreButton}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!status.enabled) {
|
||||||
|
return (
|
||||||
|
<span className="platformRuntimeStatusCell">
|
||||||
|
<span className="platformTableName">
|
||||||
|
<strong><Badge variant="secondary">已停用</Badge></strong>
|
||||||
|
<small>不参与路由</small>
|
||||||
</span>
|
</span>
|
||||||
</AntPopover>
|
{restoreButton}
|
||||||
|
</span>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (platformCooldownMs > 0) {
|
if (platformCooldownMs > 0) {
|
||||||
return (
|
return (
|
||||||
<span className="platformTableName">
|
<span className="platformRuntimeStatusCell">
|
||||||
<strong><Badge variant="warning">平台冷却中</Badge></strong>
|
<span className="platformTableName">
|
||||||
<small>剩余 {formatCooldownRemaining(platformCooldownMs)}</small>
|
<strong><Badge variant="warning">平台冷却中</Badge></strong>
|
||||||
|
<small>剩余 {formatCooldownRemaining(platformCooldownMs)}</small>
|
||||||
|
</span>
|
||||||
|
{restoreButton}
|
||||||
</span>
|
</span>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<span className="platformTableName">
|
<span className="platformRuntimeStatusCell">
|
||||||
<strong><Badge variant={status.enabled ? 'success' : 'secondary'}>{status.enabled ? '可用' : '已停用'}</Badge></strong>
|
<span className="platformTableName">
|
||||||
<small>{status.enabled ? '参与路由' : '不参与路由'}</small>
|
<strong><Badge variant="success">可用</Badge></strong>
|
||||||
|
<small>参与路由</small>
|
||||||
|
</span>
|
||||||
</span>
|
</span>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function runtimeRestoreButton(
|
||||||
|
status: ModelRateLimitStatus,
|
||||||
|
platformStatus: string,
|
||||||
|
modelCooldownMs: number,
|
||||||
|
platformCooldownMs: number,
|
||||||
|
onRestore: (platformModelId: string) => Promise<void>,
|
||||||
|
restoring: boolean,
|
||||||
|
) {
|
||||||
|
const canRestore = modelCooldownMs > 0 || platformCooldownMs > 0 || platformStatus !== 'enabled' || !status.enabled;
|
||||||
|
if (!canRestore) return null;
|
||||||
|
return (
|
||||||
|
<Button
|
||||||
|
className="platformRestoreButton"
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="xs"
|
||||||
|
disabled={restoring}
|
||||||
|
onClick={() => void onRestore(status.platformModelId)}
|
||||||
|
>
|
||||||
|
<RotateCcw size={13} />
|
||||||
|
{restoring ? '恢复中' : '恢复'}
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function cooldownRemainingMs(cooldownUntil: string | undefined, now: number) {
|
function cooldownRemainingMs(cooldownUntil: string | undefined, now: number) {
|
||||||
if (!cooldownUntil) return 0;
|
if (!cooldownUntil) return 0;
|
||||||
const until = Date.parse(cooldownUntil);
|
const until = Date.parse(cooldownUntil);
|
||||||
|
|||||||
@ -1086,8 +1086,8 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
.platformLimitTable .shTableRow {
|
.platformLimitTable .shTableRow {
|
||||||
grid-template-columns: minmax(180px, 1.1fr) minmax(160px, 0.9fr) 160px 132px 150px 170px 140px 132px;
|
grid-template-columns: minmax(180px, 1.1fr) minmax(160px, 0.9fr) 178px 132px 150px 170px 140px 132px;
|
||||||
min-width: 1224px;
|
min-width: 1242px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.platformLimitTable .shTableHead,
|
.platformLimitTable .shTableHead,
|
||||||
@ -1130,6 +1130,22 @@
|
|||||||
justify-items: start;
|
justify-items: start;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.platformRuntimeStatusCell {
|
||||||
|
display: grid;
|
||||||
|
min-width: 0;
|
||||||
|
gap: 7px;
|
||||||
|
align-content: start;
|
||||||
|
}
|
||||||
|
|
||||||
|
.platformRestoreButton {
|
||||||
|
justify-self: start;
|
||||||
|
min-height: 22px;
|
||||||
|
padding-inline: 8px;
|
||||||
|
border-color: var(--border);
|
||||||
|
color: var(--text-normal);
|
||||||
|
background: #fff;
|
||||||
|
}
|
||||||
|
|
||||||
.rateMetricCell,
|
.rateMetricCell,
|
||||||
.rateLoadCell {
|
.rateLoadCell {
|
||||||
display: grid;
|
display: grid;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user