Add runtime restore and temp asset cleanup
This commit is contained in:
parent
644a6f9d17
commit
d41d9482c7
@ -4,6 +4,7 @@ import (
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -23,6 +24,7 @@ type Config struct {
|
||||
PublicBaseURL string
|
||||
LocalGeneratedStorageDir string
|
||||
LocalUploadedStorageDir string
|
||||
LocalTempAssetTTLHours int
|
||||
TaskProgressCallbackEnabled bool
|
||||
TaskProgressCallbackURL string
|
||||
TaskProgressCallbackTimeoutMS string
|
||||
@ -49,6 +51,7 @@ func Load() Config {
|
||||
PublicBaseURL: strings.TrimRight(env("AI_GATEWAY_PUBLIC_BASE_URL", env("PUBLIC_BASE_URL", "")), "/"),
|
||||
LocalGeneratedStorageDir: env("AI_GATEWAY_GENERATED_STORAGE_DIR", env("LOCAL_GENERATED_STORAGE_DIR", env("AI_GATEWAY_STATIC_STORAGE_DIR", DefaultLocalGeneratedStorageDir))),
|
||||
LocalUploadedStorageDir: env("AI_GATEWAY_UPLOADED_STORAGE_DIR", env("LOCAL_UPLOADED_STORAGE_DIR", DefaultLocalUploadedStorageDir)),
|
||||
LocalTempAssetTTLHours: envInt("AI_GATEWAY_LOCAL_TEMP_ASSET_TTL_HOURS", 24),
|
||||
TaskProgressCallbackEnabled: env("TASK_PROGRESS_CALLBACK_ENABLED", "true") == "true",
|
||||
TaskProgressCallbackURL: env("TASK_PROGRESS_CALLBACK_URL",
|
||||
strings.TrimRight(env("SERVER_MAIN_BASE_URL", "http://localhost:3000"), "/")+"/internal/platform/task-progress-callbacks",
|
||||
@ -136,6 +139,18 @@ func env(key string, fallback string) string {
|
||||
return fallback
|
||||
}
|
||||
|
||||
func envInt(key string, fallback int) int {
|
||||
value := envValue(key)
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func logLevel(value string) slog.Level {
|
||||
switch strings.ToLower(value) {
|
||||
case "debug":
|
||||
|
||||
@ -916,13 +916,26 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
return
|
||||
}
|
||||
responsePlan := planTaskResponse(kind, compatible, body, r)
|
||||
prepared, err := s.prepareTaskRequest(r.Context(), r, user, body)
|
||||
if err != nil {
|
||||
s.logger.Warn("prepare task request failed", "kind", kind, "error", err)
|
||||
status := http.StatusBadRequest
|
||||
if code := clients.ErrorCode(err); strings.HasPrefix(code, "upload_") || code == "request_asset_upload_failed" {
|
||||
status = http.StatusBadGateway
|
||||
}
|
||||
writeError(w, status, err.Error(), clients.ErrorCode(err))
|
||||
return
|
||||
}
|
||||
|
||||
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
|
||||
Kind: kind,
|
||||
Model: model,
|
||||
RunMode: runModeFromRequest(body),
|
||||
Async: responsePlan.asyncMode,
|
||||
Request: body,
|
||||
Kind: kind,
|
||||
Model: model,
|
||||
RunMode: runModeFromRequest(prepared.Body),
|
||||
Async: responsePlan.asyncMode,
|
||||
Request: prepared.Body,
|
||||
ConversationID: prepared.ConversationID,
|
||||
NewMessageCount: prepared.NewMessageCount,
|
||||
MessageRefs: prepared.MessageRefs,
|
||||
}, user)
|
||||
if err != nil {
|
||||
s.logger.Error("create task failed", "kind", kind, "error", err)
|
||||
|
||||
70
apps/api/internal/httpapi/local_temp_assets.go
Normal file
70
apps/api/internal/httpapi/local_temp_assets.go
Normal file
@ -0,0 +1,70 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func (s *Server) startLocalTempAssetCleanup(ctx context.Context) {
|
||||
go func() {
|
||||
s.cleanupExpiredLocalTempAssets(ctx, time.Now())
|
||||
ticker := time.NewTicker(time.Hour)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
s.cleanupExpiredLocalTempAssets(ctx, now)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) cleanupExpiredLocalTempAssets(ctx context.Context, now time.Time) int {
|
||||
storageDir := strings.TrimSpace(s.cfg.LocalUploadedStorageDir)
|
||||
if storageDir == "" {
|
||||
storageDir = config.DefaultLocalUploadedStorageDir
|
||||
}
|
||||
entries, err := os.ReadDir(storageDir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
s.logger.Warn("read local temp asset dir failed", "dir", storageDir, "error", err)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
ttl := time.Duration(s.localTempAssetTTLHours()) * time.Hour
|
||||
expiredBefore := now.Add(-ttl)
|
||||
deleted := 0
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasPrefix(entry.Name(), requestAssetFilePrefix) {
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.ModTime().After(expiredBefore) {
|
||||
continue
|
||||
}
|
||||
localPath := filepath.Join(storageDir, entry.Name())
|
||||
if err := os.Remove(localPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
s.logger.Warn("remove local temp asset failed", "path", localPath, "error", err)
|
||||
continue
|
||||
}
|
||||
deleted++
|
||||
if s.store != nil {
|
||||
if err := s.store.MarkRequestAssetExpiredByLocalPath(ctx, localPath, now); err != nil && !store.IsUndefinedDatabaseObject(err) {
|
||||
s.logger.Warn("mark local temp asset expired failed", "path", localPath, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return deleted
|
||||
}
|
||||
583
apps/api/internal/httpapi/request_preparation.go
Normal file
583
apps/api/internal/httpapi/request_preparation.go
Normal file
@ -0,0 +1,583 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
const requestAssetFilePrefix = "gateway-request-asset-"
|
||||
|
||||
type preparedTaskRequest struct {
|
||||
Body map[string]any
|
||||
ConversationID string
|
||||
MessageRefs []store.TaskMessageRefInput
|
||||
NewMessageCount int
|
||||
}
|
||||
|
||||
type decodedRequestAsset struct {
|
||||
Bytes []byte
|
||||
ContentType string
|
||||
}
|
||||
|
||||
func (s *Server) prepareTaskRequest(ctx context.Context, r *http.Request, user *auth.User, body map[string]any) (preparedTaskRequest, error) {
|
||||
preparedBody, err := s.prepareRequestAssetRefs(ctx, body)
|
||||
if err != nil {
|
||||
return preparedTaskRequest{}, err
|
||||
}
|
||||
result := preparedTaskRequest{Body: preparedBody}
|
||||
conversationKey := requestConversationKey(r, preparedBody)
|
||||
if conversationKey == "" {
|
||||
return result, nil
|
||||
}
|
||||
messages, ok := preparedBody["messages"].([]any)
|
||||
if !ok || len(messages) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
conversationID, err := s.store.EnsureConversation(ctx, user, conversationKey, map[string]any{
|
||||
"source": "ai-gateway",
|
||||
"conversationKey": conversationKey,
|
||||
})
|
||||
if err != nil {
|
||||
return preparedTaskRequest{}, err
|
||||
}
|
||||
inputs := make([]store.ConversationMessageInput, 0, len(messages))
|
||||
for _, rawMessage := range messages {
|
||||
message, _ := rawMessage.(map[string]any)
|
||||
if message == nil {
|
||||
message = map[string]any{"content": rawMessage}
|
||||
}
|
||||
hash, assetHashes := canonicalConversationMessageHash(message)
|
||||
inputs = append(inputs, store.ConversationMessageInput{
|
||||
Hash: hash,
|
||||
Role: stringFromRequestAny(message["role"]),
|
||||
Snapshot: message,
|
||||
AssetSHA256s: assetHashes,
|
||||
})
|
||||
}
|
||||
refs, newCount, err := s.store.UpsertConversationMessages(ctx, conversationID, inputs)
|
||||
if err != nil {
|
||||
return preparedTaskRequest{}, err
|
||||
}
|
||||
preparedBody["conversationId"] = conversationKey
|
||||
preparedBody["conversationRecordId"] = conversationID
|
||||
preparedBody["messageRefs"] = messageRefsForRequest(refs)
|
||||
preparedBody["newMessageCount"] = newCount
|
||||
delete(preparedBody, "messages")
|
||||
result.ConversationID = conversationID
|
||||
result.MessageRefs = refs
|
||||
result.NewMessageCount = newCount
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Server) prepareRequestAssetRefs(ctx context.Context, body map[string]any) (map[string]any, error) {
|
||||
value, err := s.prepareRequestAssetValue(ctx, body, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next, _ := value.(map[string]any)
|
||||
if next == nil {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return next, nil
|
||||
}
|
||||
|
||||
func (s *Server) prepareRequestAssetValue(ctx context.Context, value any, path []string, siblings map[string]any) (any, error) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
if typed["assetRef"] != nil {
|
||||
return typed, nil
|
||||
}
|
||||
next := make(map[string]any, len(typed))
|
||||
keys := make([]string, 0, len(typed))
|
||||
for key := range typed {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
item := typed[key]
|
||||
if decoded, ok, err := requestAssetFromValue(key, path, item, typed); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
ref, err := s.ensureRequestAsset(ctx, decoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next[key] = requestAssetWrapper(ref)
|
||||
continue
|
||||
}
|
||||
prepared, err := s.prepareRequestAssetValue(ctx, item, append(path, key), typed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next[key] = prepared
|
||||
}
|
||||
return next, nil
|
||||
case []any:
|
||||
next := make([]any, 0, len(typed))
|
||||
for index, item := range typed {
|
||||
prepared, err := s.prepareRequestAssetValue(ctx, item, append(path, fmt.Sprintf("[%d]", index)), siblings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next = append(next, prepared)
|
||||
}
|
||||
return next, nil
|
||||
default:
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func requestAssetFromValue(key string, path []string, value any, siblings map[string]any) (decodedRequestAsset, bool, error) {
|
||||
text, ok := value.(string)
|
||||
if !ok {
|
||||
return decodedRequestAsset{}, false, nil
|
||||
}
|
||||
raw := strings.TrimSpace(text)
|
||||
if raw == "" || mediaURLString(raw) {
|
||||
return decodedRequestAsset{}, false, nil
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "data:") {
|
||||
contentType, encoded, ok, err := parseRequestDataURL(raw)
|
||||
if err != nil || !ok {
|
||||
return decodedRequestAsset{}, false, err
|
||||
}
|
||||
payload, err := decodeRequestBase64(encoded)
|
||||
if err != nil {
|
||||
return decodedRequestAsset{}, false, requestAssetDecodeError(err)
|
||||
}
|
||||
return decodedRequestAsset{
|
||||
Bytes: payload,
|
||||
ContentType: requestAssetContentType(contentType, payload, key, path, siblings),
|
||||
}, true, nil
|
||||
}
|
||||
strict := strictRequestBase64Field(key, path)
|
||||
if !strict && !likelyRequestBase64MediaField(key, path, raw) {
|
||||
return decodedRequestAsset{}, false, nil
|
||||
}
|
||||
payload, err := decodeRequestBase64(raw)
|
||||
if err != nil {
|
||||
if strict {
|
||||
return decodedRequestAsset{}, false, requestAssetDecodeError(err)
|
||||
}
|
||||
return decodedRequestAsset{}, false, nil
|
||||
}
|
||||
contentType := requestAssetContentType("", payload, key, path, siblings)
|
||||
if !strict && !requestContentTypeIsMedia(contentType) {
|
||||
return decodedRequestAsset{}, false, nil
|
||||
}
|
||||
return decodedRequestAsset{Bytes: payload, ContentType: contentType}, true, nil
|
||||
}
|
||||
|
||||
func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) {
|
||||
sum := sha256.Sum256(decoded.Bytes)
|
||||
sha := hex.EncodeToString(sum[:])
|
||||
contentType := strings.TrimSpace(decoded.ContentType)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
now := time.Now()
|
||||
if existing, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
|
||||
return nil, err
|
||||
} else if ok && requestAssetStillUsable(existing, now) {
|
||||
if err := s.store.IncrementRequestAssetRefCount(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
|
||||
return nil, err
|
||||
}
|
||||
return requestAssetRef(existing), nil
|
||||
}
|
||||
|
||||
upload, err := s.runner.UploadFile(ctx, runner.FileUploadPayload{
|
||||
Bytes: decoded.Bytes,
|
||||
ContentType: contentType,
|
||||
FileName: requestAssetFileName(sha, contentType),
|
||||
Scene: store.FileStorageSceneRequestAsset,
|
||||
Source: "ai-gateway-request",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
storageProvider := requestAssetStorageProvider(upload)
|
||||
url := stringFromRequestAny(upload["url"])
|
||||
if url == "" {
|
||||
return nil, &clients.ClientError{Code: "request_asset_upload_failed", Message: "file storage response did not include url", Retryable: false}
|
||||
}
|
||||
var expiresAt *time.Time
|
||||
localPath := ""
|
||||
if storageProvider == "local_static" {
|
||||
expiry := now.Add(time.Duration(s.localTempAssetTTLHours()) * time.Hour)
|
||||
expiresAt = &expiry
|
||||
localPath = requestAssetLocalPath(s.cfg.LocalUploadedStorageDir, stringFromRequestAny(upload["fileName"]))
|
||||
}
|
||||
asset, err := s.store.UpsertRequestAsset(ctx, store.RequestAssetInput{
|
||||
SHA256: sha,
|
||||
ContentType: contentType,
|
||||
ByteSize: int64(len(decoded.Bytes)),
|
||||
URL: url,
|
||||
StorageProvider: storageProvider,
|
||||
LocalPath: localPath,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if err != nil {
|
||||
if store.IsUndefinedDatabaseObject(err) {
|
||||
return map[string]any{
|
||||
"sha256": sha,
|
||||
"url": url,
|
||||
"contentType": contentType,
|
||||
"size": len(decoded.Bytes),
|
||||
"storageProvider": storageProvider,
|
||||
"expiresAt": timePtrToRFC3339(expiresAt),
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return requestAssetRef(asset), nil
|
||||
}
|
||||
|
||||
func requestConversationKey(r *http.Request, body map[string]any) string {
|
||||
if r != nil {
|
||||
if value := strings.TrimSpace(r.Header.Get("X-EasyAI-Conversation-ID")); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
if value := firstNonEmptyRequestString(body, "conversation_id", "conversationId"); value != "" {
|
||||
return value
|
||||
}
|
||||
if metadata, ok := body["metadata"].(map[string]any); ok {
|
||||
return firstNonEmptyRequestString(metadata, "conversation_id", "conversationId")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func messageRefsForRequest(refs []store.TaskMessageRefInput) []any {
|
||||
out := make([]any, 0, len(refs))
|
||||
for _, ref := range refs {
|
||||
out = append(out, map[string]any{"messageId": ref.MessageID, "position": ref.Position})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func canonicalConversationMessageHash(message map[string]any) (string, []string) {
|
||||
parts := map[string]any{
|
||||
"role": stringFromRequestAny(message["role"]),
|
||||
}
|
||||
texts := []string{}
|
||||
assetHashes := []string{}
|
||||
collectConversationMessageParts(message, &texts, &assetHashes)
|
||||
parts["text"] = strings.Join(texts, "\n")
|
||||
parts["assets"] = assetHashes
|
||||
if parts["text"] == "" && len(assetHashes) == 0 {
|
||||
encoded, _ := json.Marshal(message)
|
||||
parts["fallback"] = string(encoded)
|
||||
}
|
||||
encoded, _ := json.Marshal(parts)
|
||||
sum := sha256.Sum256(encoded)
|
||||
return hex.EncodeToString(sum[:]), uniqueRequestStrings(assetHashes)
|
||||
}
|
||||
|
||||
func collectConversationMessageParts(value any, texts *[]string, assetHashes *[]string) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
if ref, ok := typed["assetRef"].(map[string]any); ok {
|
||||
if sha := stringFromRequestAny(ref["sha256"]); sha != "" {
|
||||
*assetHashes = append(*assetHashes, sha)
|
||||
}
|
||||
}
|
||||
for key, item := range typed {
|
||||
switch key {
|
||||
case "content", "text":
|
||||
if text := stringFromRequestAny(item); text != "" && !strings.HasPrefix(strings.ToLower(text), "data:") {
|
||||
*texts = append(*texts, text)
|
||||
continue
|
||||
}
|
||||
}
|
||||
switch item.(type) {
|
||||
case map[string]any, []any:
|
||||
collectConversationMessageParts(item, texts, assetHashes)
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range typed {
|
||||
collectConversationMessageParts(item, texts, assetHashes)
|
||||
}
|
||||
case string:
|
||||
text := strings.TrimSpace(typed)
|
||||
if text != "" && !mediaURLString(text) && !strings.HasPrefix(strings.ToLower(text), "data:") {
|
||||
*texts = append(*texts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func requestAssetWrapper(ref map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"assetRef": ref,
|
||||
"url": ref["url"],
|
||||
}
|
||||
}
|
||||
|
||||
func requestAssetRef(asset store.RequestAsset) map[string]any {
|
||||
return map[string]any{
|
||||
"sha256": asset.SHA256,
|
||||
"url": asset.URL,
|
||||
"contentType": asset.ContentType,
|
||||
"size": asset.ByteSize,
|
||||
"storageProvider": asset.StorageProvider,
|
||||
"expiresAt": timePtrToRFC3339(asset.ExpiresAt),
|
||||
}
|
||||
}
|
||||
|
||||
func requestAssetStillUsable(asset store.RequestAsset, now time.Time) bool {
|
||||
if asset.ExpiredAt != nil {
|
||||
return false
|
||||
}
|
||||
if asset.ExpiresAt != nil && !asset.ExpiresAt.After(now) {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(asset.URL) != ""
|
||||
}
|
||||
|
||||
func requestAssetStorageProvider(upload map[string]any) string {
|
||||
if channel, ok := upload["storageChannel"].(map[string]any); ok {
|
||||
if provider := stringFromRequestAny(channel["provider"]); provider != "" {
|
||||
return provider
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func requestAssetLocalPath(storageDir string, fileName string) string {
|
||||
if strings.TrimSpace(storageDir) == "" {
|
||||
storageDir = config.DefaultLocalUploadedStorageDir
|
||||
}
|
||||
if strings.TrimSpace(fileName) == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(storageDir, filepath.Base(fileName))
|
||||
}
|
||||
|
||||
func requestAssetFileName(sha string, contentType string) string {
|
||||
prefix := sha
|
||||
if len(prefix) > 16 {
|
||||
prefix = prefix[:16]
|
||||
}
|
||||
return requestAssetFilePrefix + prefix + requestAssetExtension(contentType)
|
||||
}
|
||||
|
||||
func requestAssetExtension(contentType string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(contentType)) {
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/jpeg", "image/jpg":
|
||||
return ".jpg"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "audio/mpeg":
|
||||
return ".mp3"
|
||||
case "audio/wav", "audio/x-wav":
|
||||
return ".wav"
|
||||
case "video/mp4":
|
||||
return ".mp4"
|
||||
}
|
||||
if extensions, err := mime.ExtensionsByType(contentType); err == nil && len(extensions) > 0 {
|
||||
return extensions[0]
|
||||
}
|
||||
return ".bin"
|
||||
}
|
||||
|
||||
func requestAssetContentType(explicit string, payload []byte, key string, path []string, siblings map[string]any) string {
|
||||
if contentType := firstNonEmptyRequestString(siblings, "content_type", "contentType", "mime_type", "mimeType"); contentType != "" {
|
||||
return contentType
|
||||
}
|
||||
if explicit = strings.TrimSpace(explicit); explicit != "" {
|
||||
return explicit
|
||||
}
|
||||
detected := ""
|
||||
if len(payload) > 0 {
|
||||
detected = http.DetectContentType(payload)
|
||||
}
|
||||
if requestContentTypeIsMedia(detected) {
|
||||
return detected
|
||||
}
|
||||
switch requestMediaKind(key, path, siblings) {
|
||||
case "audio":
|
||||
return "audio/mpeg"
|
||||
case "video":
|
||||
return "video/mp4"
|
||||
default:
|
||||
return "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
func parseRequestDataURL(value string) (string, string, bool, error) {
|
||||
prefix, payload, ok := strings.Cut(value, ",")
|
||||
if !ok {
|
||||
return "", "", false, &clients.ClientError{Code: "request_asset_decode_failed", Message: "invalid data URL media payload", Retryable: false}
|
||||
}
|
||||
meta := strings.TrimPrefix(strings.TrimPrefix(prefix, "data:"), "DATA:")
|
||||
parts := strings.Split(meta, ";")
|
||||
contentType := strings.TrimSpace(parts[0])
|
||||
for _, part := range parts[1:] {
|
||||
if strings.EqualFold(strings.TrimSpace(part), "base64") {
|
||||
return contentType, payload, true, nil
|
||||
}
|
||||
}
|
||||
return "", "", false, &clients.ClientError{Code: "request_asset_decode_failed", Message: "data URL media payload is not base64 encoded", Retryable: false}
|
||||
}
|
||||
|
||||
func decodeRequestBase64(value string) ([]byte, error) {
|
||||
normalized := strings.Map(func(r rune) rune {
|
||||
switch r {
|
||||
case '\n', '\r', '\t', ' ':
|
||||
return -1
|
||||
default:
|
||||
return r
|
||||
}
|
||||
}, strings.TrimSpace(value))
|
||||
encodings := []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
}
|
||||
var lastErr error
|
||||
for _, encoding := range encodings {
|
||||
payload, err := encoding.DecodeString(normalized)
|
||||
if err == nil && len(payload) > 0 {
|
||||
return payload, nil
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("empty base64 payload")
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func requestAssetDecodeError(err error) error {
|
||||
return &clients.ClientError{Code: "request_asset_decode_failed", Message: err.Error(), Retryable: false}
|
||||
}
|
||||
|
||||
func strictRequestBase64Field(key string, path []string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(key))
|
||||
parent := ""
|
||||
if len(path) > 0 {
|
||||
parent = strings.ToLower(strings.Trim(path[len(path)-1], "[]"))
|
||||
}
|
||||
return lower == "b64_json" ||
|
||||
lower == "base64" ||
|
||||
lower == "b64" ||
|
||||
strings.Contains(lower, "base64") ||
|
||||
strings.Contains(lower, "_b64") ||
|
||||
(parent == "input_audio" && lower == "data")
|
||||
}
|
||||
|
||||
func likelyRequestBase64MediaField(key string, path []string, value string) bool {
|
||||
if len(value) < 64 {
|
||||
return false
|
||||
}
|
||||
return requestMediaKind(key, path, nil) != ""
|
||||
}
|
||||
|
||||
func requestMediaKind(key string, path []string, siblings map[string]any) string {
|
||||
candidates := append([]string{key}, path...)
|
||||
if siblings != nil {
|
||||
candidates = append(candidates, stringFromRequestAny(siblings["type"]), stringFromRequestAny(siblings["role"]))
|
||||
}
|
||||
for _, item := range candidates {
|
||||
lower := strings.ToLower(strings.TrimSpace(item))
|
||||
switch {
|
||||
case strings.Contains(lower, "audio"):
|
||||
return "audio"
|
||||
case strings.Contains(lower, "video"):
|
||||
return "video"
|
||||
case strings.Contains(lower, "image") || lower == "mask" || lower == "b64_json":
|
||||
return "image"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func requestContentTypeIsMedia(contentType string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(contentType))
|
||||
return strings.HasPrefix(lower, "image/") ||
|
||||
strings.HasPrefix(lower, "video/") ||
|
||||
strings.HasPrefix(lower, "audio/")
|
||||
}
|
||||
|
||||
func mediaURLString(value string) bool {
|
||||
raw := strings.TrimSpace(value)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
if strings.HasPrefix(lower, "data:") {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(lower, "http://") ||
|
||||
strings.HasPrefix(lower, "https://") ||
|
||||
strings.HasPrefix(lower, "/") ||
|
||||
strings.Contains(lower, "://")
|
||||
}
|
||||
|
||||
func firstNonEmptyRequestString(values map[string]any, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if values == nil {
|
||||
continue
|
||||
}
|
||||
if value := stringFromRequestAny(values[key]); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func stringFromRequestAny(value any) string {
|
||||
text, _ := value.(string)
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func uniqueRequestStrings(values []string) []string {
|
||||
seen := map[string]bool{}
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" || seen[value] {
|
||||
continue
|
||||
}
|
||||
seen[value] = true
|
||||
out = append(out, value)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func timePtrToRFC3339(value *time.Time) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
return value.UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func (s *Server) localTempAssetTTLHours() int {
|
||||
if s.cfg.LocalTempAssetTTLHours <= 0 {
|
||||
return 24
|
||||
}
|
||||
return s.cfg.LocalTempAssetTTLHours
|
||||
}
|
||||
136
apps/api/internal/httpapi/request_preparation_test.go
Normal file
136
apps/api/internal/httpapi/request_preparation_test.go
Normal file
@ -0,0 +1,136 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
)
|
||||
|
||||
func TestRequestAssetFromValueDetectsDataURLAndRawBase64(t *testing.T) {
|
||||
payload := base64.StdEncoding.EncodeToString([]byte("inline image"))
|
||||
decoded, ok, err := requestAssetFromValue("url", []string{"messages", "[0]", "content", "[1]", "image_url"}, "data:image/png;base64,"+payload, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("decode data URL: %v", err)
|
||||
}
|
||||
if !ok || decoded.ContentType != "image/png" || string(decoded.Bytes) != "inline image" {
|
||||
t.Fatalf("unexpected data URL asset: ok=%v decoded=%+v", ok, decoded)
|
||||
}
|
||||
|
||||
audio := base64.StdEncoding.EncodeToString([]byte("inline audio"))
|
||||
decoded, ok, err = requestAssetFromValue("data", []string{"input_audio"}, audio, map[string]any{"format": "mp3"})
|
||||
if err != nil {
|
||||
t.Fatalf("decode raw audio: %v", err)
|
||||
}
|
||||
if !ok || decoded.ContentType != "audio/mpeg" || string(decoded.Bytes) != "inline audio" {
|
||||
t.Fatalf("unexpected raw audio asset: ok=%v decoded=%+v", ok, decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalConversationMessageHashUsesTextAndAssetRefs(t *testing.T) {
|
||||
message := map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "describe it"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{
|
||||
"url": "https://cdn.example/a.png",
|
||||
"assetRef": map[string]any{"sha256": "sha-a", "url": "https://cdn.example/a.png"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
sameMessage := map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "describe it"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{
|
||||
"url": "https://different.example/a.png",
|
||||
"assetRef": map[string]any{"sha256": "sha-a", "url": "https://different.example/a.png"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
changedMessage := map[string]any{
|
||||
"role": "user",
|
||||
"content": "describe something else",
|
||||
}
|
||||
|
||||
firstHash, assetHashes := canonicalConversationMessageHash(message)
|
||||
secondHash, _ := canonicalConversationMessageHash(sameMessage)
|
||||
changedHash, _ := canonicalConversationMessageHash(changedMessage)
|
||||
if firstHash != secondHash {
|
||||
t.Fatalf("message hash should ignore resource URL drift when asset sha is stable")
|
||||
}
|
||||
if firstHash == changedHash {
|
||||
t.Fatalf("message hash should change when text changes")
|
||||
}
|
||||
if len(assetHashes) != 1 || assetHashes[0] != "sha-a" {
|
||||
t.Fatalf("unexpected asset hashes: %+v", assetHashes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupExpiredLocalTempAssetsOnlyDeletesExpiredPrefixedFiles(t *testing.T) {
|
||||
storageDir := t.TempDir()
|
||||
oldTemp := filepath.Join(storageDir, requestAssetFilePrefix+"old.png")
|
||||
freshTemp := filepath.Join(storageDir, requestAssetFilePrefix+"fresh.png")
|
||||
oldGenerated := filepath.Join(storageDir, "gateway-result-old.png")
|
||||
for _, path := range []string{oldTemp, freshTemp, oldGenerated} {
|
||||
if err := os.WriteFile(path, []byte("asset"), 0o644); err != nil {
|
||||
t.Fatalf("write fixture %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
now := time.Now()
|
||||
if err := os.Chtimes(oldTemp, now.Add(-25*time.Hour), now.Add(-25*time.Hour)); err != nil {
|
||||
t.Fatalf("touch old temp: %v", err)
|
||||
}
|
||||
if err := os.Chtimes(freshTemp, now.Add(-23*time.Hour), now.Add(-23*time.Hour)); err != nil {
|
||||
t.Fatalf("touch fresh temp: %v", err)
|
||||
}
|
||||
if err := os.Chtimes(oldGenerated, now.Add(-25*time.Hour), now.Add(-25*time.Hour)); err != nil {
|
||||
t.Fatalf("touch old generated: %v", err)
|
||||
}
|
||||
server := &Server{
|
||||
cfg: config.Config{LocalUploadedStorageDir: storageDir, LocalTempAssetTTLHours: 24},
|
||||
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
|
||||
deleted := server.cleanupExpiredLocalTempAssets(context.Background(), now)
|
||||
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected one expired temp asset delete, got %d", deleted)
|
||||
}
|
||||
if _, err := os.Stat(oldTemp); !os.IsNotExist(err) {
|
||||
t.Fatalf("old prefixed temp asset should be deleted, stat err=%v", err)
|
||||
}
|
||||
for _, path := range []string{freshTemp, oldGenerated} {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Fatalf("non-expired or non-prefixed file should remain %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestConversationKeyPriority(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
||||
request.Header.Set("X-EasyAI-Conversation-ID", "from-header")
|
||||
body := map[string]any{
|
||||
"conversation_id": "from-body",
|
||||
"metadata": map[string]any{"conversation_id": "from-metadata"},
|
||||
}
|
||||
if got := requestConversationKey(request, body); got != "from-header" {
|
||||
t.Fatalf("expected header conversation id, got %q", got)
|
||||
}
|
||||
request.Header.Del("X-EasyAI-Conversation-ID")
|
||||
if got := requestConversationKey(request, body); got != "from-body" {
|
||||
t.Fatalf("expected body conversation id, got %q", got)
|
||||
}
|
||||
delete(body, "conversation_id")
|
||||
if got := requestConversationKey(request, body); got != "from-metadata" {
|
||||
t.Fatalf("expected metadata conversation id, got %q", got)
|
||||
}
|
||||
}
|
||||
@ -134,6 +134,33 @@ func (s *Server) updatePlatformDynamicPriority(w http.ResponseWriter, r *http.Re
|
||||
writeJSON(w, http.StatusOK, item)
|
||||
}
|
||||
|
||||
// restorePlatformModelRuntimeStatus godoc
|
||||
// @Summary 恢复平台模型运行状态
|
||||
// @Description 管理端手动解除平台模型停用、模型冷却、平台冷却或平台禁用状态,使其重新参与路由。
|
||||
// @Tags runtime
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param platformModelID path string true "平台模型 ID"
|
||||
// @Success 200 {object} store.ModelRateLimitStatus
|
||||
// @Failure 401 {object} ErrorEnvelope
|
||||
// @Failure 403 {object} ErrorEnvelope
|
||||
// @Failure 404 {object} ErrorEnvelope
|
||||
// @Failure 500 {object} ErrorEnvelope
|
||||
// @Router /api/admin/runtime/model-rate-limits/{platformModelID}/restore [post]
|
||||
func (s *Server) restorePlatformModelRuntimeStatus(w http.ResponseWriter, r *http.Request) {
|
||||
item, err := s.store.RestorePlatformModelRuntimeStatus(r.Context(), r.PathValue("platformModelID"))
|
||||
if err != nil {
|
||||
if store.IsNotFound(err) {
|
||||
writeError(w, http.StatusNotFound, "platform model not found")
|
||||
return
|
||||
}
|
||||
s.logger.Error("restore platform model runtime status failed", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "restore platform model runtime status failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, item)
|
||||
}
|
||||
|
||||
// createRuntimePolicySet godoc
|
||||
// @Summary 创建运行策略集
|
||||
// @Description 管理端创建运行策略集,policyKey 和 name 必填。
|
||||
|
||||
@ -36,6 +36,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
||||
}
|
||||
server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey
|
||||
server.runner.StartAsyncQueueWorker(ctx)
|
||||
server.startLocalTempAssetCleanup(ctx)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /healthz", server.health)
|
||||
@ -103,6 +104,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
||||
mux.Handle("DELETE /api/admin/runtime/policy-sets/{policySetID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteRuntimePolicySet)))
|
||||
mux.Handle("GET /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getRunnerPolicy)))
|
||||
mux.Handle("PATCH /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateRunnerPolicy)))
|
||||
mux.Handle("POST /api/admin/runtime/model-rate-limits/{platformModelID}/restore", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.restorePlatformModelRuntimeStatus)))
|
||||
mux.Handle("GET /api/admin/config/network-proxy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getNetworkProxyConfig)))
|
||||
mux.Handle("GET /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getFileStorageSettings)))
|
||||
mux.Handle("PATCH /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateFileStorageSettings)))
|
||||
@ -173,7 +175,7 @@ func (s *Server) cors(next http.Handler) http.Handler {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async, X-EasyAI-Conversation-ID")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
|
||||
@ -11,8 +11,18 @@ func stringFromMap(values map[string]any, key string) string {
|
||||
}
|
||||
|
||||
func stringFromAny(value any) string {
|
||||
text, _ := value.(string)
|
||||
return strings.TrimSpace(text)
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case map[string]any:
|
||||
if text := stringFromAny(typed["url"]); text != "" {
|
||||
return text
|
||||
}
|
||||
if ref, ok := typed["assetRef"].(map[string]any); ok {
|
||||
return stringFromAny(ref["url"])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func boolFromMap(values map[string]any, key string) bool {
|
||||
|
||||
249
apps/api/internal/runner/request_assets.go
Normal file
249
apps/api/internal/runner/request_assets.go
Normal file
@ -0,0 +1,249 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func (s *Service) restoreTaskRequestReferences(ctx context.Context, task store.GatewayTask) (map[string]any, error) {
|
||||
body := cloneMap(task.Request)
|
||||
if body["messages"] != nil || body["messageRefs"] == nil || s.store == nil {
|
||||
return body, nil
|
||||
}
|
||||
refs, err := s.store.ListTaskConversationMessages(ctx, task.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(refs) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
messages := make([]any, 0, len(refs))
|
||||
for _, ref := range refs {
|
||||
messages = append(messages, ref.Message)
|
||||
}
|
||||
body["messages"] = messages
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (s *Service) slimTaskRequestSnapshot(task store.GatewayTask, body map[string]any) map[string]any {
|
||||
out := cloneMap(body)
|
||||
messageRefs := task.Request["messageRefs"]
|
||||
if messageRefs == nil {
|
||||
return out
|
||||
}
|
||||
delete(out, "messages")
|
||||
out["messageRefs"] = messageRefs
|
||||
for _, key := range []string{"conversationId", "conversationRecordId", "newMessageCount"} {
|
||||
if value := task.Request[key]; value != nil {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Service) slimParameterPreprocessingLog(task store.GatewayTask, log parameterPreprocessingLog) parameterPreprocessingLog {
|
||||
log.Input = s.slimTaskRequestSnapshot(task, log.Input)
|
||||
log.Output = s.slimTaskRequestSnapshot(task, log.Output)
|
||||
return log
|
||||
}
|
||||
|
||||
func (s *Service) hydrateProviderRequestAssets(ctx context.Context, body map[string]any) (map[string]any, error) {
|
||||
value, err := s.hydrateProviderRequestAssetValue(ctx, body, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out, _ := value.(map[string]any)
|
||||
if out == nil {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Service) hydrateProviderRequestAssetValue(ctx context.Context, value any, path []string) (any, error) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
if ref, ok := typed["assetRef"].(map[string]any); ok {
|
||||
return s.hydrateProviderRequestAssetRef(ctx, ref, path)
|
||||
}
|
||||
next := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
hydrated, err := s.hydrateProviderRequestAssetValue(ctx, item, append(path, key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next[key] = hydrated
|
||||
}
|
||||
return next, nil
|
||||
case []any:
|
||||
next := make([]any, 0, len(typed))
|
||||
for index, item := range typed {
|
||||
hydrated, err := s.hydrateProviderRequestAssetValue(ctx, item, append(path, fmt.Sprintf("[%d]", index)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next = append(next, hydrated)
|
||||
}
|
||||
return next, nil
|
||||
default:
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) hydrateProviderRequestAssetRef(ctx context.Context, ref map[string]any, path []string) (any, error) {
|
||||
asset, err := s.resolveRequestAsset(ctx, ref)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if providerFieldNeedsBase64(path) {
|
||||
payload, err := s.readRequestAssetBytes(ctx, asset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(payload), nil
|
||||
}
|
||||
if strings.TrimSpace(asset.URL) == "" {
|
||||
return nil, requestAssetExpiredError(asset)
|
||||
}
|
||||
return asset.URL, nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveRequestAsset(ctx context.Context, ref map[string]any) (store.RequestAsset, error) {
|
||||
sha := stringFromAny(ref["sha256"])
|
||||
contentType := stringFromAny(ref["contentType"])
|
||||
asset := store.RequestAsset{
|
||||
SHA256: sha,
|
||||
ContentType: contentType,
|
||||
URL: stringFromAny(ref["url"]),
|
||||
StorageProvider: stringFromAny(ref["storageProvider"]),
|
||||
}
|
||||
if size := floatFromAny(ref["size"]); size > 0 {
|
||||
asset.ByteSize = int64(size)
|
||||
}
|
||||
if expiresAt := stringFromAny(ref["expiresAt"]); expiresAt != "" {
|
||||
if parsed, err := time.Parse(time.RFC3339, expiresAt); err == nil {
|
||||
asset.ExpiresAt = &parsed
|
||||
}
|
||||
}
|
||||
if s.store != nil && sha != "" && contentType != "" {
|
||||
if stored, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
|
||||
return store.RequestAsset{}, err
|
||||
} else if ok {
|
||||
asset = stored
|
||||
}
|
||||
}
|
||||
if requestAssetIsExpired(asset, time.Now()) {
|
||||
return store.RequestAsset{}, requestAssetExpiredError(asset)
|
||||
}
|
||||
return asset, nil
|
||||
}
|
||||
|
||||
func (s *Service) readRequestAssetBytes(ctx context.Context, asset store.RequestAsset) ([]byte, error) {
|
||||
if requestAssetIsExpired(asset, time.Now()) {
|
||||
return nil, requestAssetExpiredError(asset)
|
||||
}
|
||||
if strings.TrimSpace(asset.LocalPath) != "" {
|
||||
payload, err := os.ReadFile(asset.LocalPath)
|
||||
if err != nil {
|
||||
return nil, requestAssetExpiredError(asset)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
if localPath := s.localPathFromRequestAssetURL(asset.URL); localPath != "" {
|
||||
payload, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
return nil, requestAssetExpiredError(asset)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
if strings.HasPrefix(asset.URL, "http://") || strings.HasPrefix(asset.URL, "https://") {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, asset.URL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: err.Error(), Retryable: true}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: resp.Status, StatusCode: resp.StatusCode, Retryable: clients.HTTPRetryable(resp.StatusCode)}
|
||||
}
|
||||
payload, err := io.ReadAll(io.LimitReader(resp.Body, 256<<20))
|
||||
if err != nil {
|
||||
return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: err.Error(), Retryable: true}
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
return nil, requestAssetExpiredError(asset)
|
||||
}
|
||||
|
||||
func (s *Service) localPathFromRequestAssetURL(value string) string {
|
||||
raw := strings.TrimSpace(value)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
pathValue := raw
|
||||
if parsed, err := url.Parse(raw); err == nil && parsed.Path != "" {
|
||||
pathValue = parsed.Path
|
||||
}
|
||||
const uploadedPrefix = "/static/uploaded/"
|
||||
if !strings.HasPrefix(pathValue, uploadedPrefix) {
|
||||
return ""
|
||||
}
|
||||
fileName := filepath.Base(strings.TrimPrefix(pathValue, uploadedPrefix))
|
||||
if !strings.HasPrefix(fileName, "gateway-request-asset-") {
|
||||
return ""
|
||||
}
|
||||
storageDir := strings.TrimSpace(s.cfg.LocalUploadedStorageDir)
|
||||
if storageDir == "" {
|
||||
storageDir = config.DefaultLocalUploadedStorageDir
|
||||
}
|
||||
return filepath.Join(storageDir, fileName)
|
||||
}
|
||||
|
||||
func providerFieldNeedsBase64(path []string) bool {
|
||||
if len(path) == 0 {
|
||||
return false
|
||||
}
|
||||
key := strings.ToLower(strings.Trim(path[len(path)-1], "[]"))
|
||||
parent := ""
|
||||
if len(path) > 1 {
|
||||
parent = strings.ToLower(strings.Trim(path[len(path)-2], "[]"))
|
||||
}
|
||||
return key == "b64_json" ||
|
||||
key == "base64" ||
|
||||
key == "b64" ||
|
||||
strings.Contains(key, "base64") ||
|
||||
strings.Contains(key, "_b64") ||
|
||||
(parent == "input_audio" && key == "data")
|
||||
}
|
||||
|
||||
func requestAssetIsExpired(asset store.RequestAsset, now time.Time) bool {
|
||||
if asset.ExpiredAt != nil {
|
||||
return true
|
||||
}
|
||||
if asset.ExpiresAt != nil && !asset.ExpiresAt.After(now) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func requestAssetExpiredError(asset store.RequestAsset) error {
|
||||
message := "request asset is expired or unavailable"
|
||||
if asset.SHA256 != "" {
|
||||
message = "request asset is expired or unavailable: " + asset.SHA256
|
||||
}
|
||||
return &clients.ClientError{Code: "request_asset_expired", Message: message, Retryable: false}
|
||||
}
|
||||
101
apps/api/internal/runner/request_assets_test.go
Normal file
101
apps/api/internal/runner/request_assets_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func TestHydrateProviderRequestAssetsConvertsStrictBase64Field(t *testing.T) {
|
||||
storageDir := t.TempDir()
|
||||
fileName := "gateway-request-asset-test.png"
|
||||
if err := os.WriteFile(filepath.Join(storageDir, fileName), []byte("image bytes"), 0o644); err != nil {
|
||||
t.Fatalf("write request asset: %v", err)
|
||||
}
|
||||
service := &Service{cfg: config.Config{LocalUploadedStorageDir: storageDir}}
|
||||
body := map[string]any{
|
||||
"model": "demo",
|
||||
"b64_json": map[string]any{
|
||||
"assetRef": map[string]any{
|
||||
"sha256": "sha-test",
|
||||
"contentType": "image/png",
|
||||
"url": "/static/uploaded/" + fileName,
|
||||
"storageProvider": "local_static",
|
||||
},
|
||||
"url": "/static/uploaded/" + fileName,
|
||||
},
|
||||
}
|
||||
|
||||
hydrated, err := service.hydrateProviderRequestAssets(context.Background(), body)
|
||||
if err != nil {
|
||||
t.Fatalf("hydrate request assets: %v", err)
|
||||
}
|
||||
if got, want := stringFromAny(hydrated["b64_json"]), base64.StdEncoding.EncodeToString([]byte("image bytes")); got != want {
|
||||
t.Fatalf("unexpected hydrated base64: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHydrateProviderRequestAssetsReturnsExpiredError(t *testing.T) {
|
||||
expiredAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
|
||||
service := &Service{}
|
||||
body := map[string]any{
|
||||
"image_url": map[string]any{
|
||||
"assetRef": map[string]any{
|
||||
"sha256": "sha-expired",
|
||||
"contentType": "image/png",
|
||||
"url": "/static/uploaded/gateway-request-asset-expired.png",
|
||||
"storageProvider": "local_static",
|
||||
"expiresAt": expiredAt,
|
||||
},
|
||||
"url": "/static/uploaded/gateway-request-asset-expired.png",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := service.hydrateProviderRequestAssets(context.Background(), body)
|
||||
if err == nil {
|
||||
t.Fatal("expected expired request asset error")
|
||||
}
|
||||
var clientErr *clients.ClientError
|
||||
if !errors.As(err, &clientErr) || clientErr.Code != "request_asset_expired" {
|
||||
t.Fatalf("expected request_asset_expired, got %T %v", err, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringFromAnyReadsRequestAssetWrapperURL(t *testing.T) {
|
||||
wrapper := map[string]any{
|
||||
"assetRef": map[string]any{"url": "https://cdn.example/request.png"},
|
||||
}
|
||||
if got := stringFromAny(wrapper); got != "https://cdn.example/request.png" {
|
||||
t.Fatalf("expected wrapper URL, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlimTaskRequestSnapshotKeepsMessageRefs(t *testing.T) {
|
||||
service := &Service{}
|
||||
task := store.GatewayTask{Request: map[string]any{
|
||||
"conversationId": "conv-1",
|
||||
"messageRefs": []any{map[string]any{"messageId": "msg-1", "position": 0}},
|
||||
"newMessageCount": 1,
|
||||
}}
|
||||
body := map[string]any{
|
||||
"model": "demo",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
|
||||
}
|
||||
|
||||
snapshot := service.slimTaskRequestSnapshot(task, body)
|
||||
|
||||
if snapshot["messages"] != nil {
|
||||
t.Fatalf("snapshot should not persist restored messages: %+v", snapshot)
|
||||
}
|
||||
if snapshot["messageRefs"] == nil || snapshot["newMessageCount"] != 1 {
|
||||
t.Fatalf("snapshot should keep message refs and new count: %+v", snapshot)
|
||||
}
|
||||
}
|
||||
@ -87,9 +87,13 @@ func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, use
|
||||
|
||||
func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
|
||||
executeStartedAt := time.Now()
|
||||
body := normalizeRequest(task.Kind, task.Request)
|
||||
restoredRequest, err := s.restoreTaskRequestReferences(ctx, task)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
body := normalizeRequest(task.Kind, restoredRequest)
|
||||
modelType := modelTypeFromKind(task.Kind, body)
|
||||
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
|
||||
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, s.slimTaskRequestSnapshot(task, body)); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
if task.Status != "running" {
|
||||
@ -194,7 +198,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
}
|
||||
return Result{Task: failed, Output: failed.Result}, clientErr
|
||||
}
|
||||
if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, firstCandidateBody); err != nil {
|
||||
if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, s.slimTaskRequestSnapshot(task, firstCandidateBody)); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
|
||||
@ -508,13 +512,13 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
QueueKey: candidate.QueueKey,
|
||||
Status: "running",
|
||||
Simulated: simulated,
|
||||
RequestSnapshot: body,
|
||||
RequestSnapshot: s.slimTaskRequestSnapshot(task, body),
|
||||
Metrics: baseAttemptMetrics,
|
||||
})
|
||||
if err != nil {
|
||||
return clients.Response{}, fmt.Errorf("create task attempt: %w", err)
|
||||
}
|
||||
if err := s.recordTaskParameterPreprocessing(ctx, task.ID, attemptID, attemptNo, candidate, preprocessing); err != nil {
|
||||
if err := s.recordTaskParameterPreprocessing(ctx, task.ID, attemptID, attemptNo, candidate, s.slimParameterPreprocessingLog(task, preprocessing)); err != nil {
|
||||
clientErr := &clients.ClientError{Code: "runtime_error", Message: err.Error(), Retryable: false}
|
||||
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
||||
AttemptID: attemptID,
|
||||
@ -560,12 +564,24 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
return clients.Response{}, fmt.Errorf("prepare http client: %w", err)
|
||||
}
|
||||
client := s.clientFor(candidate, simulated)
|
||||
providerBody, err := s.hydrateProviderRequestAssets(ctx, body)
|
||||
if err != nil {
|
||||
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
||||
AttemptID: attemptID,
|
||||
Status: "failed",
|
||||
Retryable: false,
|
||||
Metrics: mergeMetrics(baseAttemptMetrics, map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(err, false)}}),
|
||||
ErrorCode: clients.ErrorCode(err),
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
return clients.Response{}, err
|
||||
}
|
||||
callStartedAt := time.Now()
|
||||
response, err := client.Run(ctx, clients.Request{
|
||||
Kind: task.Kind,
|
||||
ModelType: candidate.ModelType,
|
||||
Model: task.Model,
|
||||
Body: body,
|
||||
Body: providerBody,
|
||||
Candidate: candidate,
|
||||
HTTPClient: requestHTTPClient,
|
||||
RemoteTaskID: task.RemoteTaskID,
|
||||
@ -576,7 +592,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
}
|
||||
return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload)
|
||||
},
|
||||
Stream: boolFromMap(body, "stream"),
|
||||
Stream: boolFromMap(providerBody, "stream"),
|
||||
StreamDelta: onDelta,
|
||||
})
|
||||
callFinishedAt := time.Now()
|
||||
@ -826,7 +842,7 @@ func (s *Service) recordFailedAttempt(ctx context.Context, input failedAttemptRe
|
||||
QueueKey: queueKey,
|
||||
Status: "running",
|
||||
Simulated: input.Simulated,
|
||||
RequestSnapshot: input.Body,
|
||||
RequestSnapshot: s.slimTaskRequestSnapshot(input.Task, input.Body),
|
||||
Metrics: metrics,
|
||||
})
|
||||
if err != nil {
|
||||
@ -834,7 +850,8 @@ func (s *Service) recordFailedAttempt(ctx context.Context, input failedAttemptRe
|
||||
return attemptNo
|
||||
}
|
||||
if input.Preprocessing != nil && input.Candidate != nil {
|
||||
if err := s.recordTaskParameterPreprocessing(ctx, input.Task.ID, attemptID, attemptNo, *input.Candidate, *input.Preprocessing); err != nil {
|
||||
preprocessing := s.slimParameterPreprocessingLog(input.Task, *input.Preprocessing)
|
||||
if err := s.recordTaskParameterPreprocessing(ctx, input.Task.ID, attemptID, attemptNo, *input.Candidate, preprocessing); err != nil {
|
||||
s.logger.Warn("record failed attempt parameter preprocessing failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
149
apps/api/internal/store/conversations.go
Normal file
149
apps/api/internal/store/conversations.go
Normal file
@ -0,0 +1,149 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type ConversationMessageInput struct {
|
||||
Hash string
|
||||
Role string
|
||||
Snapshot map[string]any
|
||||
AssetSHA256s []string
|
||||
}
|
||||
|
||||
type TaskMessageRefInput struct {
|
||||
MessageID string
|
||||
Position int
|
||||
}
|
||||
|
||||
type ConversationMessageRef struct {
|
||||
MessageID string `json:"messageId"`
|
||||
Position int `json:"position"`
|
||||
Message map[string]any `json:"message"`
|
||||
}
|
||||
|
||||
func (s *Store) EnsureConversation(ctx context.Context, user *auth.User, conversationKey string, metadata map[string]any) (string, error) {
|
||||
conversationKey = strings.TrimSpace(conversationKey)
|
||||
if conversationKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
userID := ""
|
||||
gatewayUserID := ""
|
||||
if user != nil {
|
||||
userID = strings.TrimSpace(user.ID)
|
||||
gatewayUserID = strings.TrimSpace(user.GatewayUserID)
|
||||
}
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
metadataJSON, _ := json.Marshal(emptyObjectIfNil(metadata))
|
||||
var conversationID string
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
INSERT INTO gateway_conversations (user_id, gateway_user_id, conversation_key, metadata)
|
||||
VALUES ($1, NULLIF($2, '')::uuid, $3, $4::jsonb)
|
||||
ON CONFLICT (user_id, conversation_key) DO UPDATE
|
||||
SET gateway_user_id = COALESCE(gateway_conversations.gateway_user_id, EXCLUDED.gateway_user_id),
|
||||
metadata = gateway_conversations.metadata || EXCLUDED.metadata,
|
||||
updated_at = now()
|
||||
RETURNING id::text`, userID, gatewayUserID, conversationKey, string(metadataJSON)).Scan(&conversationID)
|
||||
return conversationID, err
|
||||
}
|
||||
|
||||
func (s *Store) UpsertConversationMessages(ctx context.Context, conversationID string, messages []ConversationMessageInput) ([]TaskMessageRefInput, int, error) {
|
||||
if strings.TrimSpace(conversationID) == "" || len(messages) == 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
refs := make([]TaskMessageRefInput, 0, len(messages))
|
||||
newCount := 0
|
||||
for index, message := range messages {
|
||||
snapshotJSON, _ := json.Marshal(emptyObjectIfNil(message.Snapshot))
|
||||
var messageID string
|
||||
var inserted bool
|
||||
if err := tx.QueryRow(ctx, `
|
||||
INSERT INTO gateway_conversation_messages (
|
||||
conversation_id, message_hash, role, message_snapshot, asset_sha256s
|
||||
)
|
||||
VALUES ($1::uuid, $2, NULLIF($3, ''), $4::jsonb, $5)
|
||||
ON CONFLICT (conversation_id, message_hash) DO UPDATE
|
||||
SET updated_at = gateway_conversation_messages.updated_at
|
||||
RETURNING id::text, (xmax = 0) AS inserted`,
|
||||
conversationID,
|
||||
message.Hash,
|
||||
message.Role,
|
||||
string(snapshotJSON),
|
||||
message.AssetSHA256s,
|
||||
).Scan(&messageID, &inserted); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if inserted {
|
||||
newCount++
|
||||
}
|
||||
refs = append(refs, TaskMessageRefInput{MessageID: messageID, Position: index})
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return refs, newCount, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListTaskConversationMessages(ctx context.Context, taskID string) ([]ConversationMessageRef, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT refs.message_id::text, refs.position, messages.message_snapshot
|
||||
FROM gateway_task_message_refs refs
|
||||
JOIN gateway_conversation_messages messages ON messages.id = refs.message_id
|
||||
WHERE refs.task_id = $1::uuid
|
||||
ORDER BY refs.position ASC`, taskID)
|
||||
if err != nil {
|
||||
if IsUndefinedDatabaseObject(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]ConversationMessageRef, 0)
|
||||
for rows.Next() {
|
||||
var item ConversationMessageRef
|
||||
var snapshot []byte
|
||||
if err := rows.Scan(&item.MessageID, &item.Position, &snapshot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.Message = decodeObject(snapshot)
|
||||
items = append(items, item)
|
||||
}
|
||||
return items, rows.Err()
|
||||
}
|
||||
|
||||
func insertTaskMessageRefs(ctx context.Context, tx pgx.Tx, taskID string, refs []TaskMessageRefInput) error {
|
||||
if len(refs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, ref := range refs {
|
||||
if strings.TrimSpace(ref.MessageID) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_task_message_refs (task_id, message_id, position)
|
||||
VALUES ($1::uuid, $2::uuid, $3)
|
||||
ON CONFLICT (task_id, position) DO UPDATE
|
||||
SET message_id = EXCLUDED.message_id`,
|
||||
taskID,
|
||||
ref.MessageID,
|
||||
ref.Position,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -12,8 +12,9 @@ import (
|
||||
const defaultServerMainUploadURL = "http://127.0.0.1:3001/v1/files/upload"
|
||||
|
||||
const (
|
||||
FileStorageSceneUpload = "upload"
|
||||
FileStorageSceneImageResult = "image_result"
|
||||
FileStorageSceneUpload = "upload"
|
||||
FileStorageSceneImageResult = "image_result"
|
||||
FileStorageSceneRequestAsset = "request_asset"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@ -380,11 +380,14 @@ type RateLimitWindow struct {
|
||||
}
|
||||
|
||||
type CreateTaskInput struct {
|
||||
Kind string `json:"kind"`
|
||||
Model string `json:"model"`
|
||||
RunMode string `json:"runMode"`
|
||||
Async bool `json:"async"`
|
||||
Request map[string]any `json:"request"`
|
||||
Kind string `json:"kind"`
|
||||
Model string `json:"model"`
|
||||
RunMode string `json:"runMode"`
|
||||
Async bool `json:"async"`
|
||||
Request map[string]any `json:"request"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
NewMessageCount int `json:"newMessageCount"`
|
||||
MessageRefs []TaskMessageRefInput `json:"messageRefs"`
|
||||
}
|
||||
|
||||
type GatewayTask struct {
|
||||
@ -407,6 +410,8 @@ type GatewayTask struct {
|
||||
RequestedModel string `json:"requestedModel,omitempty"`
|
||||
ResolvedModel string `json:"resolvedModel,omitempty"`
|
||||
RequestID string `json:"requestId,omitempty"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
NewMessageCount int `json:"newMessageCount,omitempty"`
|
||||
Request map[string]any `json:"request,omitempty"`
|
||||
AsyncMode bool `json:"asyncMode"`
|
||||
RiverJobID int64 `json:"riverJobId,omitempty"`
|
||||
@ -438,6 +443,7 @@ COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_
|
||||
COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''),
|
||||
COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model,
|
||||
COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''),
|
||||
COALESCE(conversation_id::text, ''), COALESCE(new_message_count, 0),
|
||||
request, COALESCE(async_mode, false), COALESCE(river_job_id, 0), status, COALESCE(attempt_count, 0),
|
||||
COALESCE(remote_task_id, ''), COALESCE(remote_task_payload, '{}'::jsonb),
|
||||
COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
|
||||
@ -1746,15 +1752,18 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut
|
||||
INSERT INTO gateway_tasks (
|
||||
kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key,
|
||||
api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key,
|
||||
model, requested_model, request, async_mode, status, result, billings, finished_at
|
||||
model, requested_model, request, async_mode, status, result, billings, conversation_id, new_message_count, finished_at
|
||||
)
|
||||
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, CASE WHEN $20 THEN now() ELSE NULL END)
|
||||
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, NULLIF($20, '')::uuid, $21, CASE WHEN $22 THEN now() ELSE NULL END)
|
||||
RETURNING `+gatewayTaskColumns,
|
||||
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, false,
|
||||
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, input.ConversationID, input.NewMessageCount, false,
|
||||
))
|
||||
if err != nil {
|
||||
return GatewayTask{}, err
|
||||
}
|
||||
if err := insertTaskMessageRefs(ctx, tx, task.ID, input.MessageRefs); err != nil {
|
||||
return GatewayTask{}, err
|
||||
}
|
||||
events := taskEventsForCreate(task.ID, runMode, status, nil)
|
||||
for _, event := range events {
|
||||
payload, _ := json.Marshal(event.Payload)
|
||||
@ -1822,6 +1831,8 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
|
||||
&task.RequestedModel,
|
||||
&task.ResolvedModel,
|
||||
&task.RequestID,
|
||||
&task.ConversationID,
|
||||
&task.NewMessageCount,
|
||||
&requestBytes,
|
||||
&task.AsyncMode,
|
||||
&task.RiverJobID,
|
||||
|
||||
@ -7,6 +7,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type RateLimitMetricStatus struct {
|
||||
@ -82,6 +84,59 @@ type PlatformPolicyEvent struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
func (s *Store) RestorePlatformModelRuntimeStatus(ctx context.Context, platformModelID string) (ModelRateLimitStatus, error) {
|
||||
platformModelID = strings.TrimSpace(platformModelID)
|
||||
if platformModelID == "" {
|
||||
return ModelRateLimitStatus{}, pgx.ErrNoRows
|
||||
}
|
||||
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return ModelRateLimitStatus{}, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
var restoredModelID string
|
||||
if err := tx.QueryRow(ctx, `
|
||||
WITH restored_model AS (
|
||||
UPDATE platform_models
|
||||
SET enabled = true,
|
||||
cooldown_until = NULL,
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid
|
||||
RETURNING id::text, platform_id
|
||||
),
|
||||
restored_platform AS (
|
||||
UPDATE integration_platforms
|
||||
SET status = 'enabled',
|
||||
disabled_reason = NULL,
|
||||
cooldown_until = NULL,
|
||||
updated_at = now()
|
||||
WHERE id = (SELECT platform_id FROM restored_model)
|
||||
AND deleted_at IS NULL
|
||||
RETURNING id
|
||||
)
|
||||
SELECT id
|
||||
FROM restored_model
|
||||
WHERE EXISTS (SELECT 1 FROM restored_platform)`, platformModelID).Scan(&restoredModelID); err != nil {
|
||||
return ModelRateLimitStatus{}, err
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return ModelRateLimitStatus{}, err
|
||||
}
|
||||
|
||||
items, err := s.ListModelRateLimitStatuses(ctx)
|
||||
if err != nil {
|
||||
return ModelRateLimitStatus{}, err
|
||||
}
|
||||
for _, item := range items {
|
||||
if item.PlatformModelID == restoredModelID {
|
||||
return item, nil
|
||||
}
|
||||
}
|
||||
return ModelRateLimitStatus{}, pgx.ErrNoRows
|
||||
}
|
||||
|
||||
func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT m.id::text, m.platform_id::text, p.name, p.provider, p.status,
|
||||
|
||||
130
apps/api/internal/store/request_assets.go
Normal file
130
apps/api/internal/store/request_assets.go
Normal file
@ -0,0 +1,130 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type RequestAsset struct {
|
||||
ID string `json:"id"`
|
||||
SHA256 string `json:"sha256"`
|
||||
ContentType string `json:"contentType"`
|
||||
ByteSize int64 `json:"byteSize"`
|
||||
URL string `json:"url"`
|
||||
StorageProvider string `json:"storageProvider"`
|
||||
LocalPath string `json:"localPath,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||
ExpiredAt *time.Time `json:"expiredAt,omitempty"`
|
||||
RefCount int `json:"refCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type RequestAssetInput struct {
|
||||
SHA256 string
|
||||
ContentType string
|
||||
ByteSize int64
|
||||
URL string
|
||||
StorageProvider string
|
||||
LocalPath string
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
|
||||
func (s *Store) FindRequestAsset(ctx context.Context, sha256 string, contentType string) (RequestAsset, bool, error) {
|
||||
asset, err := scanRequestAsset(s.pool.QueryRow(ctx, `
|
||||
SELECT id::text, sha256, content_type, byte_size, url, storage_provider,
|
||||
COALESCE(local_path, ''), expires_at, expired_at, ref_count, created_at, updated_at
|
||||
FROM gateway_request_assets
|
||||
WHERE sha256 = $1 AND content_type = $2`, sha256, contentType))
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return RequestAsset{}, false, nil
|
||||
}
|
||||
return RequestAsset{}, false, err
|
||||
}
|
||||
return asset, true, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpsertRequestAsset(ctx context.Context, input RequestAssetInput) (RequestAsset, error) {
|
||||
return scanRequestAsset(s.pool.QueryRow(ctx, `
|
||||
INSERT INTO gateway_request_assets (
|
||||
sha256, content_type, byte_size, url, storage_provider, local_path, expires_at, expired_at, ref_count
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, NULLIF($6, ''), $7, NULL, 1)
|
||||
ON CONFLICT (sha256, content_type) DO UPDATE
|
||||
SET byte_size = EXCLUDED.byte_size,
|
||||
url = EXCLUDED.url,
|
||||
storage_provider = EXCLUDED.storage_provider,
|
||||
local_path = EXCLUDED.local_path,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
expired_at = NULL,
|
||||
ref_count = gateway_request_assets.ref_count + 1,
|
||||
updated_at = now()
|
||||
RETURNING id::text, sha256, content_type, byte_size, url, storage_provider,
|
||||
COALESCE(local_path, ''), expires_at, expired_at, ref_count, created_at, updated_at`,
|
||||
input.SHA256,
|
||||
input.ContentType,
|
||||
input.ByteSize,
|
||||
input.URL,
|
||||
input.StorageProvider,
|
||||
input.LocalPath,
|
||||
input.ExpiresAt,
|
||||
))
|
||||
}
|
||||
|
||||
func (s *Store) IncrementRequestAssetRefCount(ctx context.Context, sha256 string, contentType string) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE gateway_request_assets
|
||||
SET ref_count = ref_count + 1,
|
||||
updated_at = now()
|
||||
WHERE sha256 = $1 AND content_type = $2`, sha256, contentType)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) MarkRequestAssetExpiredByLocalPath(ctx context.Context, localPath string, expiredAt time.Time) error {
|
||||
if localPath == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE gateway_request_assets
|
||||
SET expired_at = COALESCE(expired_at, $2),
|
||||
updated_at = now()
|
||||
WHERE local_path = $1
|
||||
AND storage_provider = 'local_static'
|
||||
AND expired_at IS NULL`, localPath, expiredAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func scanRequestAsset(scanner interface{ Scan(dest ...any) error }) (RequestAsset, error) {
|
||||
var asset RequestAsset
|
||||
var localPath string
|
||||
var expiresAt sql.NullTime
|
||||
var expiredAt sql.NullTime
|
||||
if err := scanner.Scan(
|
||||
&asset.ID,
|
||||
&asset.SHA256,
|
||||
&asset.ContentType,
|
||||
&asset.ByteSize,
|
||||
&asset.URL,
|
||||
&asset.StorageProvider,
|
||||
&localPath,
|
||||
&expiresAt,
|
||||
&expiredAt,
|
||||
&asset.RefCount,
|
||||
&asset.CreatedAt,
|
||||
&asset.UpdatedAt,
|
||||
); err != nil {
|
||||
return RequestAsset{}, err
|
||||
}
|
||||
asset.LocalPath = localPath
|
||||
if expiresAt.Valid {
|
||||
asset.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
if expiredAt.Valid {
|
||||
asset.ExpiredAt = &expiredAt.Time
|
||||
}
|
||||
return asset, nil
|
||||
}
|
||||
@ -0,0 +1,68 @@
|
||||
CREATE TABLE IF NOT EXISTS gateway_request_assets (
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
sha256 text NOT NULL,
|
||||
content_type text NOT NULL,
|
||||
byte_size bigint NOT NULL,
|
||||
url text NOT NULL,
|
||||
storage_provider text NOT NULL,
|
||||
local_path text,
|
||||
expires_at timestamptz,
|
||||
expired_at timestamptz,
|
||||
ref_count integer NOT NULL DEFAULT 0,
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||
UNIQUE (sha256, content_type)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_request_assets_expires
|
||||
ON gateway_request_assets(storage_provider, expires_at)
|
||||
WHERE expires_at IS NOT NULL AND expired_at IS NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_request_assets_local_path
|
||||
ON gateway_request_assets(local_path)
|
||||
WHERE local_path IS NOT NULL;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS gateway_conversations (
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id text NOT NULL,
|
||||
gateway_user_id uuid REFERENCES gateway_users(id) ON DELETE SET NULL,
|
||||
conversation_key text NOT NULL,
|
||||
metadata jsonb NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||
UNIQUE (user_id, conversation_key)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS gateway_conversation_messages (
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id uuid NOT NULL REFERENCES gateway_conversations(id) ON DELETE CASCADE,
|
||||
message_hash text NOT NULL,
|
||||
role text,
|
||||
message_snapshot jsonb NOT NULL,
|
||||
asset_sha256s text[] NOT NULL DEFAULT '{}',
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||
UNIQUE (conversation_id, message_hash)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_conversation_messages_conversation
|
||||
ON gateway_conversation_messages(conversation_id, created_at);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS gateway_task_message_refs (
|
||||
task_id uuid NOT NULL REFERENCES gateway_tasks(id) ON DELETE CASCADE,
|
||||
message_id uuid NOT NULL REFERENCES gateway_conversation_messages(id) ON DELETE CASCADE,
|
||||
position integer NOT NULL,
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (task_id, position)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_task_message_refs_message
|
||||
ON gateway_task_message_refs(message_id);
|
||||
|
||||
ALTER TABLE gateway_tasks
|
||||
ADD COLUMN IF NOT EXISTS conversation_id uuid REFERENCES gateway_conversations(id) ON DELETE SET NULL,
|
||||
ADD COLUMN IF NOT EXISTS new_message_count integer NOT NULL DEFAULT 0;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_conversation_created
|
||||
ON gateway_tasks(conversation_id, created_at DESC)
|
||||
WHERE conversation_id IS NOT NULL;
|
||||
@ -84,6 +84,7 @@ import {
|
||||
pollTaskUntilSettled,
|
||||
registerLocalAccount,
|
||||
replacePlatformModels,
|
||||
restoreModelRuntimeStatus,
|
||||
setUserWalletBalance,
|
||||
type HealthResponse,
|
||||
updateAccessRule,
|
||||
@ -635,6 +636,50 @@ export function App() {
|
||||
}
|
||||
}
|
||||
|
||||
async function restoreRuntimeModel(platformModelId: string) {
|
||||
setCoreState('loading');
|
||||
setCoreMessage('');
|
||||
try {
|
||||
const restored = await restoreModelRuntimeStatus(token, platformModelId);
|
||||
setModelRateLimits((current) => current.map((status) => {
|
||||
if (status.platformId !== restored.platformId) return status;
|
||||
return {
|
||||
...status,
|
||||
platformStatus: 'enabled',
|
||||
platformCooldownUntil: undefined,
|
||||
platformDisabledReason: undefined,
|
||||
...(status.platformModelId === platformModelId
|
||||
? {
|
||||
enabled: restored.enabled,
|
||||
modelCooldownUntil: restored.modelCooldownUntil,
|
||||
}
|
||||
: {}),
|
||||
};
|
||||
}));
|
||||
setPlatforms((current) => current.map((platform) => platform.id === restored.platformId
|
||||
? {
|
||||
...platform,
|
||||
status: 'enabled',
|
||||
cooldownUntil: undefined,
|
||||
}
|
||||
: platform));
|
||||
setModels((current) => current.map((model) => model.id === platformModelId
|
||||
? {
|
||||
...model,
|
||||
enabled: true,
|
||||
cooldownUntil: undefined,
|
||||
}
|
||||
: model));
|
||||
invalidateDataKeys('modelCatalog', 'modelRateLimits', 'models', 'platforms', 'playgroundModels');
|
||||
setCoreState('ready');
|
||||
setCoreMessage('模型运行状态已恢复。');
|
||||
} catch (err) {
|
||||
setCoreState('error');
|
||||
setCoreMessage(err instanceof Error ? err.message : '恢复模型运行状态失败');
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
async function removePlatform(platformId: string) {
|
||||
setCoreState('loading');
|
||||
setCoreMessage('');
|
||||
@ -1143,6 +1188,7 @@ export function App() {
|
||||
onResetBaseModel={resetBaseModelToDefault}
|
||||
onSavePlatform={savePlatformWithModels}
|
||||
onSavePlatformDynamicPriority={savePlatformDynamicPriority}
|
||||
onRestoreRuntimeModel={restoreRuntimeModel}
|
||||
onTogglePlatformStatus={savePlatformStatus}
|
||||
onSaveProvider={saveProvider}
|
||||
onSavePricingRuleSet={savePricingRuleSet}
|
||||
|
||||
@ -913,6 +913,13 @@ export async function listModelRateLimitStatuses(token: string): Promise<ListRes
|
||||
return request<ListResponse<ModelRateLimitStatus>>('/api/admin/runtime/model-rate-limits', { token });
|
||||
}
|
||||
|
||||
export async function restoreModelRuntimeStatus(token: string, platformModelId: string): Promise<ModelRateLimitStatus> {
|
||||
return request<ModelRateLimitStatus>(`/api/admin/runtime/model-rate-limits/${platformModelId}/restore`, {
|
||||
method: 'POST',
|
||||
token,
|
||||
});
|
||||
}
|
||||
|
||||
export async function getNetworkProxyConfig(token: string): Promise<GatewayNetworkProxyConfig> {
|
||||
return request<GatewayNetworkProxyConfig>('/api/admin/config/network-proxy', { token });
|
||||
}
|
||||
|
||||
@ -71,6 +71,7 @@ export function AdminPage(props: {
|
||||
onBatchAccessRules: (input: GatewayAccessRuleBatchRequest) => Promise<void>;
|
||||
onSavePlatform: (input: PlatformWithModelsInput) => Promise<void>;
|
||||
onSavePlatformDynamicPriority: (platformId: string, input: PlatformDynamicPriorityUpdateRequest) => Promise<void>;
|
||||
onRestoreRuntimeModel: (platformModelId: string) => Promise<void>;
|
||||
onTogglePlatformStatus: (platform: IntegrationPlatform, status: 'enabled' | 'disabled') => Promise<void>;
|
||||
onSaveProvider: (input: CatalogProviderUpsertRequest, providerId?: string) => Promise<void>;
|
||||
onSavePricingRuleSet: (input: PricingRuleSetUpsertRequest, ruleSetId?: string) => Promise<void>;
|
||||
@ -173,6 +174,7 @@ export function AdminPage(props: {
|
||||
modelRateLimitsUpdatedAt={props.data.modelRateLimitsUpdatedAt}
|
||||
platforms={props.data.platforms}
|
||||
onSavePlatformDynamicPriority={props.onSavePlatformDynamicPriority}
|
||||
onRestoreRuntimeModel={props.onRestoreRuntimeModel}
|
||||
/>
|
||||
)}
|
||||
{props.section === 'tenants' && <TenantsPanel {...identityPanelProps(props)} />}
|
||||
|
||||
@ -9,11 +9,13 @@ export function RealtimeLoadPanel(props: {
|
||||
modelRateLimitsUpdatedAt: number | null;
|
||||
platforms: IntegrationPlatform[];
|
||||
onSavePlatformDynamicPriority: (platformId: string, input: PlatformDynamicPriorityUpdateRequest) => Promise<void>;
|
||||
onRestoreRuntimeModel: (platformModelId: string) => Promise<void>;
|
||||
}) {
|
||||
const [now, setNow] = useState(() => Date.now());
|
||||
const [priorityDialog, setPriorityDialog] = useState<PriorityDialogState | null>(null);
|
||||
const [priorityError, setPriorityError] = useState('');
|
||||
const [prioritySaving, setPrioritySaving] = useState(false);
|
||||
const [restoreSavingId, setRestoreSavingId] = useState<string | null>(null);
|
||||
const platformMap = useMemo(() => new Map(props.platforms.map((item) => [item.id, item])), [props.platforms]);
|
||||
|
||||
useEffect(() => {
|
||||
@ -65,6 +67,17 @@ export function RealtimeLoadPanel(props: {
|
||||
}
|
||||
}
|
||||
|
||||
async function restoreRuntimeModel(platformModelId: string) {
|
||||
setRestoreSavingId(platformModelId);
|
||||
try {
|
||||
await props.onRestoreRuntimeModel(platformModelId);
|
||||
} catch {
|
||||
// App-level state owns the error message.
|
||||
} finally {
|
||||
setRestoreSavingId(null);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<section className="pageStack">
|
||||
<Card>
|
||||
@ -81,6 +94,8 @@ export function RealtimeLoadPanel(props: {
|
||||
statuses={props.modelRateLimits}
|
||||
updatedAt={props.modelRateLimitsUpdatedAt}
|
||||
onAdjustPriority={openPriorityDialog}
|
||||
onRestoreRuntimeModel={restoreRuntimeModel}
|
||||
restoreSavingId={restoreSavingId}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
@ -109,6 +124,8 @@ function RateLimitStatusTable(props: {
|
||||
now: number;
|
||||
updatedAt: number | null;
|
||||
onAdjustPriority: (status: ModelRateLimitStatus, platform: IntegrationPlatform | undefined) => void;
|
||||
onRestoreRuntimeModel: (platformModelId: string) => Promise<void>;
|
||||
restoreSavingId: string | null;
|
||||
}) {
|
||||
if (!props.statuses.length) {
|
||||
return <EmptyState title="暂无实时负载" description="模型产生请求后会在这里显示实时 RPM、TPM 和并发窗口。" />;
|
||||
@ -150,7 +167,15 @@ function RateLimitStatusTable(props: {
|
||||
<small>{status.provider}</small>
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell className="platformLimitStatusCell">{modelRuntimeStatusCell(status, platform, props.now)}</TableCell>
|
||||
<TableCell className="platformLimitStatusCell">
|
||||
{modelRuntimeStatusCell(
|
||||
status,
|
||||
platform,
|
||||
props.now,
|
||||
props.onRestoreRuntimeModel,
|
||||
props.restoreSavingId === status.platformModelId,
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="platformLimitNumberCell">{platformPriorityCell(status, platform, props.onAdjustPriority)}</TableCell>
|
||||
<TableCell className="platformLimitNumberCell">
|
||||
<span className="rateLoadCell" data-overloaded={status.loadRatio > 0.8 ? 'true' : undefined}>
|
||||
@ -447,50 +472,104 @@ function shortId(value: string | undefined) {
|
||||
return value.length > 8 ? value.slice(0, 8) : value;
|
||||
}
|
||||
|
||||
function modelRuntimeStatusCell(status: ModelRateLimitStatus, platform: IntegrationPlatform | undefined, now: number) {
|
||||
function modelRuntimeStatusCell(
|
||||
status: ModelRateLimitStatus,
|
||||
platform: IntegrationPlatform | undefined,
|
||||
now: number,
|
||||
onRestore: (platformModelId: string) => Promise<void>,
|
||||
restoring: boolean,
|
||||
) {
|
||||
const modelCooldownMs = cooldownRemainingMs(status.modelCooldownUntil, now);
|
||||
const platformCooldownMs = cooldownRemainingMs(status.platformCooldownUntil, now);
|
||||
const platformStatus = platform?.status || status.platformStatus || 'enabled';
|
||||
const restoreButton = runtimeRestoreButton(status, platformStatus, modelCooldownMs, platformCooldownMs, onRestore, restoring);
|
||||
if (modelCooldownMs > 0) {
|
||||
return (
|
||||
<span className="platformTableName">
|
||||
<strong><Badge variant="warning">模型冷却中</Badge></strong>
|
||||
<small>剩余 {formatCooldownRemaining(modelCooldownMs)}</small>
|
||||
<span className="platformRuntimeStatusCell">
|
||||
<span className="platformTableName">
|
||||
<strong><Badge variant="warning">模型冷却中</Badge></strong>
|
||||
<small>剩余 {formatCooldownRemaining(modelCooldownMs)}</small>
|
||||
</span>
|
||||
{restoreButton}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (platformStatus !== 'enabled') {
|
||||
const badge = <Badge variant="warning">已禁用</Badge>;
|
||||
return (
|
||||
<AntPopover
|
||||
align={{ offset: [0, 8] }}
|
||||
content={<PlatformDisabledReasonPopover record={status.platformDisabledReason} />}
|
||||
overlayClassName="priorityDemotionAntPopover"
|
||||
placement="bottomLeft"
|
||||
trigger={['hover', 'focus']}
|
||||
>
|
||||
<span className="platformTableName" tabIndex={0}>
|
||||
<strong>{badge}</strong>
|
||||
<span className="platformRuntimeStatusCell">
|
||||
<AntPopover
|
||||
align={{ offset: [0, 8] }}
|
||||
content={<PlatformDisabledReasonPopover record={status.platformDisabledReason} />}
|
||||
overlayClassName="priorityDemotionAntPopover"
|
||||
placement="bottomLeft"
|
||||
trigger={['hover', 'focus']}
|
||||
>
|
||||
<span className="platformTableName" tabIndex={0}>
|
||||
<strong>{badge}</strong>
|
||||
</span>
|
||||
</AntPopover>
|
||||
{restoreButton}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (!status.enabled) {
|
||||
return (
|
||||
<span className="platformRuntimeStatusCell">
|
||||
<span className="platformTableName">
|
||||
<strong><Badge variant="secondary">已停用</Badge></strong>
|
||||
<small>不参与路由</small>
|
||||
</span>
|
||||
</AntPopover>
|
||||
{restoreButton}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (platformCooldownMs > 0) {
|
||||
return (
|
||||
<span className="platformTableName">
|
||||
<strong><Badge variant="warning">平台冷却中</Badge></strong>
|
||||
<small>剩余 {formatCooldownRemaining(platformCooldownMs)}</small>
|
||||
<span className="platformRuntimeStatusCell">
|
||||
<span className="platformTableName">
|
||||
<strong><Badge variant="warning">平台冷却中</Badge></strong>
|
||||
<small>剩余 {formatCooldownRemaining(platformCooldownMs)}</small>
|
||||
</span>
|
||||
{restoreButton}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<span className="platformTableName">
|
||||
<strong><Badge variant={status.enabled ? 'success' : 'secondary'}>{status.enabled ? '可用' : '已停用'}</Badge></strong>
|
||||
<small>{status.enabled ? '参与路由' : '不参与路由'}</small>
|
||||
<span className="platformRuntimeStatusCell">
|
||||
<span className="platformTableName">
|
||||
<strong><Badge variant="success">可用</Badge></strong>
|
||||
<small>参与路由</small>
|
||||
</span>
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
function runtimeRestoreButton(
|
||||
status: ModelRateLimitStatus,
|
||||
platformStatus: string,
|
||||
modelCooldownMs: number,
|
||||
platformCooldownMs: number,
|
||||
onRestore: (platformModelId: string) => Promise<void>,
|
||||
restoring: boolean,
|
||||
) {
|
||||
const canRestore = modelCooldownMs > 0 || platformCooldownMs > 0 || platformStatus !== 'enabled' || !status.enabled;
|
||||
if (!canRestore) return null;
|
||||
return (
|
||||
<Button
|
||||
className="platformRestoreButton"
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="xs"
|
||||
disabled={restoring}
|
||||
onClick={() => void onRestore(status.platformModelId)}
|
||||
>
|
||||
<RotateCcw size={13} />
|
||||
{restoring ? '恢复中' : '恢复'}
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
function cooldownRemainingMs(cooldownUntil: string | undefined, now: number) {
|
||||
if (!cooldownUntil) return 0;
|
||||
const until = Date.parse(cooldownUntil);
|
||||
|
||||
@ -1086,8 +1086,8 @@
|
||||
}
|
||||
|
||||
.platformLimitTable .shTableRow {
|
||||
grid-template-columns: minmax(180px, 1.1fr) minmax(160px, 0.9fr) 160px 132px 150px 170px 140px 132px;
|
||||
min-width: 1224px;
|
||||
grid-template-columns: minmax(180px, 1.1fr) minmax(160px, 0.9fr) 178px 132px 150px 170px 140px 132px;
|
||||
min-width: 1242px;
|
||||
}
|
||||
|
||||
.platformLimitTable .shTableHead,
|
||||
@ -1130,6 +1130,22 @@
|
||||
justify-items: start;
|
||||
}
|
||||
|
||||
.platformRuntimeStatusCell {
|
||||
display: grid;
|
||||
min-width: 0;
|
||||
gap: 7px;
|
||||
align-content: start;
|
||||
}
|
||||
|
||||
.platformRestoreButton {
|
||||
justify-self: start;
|
||||
min-height: 22px;
|
||||
padding-inline: 8px;
|
||||
border-color: var(--border);
|
||||
color: var(--text-normal);
|
||||
background: #fff;
|
||||
}
|
||||
|
||||
.rateMetricCell,
|
||||
.rateLoadCell {
|
||||
display: grid;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user