diff --git a/apps/api/internal/config/config.go b/apps/api/internal/config/config.go index 7b2188d..273c1e8 100644 --- a/apps/api/internal/config/config.go +++ b/apps/api/internal/config/config.go @@ -4,6 +4,7 @@ import ( "log/slog" "net/url" "os" + "strconv" "strings" ) @@ -23,6 +24,7 @@ type Config struct { PublicBaseURL string LocalGeneratedStorageDir string LocalUploadedStorageDir string + LocalTempAssetTTLHours int TaskProgressCallbackEnabled bool TaskProgressCallbackURL string TaskProgressCallbackTimeoutMS string @@ -49,6 +51,7 @@ func Load() Config { PublicBaseURL: strings.TrimRight(env("AI_GATEWAY_PUBLIC_BASE_URL", env("PUBLIC_BASE_URL", "")), "/"), LocalGeneratedStorageDir: env("AI_GATEWAY_GENERATED_STORAGE_DIR", env("LOCAL_GENERATED_STORAGE_DIR", env("AI_GATEWAY_STATIC_STORAGE_DIR", DefaultLocalGeneratedStorageDir))), LocalUploadedStorageDir: env("AI_GATEWAY_UPLOADED_STORAGE_DIR", env("LOCAL_UPLOADED_STORAGE_DIR", DefaultLocalUploadedStorageDir)), + LocalTempAssetTTLHours: envInt("AI_GATEWAY_LOCAL_TEMP_ASSET_TTL_HOURS", 24), TaskProgressCallbackEnabled: env("TASK_PROGRESS_CALLBACK_ENABLED", "true") == "true", TaskProgressCallbackURL: env("TASK_PROGRESS_CALLBACK_URL", strings.TrimRight(env("SERVER_MAIN_BASE_URL", "http://localhost:3000"), "/")+"/internal/platform/task-progress-callbacks", @@ -136,6 +139,18 @@ func env(key string, fallback string) string { return fallback } +func envInt(key string, fallback int) int { + value := envValue(key) + if value == "" { + return fallback + } + parsed, err := strconv.Atoi(value) + if err != nil { + return fallback + } + return parsed +} + func logLevel(value string) slog.Level { switch strings.ToLower(value) { case "debug": diff --git a/apps/api/internal/httpapi/handlers.go b/apps/api/internal/httpapi/handlers.go index dbc244a..7a4146a 100644 --- a/apps/api/internal/httpapi/handlers.go +++ b/apps/api/internal/httpapi/handlers.go @@ -916,13 +916,26 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { return } responsePlan := planTaskResponse(kind, compatible, body, r) + prepared, err := s.prepareTaskRequest(r.Context(), r, user, body) + if err != nil { + s.logger.Warn("prepare task request failed", "kind", kind, "error", err) + status := http.StatusBadRequest + if code := clients.ErrorCode(err); strings.HasPrefix(code, "upload_") || code == "request_asset_upload_failed" { + status = http.StatusBadGateway + } + writeError(w, status, err.Error(), clients.ErrorCode(err)) + return + } task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{ - Kind: kind, - Model: model, - RunMode: runModeFromRequest(body), - Async: responsePlan.asyncMode, - Request: body, + Kind: kind, + Model: model, + RunMode: runModeFromRequest(prepared.Body), + Async: responsePlan.asyncMode, + Request: prepared.Body, + ConversationID: prepared.ConversationID, + NewMessageCount: prepared.NewMessageCount, + MessageRefs: prepared.MessageRefs, }, user) if err != nil { s.logger.Error("create task failed", "kind", kind, "error", err) diff --git a/apps/api/internal/httpapi/local_temp_assets.go b/apps/api/internal/httpapi/local_temp_assets.go new file mode 100644 index 0000000..de71e86 --- /dev/null +++ b/apps/api/internal/httpapi/local_temp_assets.go @@ -0,0 +1,70 @@ +package httpapi + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func (s *Server) startLocalTempAssetCleanup(ctx context.Context) { + go func() { + s.cleanupExpiredLocalTempAssets(ctx, time.Now()) + ticker := time.NewTicker(time.Hour) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + s.cleanupExpiredLocalTempAssets(ctx, now) + } + } + }() +} + +func (s *Server) cleanupExpiredLocalTempAssets(ctx context.Context, now time.Time) int { + storageDir := strings.TrimSpace(s.cfg.LocalUploadedStorageDir) + if storageDir == "" { + storageDir = config.DefaultLocalUploadedStorageDir + } + entries, err := os.ReadDir(storageDir) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + s.logger.Warn("read local temp asset dir failed", "dir", storageDir, "error", err) + } + return 0 + } + ttl := time.Duration(s.localTempAssetTTLHours()) * time.Hour + expiredBefore := now.Add(-ttl) + deleted := 0 + for _, entry := range entries { + if entry.IsDir() || !strings.HasPrefix(entry.Name(), requestAssetFilePrefix) { + continue + } + info, err := entry.Info() + if err != nil { + continue + } + if info.ModTime().After(expiredBefore) { + continue + } + localPath := filepath.Join(storageDir, entry.Name()) + if err := os.Remove(localPath); err != nil && !errors.Is(err, os.ErrNotExist) { + s.logger.Warn("remove local temp asset failed", "path", localPath, "error", err) + continue + } + deleted++ + if s.store != nil { + if err := s.store.MarkRequestAssetExpiredByLocalPath(ctx, localPath, now); err != nil && !store.IsUndefinedDatabaseObject(err) { + s.logger.Warn("mark local temp asset expired failed", "path", localPath, "error", err) + } + } + } + return deleted +} diff --git a/apps/api/internal/httpapi/request_preparation.go b/apps/api/internal/httpapi/request_preparation.go new file mode 100644 index 0000000..22ddd06 --- /dev/null +++ b/apps/api/internal/httpapi/request_preparation.go @@ -0,0 +1,583 @@ +package httpapi + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "mime" + "net/http" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/runner" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +const requestAssetFilePrefix = "gateway-request-asset-" + +type preparedTaskRequest struct { + Body map[string]any + ConversationID string + MessageRefs []store.TaskMessageRefInput + NewMessageCount int +} + +type decodedRequestAsset struct { + Bytes []byte + ContentType string +} + +func (s *Server) prepareTaskRequest(ctx context.Context, r *http.Request, user *auth.User, body map[string]any) (preparedTaskRequest, error) { + preparedBody, err := s.prepareRequestAssetRefs(ctx, body) + if err != nil { + return preparedTaskRequest{}, err + } + result := preparedTaskRequest{Body: preparedBody} + conversationKey := requestConversationKey(r, preparedBody) + if conversationKey == "" { + return result, nil + } + messages, ok := preparedBody["messages"].([]any) + if !ok || len(messages) == 0 { + return result, nil + } + conversationID, err := s.store.EnsureConversation(ctx, user, conversationKey, map[string]any{ + "source": "ai-gateway", + "conversationKey": conversationKey, + }) + if err != nil { + return preparedTaskRequest{}, err + } + inputs := make([]store.ConversationMessageInput, 0, len(messages)) + for _, rawMessage := range messages { + message, _ := rawMessage.(map[string]any) + if message == nil { + message = map[string]any{"content": rawMessage} + } + hash, assetHashes := canonicalConversationMessageHash(message) + inputs = append(inputs, store.ConversationMessageInput{ + Hash: hash, + Role: stringFromRequestAny(message["role"]), + Snapshot: message, + AssetSHA256s: assetHashes, + }) + } + refs, newCount, err := s.store.UpsertConversationMessages(ctx, conversationID, inputs) + if err != nil { + return preparedTaskRequest{}, err + } + preparedBody["conversationId"] = conversationKey + preparedBody["conversationRecordId"] = conversationID + preparedBody["messageRefs"] = messageRefsForRequest(refs) + preparedBody["newMessageCount"] = newCount + delete(preparedBody, "messages") + result.ConversationID = conversationID + result.MessageRefs = refs + result.NewMessageCount = newCount + return result, nil +} + +func (s *Server) prepareRequestAssetRefs(ctx context.Context, body map[string]any) (map[string]any, error) { + value, err := s.prepareRequestAssetValue(ctx, body, nil, nil) + if err != nil { + return nil, err + } + next, _ := value.(map[string]any) + if next == nil { + return map[string]any{}, nil + } + return next, nil +} + +func (s *Server) prepareRequestAssetValue(ctx context.Context, value any, path []string, siblings map[string]any) (any, error) { + switch typed := value.(type) { + case map[string]any: + if typed["assetRef"] != nil { + return typed, nil + } + next := make(map[string]any, len(typed)) + keys := make([]string, 0, len(typed)) + for key := range typed { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + item := typed[key] + if decoded, ok, err := requestAssetFromValue(key, path, item, typed); err != nil { + return nil, err + } else if ok { + ref, err := s.ensureRequestAsset(ctx, decoded) + if err != nil { + return nil, err + } + next[key] = requestAssetWrapper(ref) + continue + } + prepared, err := s.prepareRequestAssetValue(ctx, item, append(path, key), typed) + if err != nil { + return nil, err + } + next[key] = prepared + } + return next, nil + case []any: + next := make([]any, 0, len(typed)) + for index, item := range typed { + prepared, err := s.prepareRequestAssetValue(ctx, item, append(path, fmt.Sprintf("[%d]", index)), siblings) + if err != nil { + return nil, err + } + next = append(next, prepared) + } + return next, nil + default: + return value, nil + } +} + +func requestAssetFromValue(key string, path []string, value any, siblings map[string]any) (decodedRequestAsset, bool, error) { + text, ok := value.(string) + if !ok { + return decodedRequestAsset{}, false, nil + } + raw := strings.TrimSpace(text) + if raw == "" || mediaURLString(raw) { + return decodedRequestAsset{}, false, nil + } + if strings.HasPrefix(strings.ToLower(raw), "data:") { + contentType, encoded, ok, err := parseRequestDataURL(raw) + if err != nil || !ok { + return decodedRequestAsset{}, false, err + } + payload, err := decodeRequestBase64(encoded) + if err != nil { + return decodedRequestAsset{}, false, requestAssetDecodeError(err) + } + return decodedRequestAsset{ + Bytes: payload, + ContentType: requestAssetContentType(contentType, payload, key, path, siblings), + }, true, nil + } + strict := strictRequestBase64Field(key, path) + if !strict && !likelyRequestBase64MediaField(key, path, raw) { + return decodedRequestAsset{}, false, nil + } + payload, err := decodeRequestBase64(raw) + if err != nil { + if strict { + return decodedRequestAsset{}, false, requestAssetDecodeError(err) + } + return decodedRequestAsset{}, false, nil + } + contentType := requestAssetContentType("", payload, key, path, siblings) + if !strict && !requestContentTypeIsMedia(contentType) { + return decodedRequestAsset{}, false, nil + } + return decodedRequestAsset{Bytes: payload, ContentType: contentType}, true, nil +} + +func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) { + sum := sha256.Sum256(decoded.Bytes) + sha := hex.EncodeToString(sum[:]) + contentType := strings.TrimSpace(decoded.ContentType) + if contentType == "" { + contentType = "application/octet-stream" + } + now := time.Now() + if existing, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) { + return nil, err + } else if ok && requestAssetStillUsable(existing, now) { + if err := s.store.IncrementRequestAssetRefCount(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) { + return nil, err + } + return requestAssetRef(existing), nil + } + + upload, err := s.runner.UploadFile(ctx, runner.FileUploadPayload{ + Bytes: decoded.Bytes, + ContentType: contentType, + FileName: requestAssetFileName(sha, contentType), + Scene: store.FileStorageSceneRequestAsset, + Source: "ai-gateway-request", + }) + if err != nil { + return nil, err + } + storageProvider := requestAssetStorageProvider(upload) + url := stringFromRequestAny(upload["url"]) + if url == "" { + return nil, &clients.ClientError{Code: "request_asset_upload_failed", Message: "file storage response did not include url", Retryable: false} + } + var expiresAt *time.Time + localPath := "" + if storageProvider == "local_static" { + expiry := now.Add(time.Duration(s.localTempAssetTTLHours()) * time.Hour) + expiresAt = &expiry + localPath = requestAssetLocalPath(s.cfg.LocalUploadedStorageDir, stringFromRequestAny(upload["fileName"])) + } + asset, err := s.store.UpsertRequestAsset(ctx, store.RequestAssetInput{ + SHA256: sha, + ContentType: contentType, + ByteSize: int64(len(decoded.Bytes)), + URL: url, + StorageProvider: storageProvider, + LocalPath: localPath, + ExpiresAt: expiresAt, + }) + if err != nil { + if store.IsUndefinedDatabaseObject(err) { + return map[string]any{ + "sha256": sha, + "url": url, + "contentType": contentType, + "size": len(decoded.Bytes), + "storageProvider": storageProvider, + "expiresAt": timePtrToRFC3339(expiresAt), + }, nil + } + return nil, err + } + return requestAssetRef(asset), nil +} + +func requestConversationKey(r *http.Request, body map[string]any) string { + if r != nil { + if value := strings.TrimSpace(r.Header.Get("X-EasyAI-Conversation-ID")); value != "" { + return value + } + } + if value := firstNonEmptyRequestString(body, "conversation_id", "conversationId"); value != "" { + return value + } + if metadata, ok := body["metadata"].(map[string]any); ok { + return firstNonEmptyRequestString(metadata, "conversation_id", "conversationId") + } + return "" +} + +func messageRefsForRequest(refs []store.TaskMessageRefInput) []any { + out := make([]any, 0, len(refs)) + for _, ref := range refs { + out = append(out, map[string]any{"messageId": ref.MessageID, "position": ref.Position}) + } + return out +} + +func canonicalConversationMessageHash(message map[string]any) (string, []string) { + parts := map[string]any{ + "role": stringFromRequestAny(message["role"]), + } + texts := []string{} + assetHashes := []string{} + collectConversationMessageParts(message, &texts, &assetHashes) + parts["text"] = strings.Join(texts, "\n") + parts["assets"] = assetHashes + if parts["text"] == "" && len(assetHashes) == 0 { + encoded, _ := json.Marshal(message) + parts["fallback"] = string(encoded) + } + encoded, _ := json.Marshal(parts) + sum := sha256.Sum256(encoded) + return hex.EncodeToString(sum[:]), uniqueRequestStrings(assetHashes) +} + +func collectConversationMessageParts(value any, texts *[]string, assetHashes *[]string) { + switch typed := value.(type) { + case map[string]any: + if ref, ok := typed["assetRef"].(map[string]any); ok { + if sha := stringFromRequestAny(ref["sha256"]); sha != "" { + *assetHashes = append(*assetHashes, sha) + } + } + for key, item := range typed { + switch key { + case "content", "text": + if text := stringFromRequestAny(item); text != "" && !strings.HasPrefix(strings.ToLower(text), "data:") { + *texts = append(*texts, text) + continue + } + } + switch item.(type) { + case map[string]any, []any: + collectConversationMessageParts(item, texts, assetHashes) + } + } + case []any: + for _, item := range typed { + collectConversationMessageParts(item, texts, assetHashes) + } + case string: + text := strings.TrimSpace(typed) + if text != "" && !mediaURLString(text) && !strings.HasPrefix(strings.ToLower(text), "data:") { + *texts = append(*texts, text) + } + } +} + +func requestAssetWrapper(ref map[string]any) map[string]any { + return map[string]any{ + "assetRef": ref, + "url": ref["url"], + } +} + +func requestAssetRef(asset store.RequestAsset) map[string]any { + return map[string]any{ + "sha256": asset.SHA256, + "url": asset.URL, + "contentType": asset.ContentType, + "size": asset.ByteSize, + "storageProvider": asset.StorageProvider, + "expiresAt": timePtrToRFC3339(asset.ExpiresAt), + } +} + +func requestAssetStillUsable(asset store.RequestAsset, now time.Time) bool { + if asset.ExpiredAt != nil { + return false + } + if asset.ExpiresAt != nil && !asset.ExpiresAt.After(now) { + return false + } + return strings.TrimSpace(asset.URL) != "" +} + +func requestAssetStorageProvider(upload map[string]any) string { + if channel, ok := upload["storageChannel"].(map[string]any); ok { + if provider := stringFromRequestAny(channel["provider"]); provider != "" { + return provider + } + } + return "unknown" +} + +func requestAssetLocalPath(storageDir string, fileName string) string { + if strings.TrimSpace(storageDir) == "" { + storageDir = config.DefaultLocalUploadedStorageDir + } + if strings.TrimSpace(fileName) == "" { + return "" + } + return filepath.Join(storageDir, filepath.Base(fileName)) +} + +func requestAssetFileName(sha string, contentType string) string { + prefix := sha + if len(prefix) > 16 { + prefix = prefix[:16] + } + return requestAssetFilePrefix + prefix + requestAssetExtension(contentType) +} + +func requestAssetExtension(contentType string) string { + switch strings.ToLower(strings.TrimSpace(contentType)) { + case "image/png": + return ".png" + case "image/jpeg", "image/jpg": + return ".jpg" + case "image/webp": + return ".webp" + case "image/gif": + return ".gif" + case "audio/mpeg": + return ".mp3" + case "audio/wav", "audio/x-wav": + return ".wav" + case "video/mp4": + return ".mp4" + } + if extensions, err := mime.ExtensionsByType(contentType); err == nil && len(extensions) > 0 { + return extensions[0] + } + return ".bin" +} + +func requestAssetContentType(explicit string, payload []byte, key string, path []string, siblings map[string]any) string { + if contentType := firstNonEmptyRequestString(siblings, "content_type", "contentType", "mime_type", "mimeType"); contentType != "" { + return contentType + } + if explicit = strings.TrimSpace(explicit); explicit != "" { + return explicit + } + detected := "" + if len(payload) > 0 { + detected = http.DetectContentType(payload) + } + if requestContentTypeIsMedia(detected) { + return detected + } + switch requestMediaKind(key, path, siblings) { + case "audio": + return "audio/mpeg" + case "video": + return "video/mp4" + default: + return "image/png" + } +} + +func parseRequestDataURL(value string) (string, string, bool, error) { + prefix, payload, ok := strings.Cut(value, ",") + if !ok { + return "", "", false, &clients.ClientError{Code: "request_asset_decode_failed", Message: "invalid data URL media payload", Retryable: false} + } + meta := strings.TrimPrefix(strings.TrimPrefix(prefix, "data:"), "DATA:") + parts := strings.Split(meta, ";") + contentType := strings.TrimSpace(parts[0]) + for _, part := range parts[1:] { + if strings.EqualFold(strings.TrimSpace(part), "base64") { + return contentType, payload, true, nil + } + } + return "", "", false, &clients.ClientError{Code: "request_asset_decode_failed", Message: "data URL media payload is not base64 encoded", Retryable: false} +} + +func decodeRequestBase64(value string) ([]byte, error) { + normalized := strings.Map(func(r rune) rune { + switch r { + case '\n', '\r', '\t', ' ': + return -1 + default: + return r + } + }, strings.TrimSpace(value)) + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + var lastErr error + for _, encoding := range encodings { + payload, err := encoding.DecodeString(normalized) + if err == nil && len(payload) > 0 { + return payload, nil + } + if err != nil { + lastErr = err + } + } + if lastErr == nil { + lastErr = fmt.Errorf("empty base64 payload") + } + return nil, lastErr +} + +func requestAssetDecodeError(err error) error { + return &clients.ClientError{Code: "request_asset_decode_failed", Message: err.Error(), Retryable: false} +} + +func strictRequestBase64Field(key string, path []string) bool { + lower := strings.ToLower(strings.TrimSpace(key)) + parent := "" + if len(path) > 0 { + parent = strings.ToLower(strings.Trim(path[len(path)-1], "[]")) + } + return lower == "b64_json" || + lower == "base64" || + lower == "b64" || + strings.Contains(lower, "base64") || + strings.Contains(lower, "_b64") || + (parent == "input_audio" && lower == "data") +} + +func likelyRequestBase64MediaField(key string, path []string, value string) bool { + if len(value) < 64 { + return false + } + return requestMediaKind(key, path, nil) != "" +} + +func requestMediaKind(key string, path []string, siblings map[string]any) string { + candidates := append([]string{key}, path...) + if siblings != nil { + candidates = append(candidates, stringFromRequestAny(siblings["type"]), stringFromRequestAny(siblings["role"])) + } + for _, item := range candidates { + lower := strings.ToLower(strings.TrimSpace(item)) + switch { + case strings.Contains(lower, "audio"): + return "audio" + case strings.Contains(lower, "video"): + return "video" + case strings.Contains(lower, "image") || lower == "mask" || lower == "b64_json": + return "image" + } + } + return "" +} + +func requestContentTypeIsMedia(contentType string) bool { + lower := strings.ToLower(strings.TrimSpace(contentType)) + return strings.HasPrefix(lower, "image/") || + strings.HasPrefix(lower, "video/") || + strings.HasPrefix(lower, "audio/") +} + +func mediaURLString(value string) bool { + raw := strings.TrimSpace(value) + if raw == "" { + return false + } + lower := strings.ToLower(raw) + if strings.HasPrefix(lower, "data:") { + return false + } + return strings.HasPrefix(lower, "http://") || + strings.HasPrefix(lower, "https://") || + strings.HasPrefix(lower, "/") || + strings.Contains(lower, "://") +} + +func firstNonEmptyRequestString(values map[string]any, keys ...string) string { + for _, key := range keys { + if values == nil { + continue + } + if value := stringFromRequestAny(values[key]); value != "" { + return value + } + } + return "" +} + +func stringFromRequestAny(value any) string { + text, _ := value.(string) + return strings.TrimSpace(text) +} + +func uniqueRequestStrings(values []string) []string { + seen := map[string]bool{} + out := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" || seen[value] { + continue + } + seen[value] = true + out = append(out, value) + } + return out +} + +func timePtrToRFC3339(value *time.Time) any { + if value == nil { + return nil + } + return value.UTC().Format(time.RFC3339) +} + +func (s *Server) localTempAssetTTLHours() int { + if s.cfg.LocalTempAssetTTLHours <= 0 { + return 24 + } + return s.cfg.LocalTempAssetTTLHours +} diff --git a/apps/api/internal/httpapi/request_preparation_test.go b/apps/api/internal/httpapi/request_preparation_test.go new file mode 100644 index 0000000..61dbe0e --- /dev/null +++ b/apps/api/internal/httpapi/request_preparation_test.go @@ -0,0 +1,136 @@ +package httpapi + +import ( + "context" + "encoding/base64" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" +) + +func TestRequestAssetFromValueDetectsDataURLAndRawBase64(t *testing.T) { + payload := base64.StdEncoding.EncodeToString([]byte("inline image")) + decoded, ok, err := requestAssetFromValue("url", []string{"messages", "[0]", "content", "[1]", "image_url"}, "data:image/png;base64,"+payload, nil) + if err != nil { + t.Fatalf("decode data URL: %v", err) + } + if !ok || decoded.ContentType != "image/png" || string(decoded.Bytes) != "inline image" { + t.Fatalf("unexpected data URL asset: ok=%v decoded=%+v", ok, decoded) + } + + audio := base64.StdEncoding.EncodeToString([]byte("inline audio")) + decoded, ok, err = requestAssetFromValue("data", []string{"input_audio"}, audio, map[string]any{"format": "mp3"}) + if err != nil { + t.Fatalf("decode raw audio: %v", err) + } + if !ok || decoded.ContentType != "audio/mpeg" || string(decoded.Bytes) != "inline audio" { + t.Fatalf("unexpected raw audio asset: ok=%v decoded=%+v", ok, decoded) + } +} + +func TestCanonicalConversationMessageHashUsesTextAndAssetRefs(t *testing.T) { + message := map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe it"}, + map[string]any{"type": "image_url", "image_url": map[string]any{ + "url": "https://cdn.example/a.png", + "assetRef": map[string]any{"sha256": "sha-a", "url": "https://cdn.example/a.png"}, + }}, + }, + } + sameMessage := map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe it"}, + map[string]any{"type": "image_url", "image_url": map[string]any{ + "url": "https://different.example/a.png", + "assetRef": map[string]any{"sha256": "sha-a", "url": "https://different.example/a.png"}, + }}, + }, + } + changedMessage := map[string]any{ + "role": "user", + "content": "describe something else", + } + + firstHash, assetHashes := canonicalConversationMessageHash(message) + secondHash, _ := canonicalConversationMessageHash(sameMessage) + changedHash, _ := canonicalConversationMessageHash(changedMessage) + if firstHash != secondHash { + t.Fatalf("message hash should ignore resource URL drift when asset sha is stable") + } + if firstHash == changedHash { + t.Fatalf("message hash should change when text changes") + } + if len(assetHashes) != 1 || assetHashes[0] != "sha-a" { + t.Fatalf("unexpected asset hashes: %+v", assetHashes) + } +} + +func TestCleanupExpiredLocalTempAssetsOnlyDeletesExpiredPrefixedFiles(t *testing.T) { + storageDir := t.TempDir() + oldTemp := filepath.Join(storageDir, requestAssetFilePrefix+"old.png") + freshTemp := filepath.Join(storageDir, requestAssetFilePrefix+"fresh.png") + oldGenerated := filepath.Join(storageDir, "gateway-result-old.png") + for _, path := range []string{oldTemp, freshTemp, oldGenerated} { + if err := os.WriteFile(path, []byte("asset"), 0o644); err != nil { + t.Fatalf("write fixture %s: %v", path, err) + } + } + now := time.Now() + if err := os.Chtimes(oldTemp, now.Add(-25*time.Hour), now.Add(-25*time.Hour)); err != nil { + t.Fatalf("touch old temp: %v", err) + } + if err := os.Chtimes(freshTemp, now.Add(-23*time.Hour), now.Add(-23*time.Hour)); err != nil { + t.Fatalf("touch fresh temp: %v", err) + } + if err := os.Chtimes(oldGenerated, now.Add(-25*time.Hour), now.Add(-25*time.Hour)); err != nil { + t.Fatalf("touch old generated: %v", err) + } + server := &Server{ + cfg: config.Config{LocalUploadedStorageDir: storageDir, LocalTempAssetTTLHours: 24}, + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + + deleted := server.cleanupExpiredLocalTempAssets(context.Background(), now) + + if deleted != 1 { + t.Fatalf("expected one expired temp asset delete, got %d", deleted) + } + if _, err := os.Stat(oldTemp); !os.IsNotExist(err) { + t.Fatalf("old prefixed temp asset should be deleted, stat err=%v", err) + } + for _, path := range []string{freshTemp, oldGenerated} { + if _, err := os.Stat(path); err != nil { + t.Fatalf("non-expired or non-prefixed file should remain %s: %v", path, err) + } + } +} + +func TestRequestConversationKeyPriority(t *testing.T) { + request := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil) + request.Header.Set("X-EasyAI-Conversation-ID", "from-header") + body := map[string]any{ + "conversation_id": "from-body", + "metadata": map[string]any{"conversation_id": "from-metadata"}, + } + if got := requestConversationKey(request, body); got != "from-header" { + t.Fatalf("expected header conversation id, got %q", got) + } + request.Header.Del("X-EasyAI-Conversation-ID") + if got := requestConversationKey(request, body); got != "from-body" { + t.Fatalf("expected body conversation id, got %q", got) + } + delete(body, "conversation_id") + if got := requestConversationKey(request, body); got != "from-metadata" { + t.Fatalf("expected metadata conversation id, got %q", got) + } +} diff --git a/apps/api/internal/httpapi/runtime_policy_handlers.go b/apps/api/internal/httpapi/runtime_policy_handlers.go index 5440584..50f37bb 100644 --- a/apps/api/internal/httpapi/runtime_policy_handlers.go +++ b/apps/api/internal/httpapi/runtime_policy_handlers.go @@ -134,6 +134,33 @@ func (s *Server) updatePlatformDynamicPriority(w http.ResponseWriter, r *http.Re writeJSON(w, http.StatusOK, item) } +// restorePlatformModelRuntimeStatus godoc +// @Summary 恢复平台模型运行状态 +// @Description 管理端手动解除平台模型停用、模型冷却、平台冷却或平台禁用状态,使其重新参与路由。 +// @Tags runtime +// @Produce json +// @Security BearerAuth +// @Param platformModelID path string true "平台模型 ID" +// @Success 200 {object} store.ModelRateLimitStatus +// @Failure 401 {object} ErrorEnvelope +// @Failure 403 {object} ErrorEnvelope +// @Failure 404 {object} ErrorEnvelope +// @Failure 500 {object} ErrorEnvelope +// @Router /api/admin/runtime/model-rate-limits/{platformModelID}/restore [post] +func (s *Server) restorePlatformModelRuntimeStatus(w http.ResponseWriter, r *http.Request) { + item, err := s.store.RestorePlatformModelRuntimeStatus(r.Context(), r.PathValue("platformModelID")) + if err != nil { + if store.IsNotFound(err) { + writeError(w, http.StatusNotFound, "platform model not found") + return + } + s.logger.Error("restore platform model runtime status failed", "error", err) + writeError(w, http.StatusInternalServerError, "restore platform model runtime status failed") + return + } + writeJSON(w, http.StatusOK, item) +} + // createRuntimePolicySet godoc // @Summary 创建运行策略集 // @Description 管理端创建运行策略集,policyKey 和 name 必填。 diff --git a/apps/api/internal/httpapi/server.go b/apps/api/internal/httpapi/server.go index 577087a..6b26271 100644 --- a/apps/api/internal/httpapi/server.go +++ b/apps/api/internal/httpapi/server.go @@ -36,6 +36,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor } server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey server.runner.StartAsyncQueueWorker(ctx) + server.startLocalTempAssetCleanup(ctx) mux := http.NewServeMux() mux.HandleFunc("GET /healthz", server.health) @@ -103,6 +104,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("DELETE /api/admin/runtime/policy-sets/{policySetID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteRuntimePolicySet))) mux.Handle("GET /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getRunnerPolicy))) mux.Handle("PATCH /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateRunnerPolicy))) + mux.Handle("POST /api/admin/runtime/model-rate-limits/{platformModelID}/restore", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.restorePlatformModelRuntimeStatus))) mux.Handle("GET /api/admin/config/network-proxy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getNetworkProxyConfig))) mux.Handle("GET /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getFileStorageSettings))) mux.Handle("PATCH /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateFileStorageSettings))) @@ -173,7 +175,7 @@ func (s *Server) cors(next http.Handler) http.Handler { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async, X-EasyAI-Conversation-ID") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") } if r.Method == http.MethodOptions { diff --git a/apps/api/internal/runner/helpers.go b/apps/api/internal/runner/helpers.go index 8b0834f..8684009 100644 --- a/apps/api/internal/runner/helpers.go +++ b/apps/api/internal/runner/helpers.go @@ -11,8 +11,18 @@ func stringFromMap(values map[string]any, key string) string { } func stringFromAny(value any) string { - text, _ := value.(string) - return strings.TrimSpace(text) + switch typed := value.(type) { + case string: + return strings.TrimSpace(typed) + case map[string]any: + if text := stringFromAny(typed["url"]); text != "" { + return text + } + if ref, ok := typed["assetRef"].(map[string]any); ok { + return stringFromAny(ref["url"]) + } + } + return "" } func boolFromMap(values map[string]any, key string) bool { diff --git a/apps/api/internal/runner/request_assets.go b/apps/api/internal/runner/request_assets.go new file mode 100644 index 0000000..91dd665 --- /dev/null +++ b/apps/api/internal/runner/request_assets.go @@ -0,0 +1,249 @@ +package runner + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func (s *Service) restoreTaskRequestReferences(ctx context.Context, task store.GatewayTask) (map[string]any, error) { + body := cloneMap(task.Request) + if body["messages"] != nil || body["messageRefs"] == nil || s.store == nil { + return body, nil + } + refs, err := s.store.ListTaskConversationMessages(ctx, task.ID) + if err != nil { + return nil, err + } + if len(refs) == 0 { + return body, nil + } + messages := make([]any, 0, len(refs)) + for _, ref := range refs { + messages = append(messages, ref.Message) + } + body["messages"] = messages + return body, nil +} + +func (s *Service) slimTaskRequestSnapshot(task store.GatewayTask, body map[string]any) map[string]any { + out := cloneMap(body) + messageRefs := task.Request["messageRefs"] + if messageRefs == nil { + return out + } + delete(out, "messages") + out["messageRefs"] = messageRefs + for _, key := range []string{"conversationId", "conversationRecordId", "newMessageCount"} { + if value := task.Request[key]; value != nil { + out[key] = value + } + } + return out +} + +func (s *Service) slimParameterPreprocessingLog(task store.GatewayTask, log parameterPreprocessingLog) parameterPreprocessingLog { + log.Input = s.slimTaskRequestSnapshot(task, log.Input) + log.Output = s.slimTaskRequestSnapshot(task, log.Output) + return log +} + +func (s *Service) hydrateProviderRequestAssets(ctx context.Context, body map[string]any) (map[string]any, error) { + value, err := s.hydrateProviderRequestAssetValue(ctx, body, nil) + if err != nil { + return nil, err + } + out, _ := value.(map[string]any) + if out == nil { + return map[string]any{}, nil + } + return out, nil +} + +func (s *Service) hydrateProviderRequestAssetValue(ctx context.Context, value any, path []string) (any, error) { + switch typed := value.(type) { + case map[string]any: + if ref, ok := typed["assetRef"].(map[string]any); ok { + return s.hydrateProviderRequestAssetRef(ctx, ref, path) + } + next := make(map[string]any, len(typed)) + for key, item := range typed { + hydrated, err := s.hydrateProviderRequestAssetValue(ctx, item, append(path, key)) + if err != nil { + return nil, err + } + next[key] = hydrated + } + return next, nil + case []any: + next := make([]any, 0, len(typed)) + for index, item := range typed { + hydrated, err := s.hydrateProviderRequestAssetValue(ctx, item, append(path, fmt.Sprintf("[%d]", index))) + if err != nil { + return nil, err + } + next = append(next, hydrated) + } + return next, nil + default: + return value, nil + } +} + +func (s *Service) hydrateProviderRequestAssetRef(ctx context.Context, ref map[string]any, path []string) (any, error) { + asset, err := s.resolveRequestAsset(ctx, ref) + if err != nil { + return nil, err + } + if providerFieldNeedsBase64(path) { + payload, err := s.readRequestAssetBytes(ctx, asset) + if err != nil { + return nil, err + } + return base64.StdEncoding.EncodeToString(payload), nil + } + if strings.TrimSpace(asset.URL) == "" { + return nil, requestAssetExpiredError(asset) + } + return asset.URL, nil +} + +func (s *Service) resolveRequestAsset(ctx context.Context, ref map[string]any) (store.RequestAsset, error) { + sha := stringFromAny(ref["sha256"]) + contentType := stringFromAny(ref["contentType"]) + asset := store.RequestAsset{ + SHA256: sha, + ContentType: contentType, + URL: stringFromAny(ref["url"]), + StorageProvider: stringFromAny(ref["storageProvider"]), + } + if size := floatFromAny(ref["size"]); size > 0 { + asset.ByteSize = int64(size) + } + if expiresAt := stringFromAny(ref["expiresAt"]); expiresAt != "" { + if parsed, err := time.Parse(time.RFC3339, expiresAt); err == nil { + asset.ExpiresAt = &parsed + } + } + if s.store != nil && sha != "" && contentType != "" { + if stored, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) { + return store.RequestAsset{}, err + } else if ok { + asset = stored + } + } + if requestAssetIsExpired(asset, time.Now()) { + return store.RequestAsset{}, requestAssetExpiredError(asset) + } + return asset, nil +} + +func (s *Service) readRequestAssetBytes(ctx context.Context, asset store.RequestAsset) ([]byte, error) { + if requestAssetIsExpired(asset, time.Now()) { + return nil, requestAssetExpiredError(asset) + } + if strings.TrimSpace(asset.LocalPath) != "" { + payload, err := os.ReadFile(asset.LocalPath) + if err != nil { + return nil, requestAssetExpiredError(asset) + } + return payload, nil + } + if localPath := s.localPathFromRequestAssetURL(asset.URL); localPath != "" { + payload, err := os.ReadFile(localPath) + if err != nil { + return nil, requestAssetExpiredError(asset) + } + return payload, nil + } + if strings.HasPrefix(asset.URL, "http://") || strings.HasPrefix(asset.URL, "https://") { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, asset.URL, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: err.Error(), Retryable: true} + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: resp.Status, StatusCode: resp.StatusCode, Retryable: clients.HTTPRetryable(resp.StatusCode)} + } + payload, err := io.ReadAll(io.LimitReader(resp.Body, 256<<20)) + if err != nil { + return nil, &clients.ClientError{Code: "request_asset_fetch_failed", Message: err.Error(), Retryable: true} + } + return payload, nil + } + return nil, requestAssetExpiredError(asset) +} + +func (s *Service) localPathFromRequestAssetURL(value string) string { + raw := strings.TrimSpace(value) + if raw == "" { + return "" + } + pathValue := raw + if parsed, err := url.Parse(raw); err == nil && parsed.Path != "" { + pathValue = parsed.Path + } + const uploadedPrefix = "/static/uploaded/" + if !strings.HasPrefix(pathValue, uploadedPrefix) { + return "" + } + fileName := filepath.Base(strings.TrimPrefix(pathValue, uploadedPrefix)) + if !strings.HasPrefix(fileName, "gateway-request-asset-") { + return "" + } + storageDir := strings.TrimSpace(s.cfg.LocalUploadedStorageDir) + if storageDir == "" { + storageDir = config.DefaultLocalUploadedStorageDir + } + return filepath.Join(storageDir, fileName) +} + +func providerFieldNeedsBase64(path []string) bool { + if len(path) == 0 { + return false + } + key := strings.ToLower(strings.Trim(path[len(path)-1], "[]")) + parent := "" + if len(path) > 1 { + parent = strings.ToLower(strings.Trim(path[len(path)-2], "[]")) + } + return key == "b64_json" || + key == "base64" || + key == "b64" || + strings.Contains(key, "base64") || + strings.Contains(key, "_b64") || + (parent == "input_audio" && key == "data") +} + +func requestAssetIsExpired(asset store.RequestAsset, now time.Time) bool { + if asset.ExpiredAt != nil { + return true + } + if asset.ExpiresAt != nil && !asset.ExpiresAt.After(now) { + return true + } + return false +} + +func requestAssetExpiredError(asset store.RequestAsset) error { + message := "request asset is expired or unavailable" + if asset.SHA256 != "" { + message = "request asset is expired or unavailable: " + asset.SHA256 + } + return &clients.ClientError{Code: "request_asset_expired", Message: message, Retryable: false} +} diff --git a/apps/api/internal/runner/request_assets_test.go b/apps/api/internal/runner/request_assets_test.go new file mode 100644 index 0000000..1610010 --- /dev/null +++ b/apps/api/internal/runner/request_assets_test.go @@ -0,0 +1,101 @@ +package runner + +import ( + "context" + "encoding/base64" + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestHydrateProviderRequestAssetsConvertsStrictBase64Field(t *testing.T) { + storageDir := t.TempDir() + fileName := "gateway-request-asset-test.png" + if err := os.WriteFile(filepath.Join(storageDir, fileName), []byte("image bytes"), 0o644); err != nil { + t.Fatalf("write request asset: %v", err) + } + service := &Service{cfg: config.Config{LocalUploadedStorageDir: storageDir}} + body := map[string]any{ + "model": "demo", + "b64_json": map[string]any{ + "assetRef": map[string]any{ + "sha256": "sha-test", + "contentType": "image/png", + "url": "/static/uploaded/" + fileName, + "storageProvider": "local_static", + }, + "url": "/static/uploaded/" + fileName, + }, + } + + hydrated, err := service.hydrateProviderRequestAssets(context.Background(), body) + if err != nil { + t.Fatalf("hydrate request assets: %v", err) + } + if got, want := stringFromAny(hydrated["b64_json"]), base64.StdEncoding.EncodeToString([]byte("image bytes")); got != want { + t.Fatalf("unexpected hydrated base64: got %q want %q", got, want) + } +} + +func TestHydrateProviderRequestAssetsReturnsExpiredError(t *testing.T) { + expiredAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339) + service := &Service{} + body := map[string]any{ + "image_url": map[string]any{ + "assetRef": map[string]any{ + "sha256": "sha-expired", + "contentType": "image/png", + "url": "/static/uploaded/gateway-request-asset-expired.png", + "storageProvider": "local_static", + "expiresAt": expiredAt, + }, + "url": "/static/uploaded/gateway-request-asset-expired.png", + }, + } + + _, err := service.hydrateProviderRequestAssets(context.Background(), body) + if err == nil { + t.Fatal("expected expired request asset error") + } + var clientErr *clients.ClientError + if !errors.As(err, &clientErr) || clientErr.Code != "request_asset_expired" { + t.Fatalf("expected request_asset_expired, got %T %v", err, err) + } +} + +func TestStringFromAnyReadsRequestAssetWrapperURL(t *testing.T) { + wrapper := map[string]any{ + "assetRef": map[string]any{"url": "https://cdn.example/request.png"}, + } + if got := stringFromAny(wrapper); got != "https://cdn.example/request.png" { + t.Fatalf("expected wrapper URL, got %q", got) + } +} + +func TestSlimTaskRequestSnapshotKeepsMessageRefs(t *testing.T) { + service := &Service{} + task := store.GatewayTask{Request: map[string]any{ + "conversationId": "conv-1", + "messageRefs": []any{map[string]any{"messageId": "msg-1", "position": 0}}, + "newMessageCount": 1, + }} + body := map[string]any{ + "model": "demo", + "messages": []any{map[string]any{"role": "user", "content": "hello"}}, + } + + snapshot := service.slimTaskRequestSnapshot(task, body) + + if snapshot["messages"] != nil { + t.Fatalf("snapshot should not persist restored messages: %+v", snapshot) + } + if snapshot["messageRefs"] == nil || snapshot["newMessageCount"] != 1 { + t.Fatalf("snapshot should keep message refs and new count: %+v", snapshot) + } +} diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index e9ff718..c3635a0 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -87,9 +87,13 @@ func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, use func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) { executeStartedAt := time.Now() - body := normalizeRequest(task.Kind, task.Request) + restoredRequest, err := s.restoreTaskRequestReferences(ctx, task) + if err != nil { + return Result{}, err + } + body := normalizeRequest(task.Kind, restoredRequest) modelType := modelTypeFromKind(task.Kind, body) - if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil { + if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, s.slimTaskRequestSnapshot(task, body)); err != nil { return Result{}, err } if task.Status != "running" { @@ -194,7 +198,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut } return Result{Task: failed, Output: failed.Result}, clientErr } - if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, firstCandidateBody); err != nil { + if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, s.slimTaskRequestSnapshot(task, firstCandidateBody)); err != nil { return Result{}, err } estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0]) @@ -508,13 +512,13 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user QueueKey: candidate.QueueKey, Status: "running", Simulated: simulated, - RequestSnapshot: body, + RequestSnapshot: s.slimTaskRequestSnapshot(task, body), Metrics: baseAttemptMetrics, }) if err != nil { return clients.Response{}, fmt.Errorf("create task attempt: %w", err) } - if err := s.recordTaskParameterPreprocessing(ctx, task.ID, attemptID, attemptNo, candidate, preprocessing); err != nil { + if err := s.recordTaskParameterPreprocessing(ctx, task.ID, attemptID, attemptNo, candidate, s.slimParameterPreprocessingLog(task, preprocessing)); err != nil { clientErr := &clients.ClientError{Code: "runtime_error", Message: err.Error(), Retryable: false} _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, @@ -560,12 +564,24 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user return clients.Response{}, fmt.Errorf("prepare http client: %w", err) } client := s.clientFor(candidate, simulated) + providerBody, err := s.hydrateProviderRequestAssets(ctx, body) + if err != nil { + _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ + AttemptID: attemptID, + Status: "failed", + Retryable: false, + Metrics: mergeMetrics(baseAttemptMetrics, map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(err, false)}}), + ErrorCode: clients.ErrorCode(err), + ErrorMessage: err.Error(), + }) + return clients.Response{}, err + } callStartedAt := time.Now() response, err := client.Run(ctx, clients.Request{ Kind: task.Kind, ModelType: candidate.ModelType, Model: task.Model, - Body: body, + Body: providerBody, Candidate: candidate, HTTPClient: requestHTTPClient, RemoteTaskID: task.RemoteTaskID, @@ -576,7 +592,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user } return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload) }, - Stream: boolFromMap(body, "stream"), + Stream: boolFromMap(providerBody, "stream"), StreamDelta: onDelta, }) callFinishedAt := time.Now() @@ -826,7 +842,7 @@ func (s *Service) recordFailedAttempt(ctx context.Context, input failedAttemptRe QueueKey: queueKey, Status: "running", Simulated: input.Simulated, - RequestSnapshot: input.Body, + RequestSnapshot: s.slimTaskRequestSnapshot(input.Task, input.Body), Metrics: metrics, }) if err != nil { @@ -834,7 +850,8 @@ func (s *Service) recordFailedAttempt(ctx context.Context, input failedAttemptRe return attemptNo } if input.Preprocessing != nil && input.Candidate != nil { - if err := s.recordTaskParameterPreprocessing(ctx, input.Task.ID, attemptID, attemptNo, *input.Candidate, *input.Preprocessing); err != nil { + preprocessing := s.slimParameterPreprocessingLog(input.Task, *input.Preprocessing) + if err := s.recordTaskParameterPreprocessing(ctx, input.Task.ID, attemptID, attemptNo, *input.Candidate, preprocessing); err != nil { s.logger.Warn("record failed attempt parameter preprocessing failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err) } } diff --git a/apps/api/internal/store/conversations.go b/apps/api/internal/store/conversations.go new file mode 100644 index 0000000..e686c64 --- /dev/null +++ b/apps/api/internal/store/conversations.go @@ -0,0 +1,149 @@ +package store + +import ( + "context" + "encoding/json" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" + "github.com/jackc/pgx/v5" +) + +type ConversationMessageInput struct { + Hash string + Role string + Snapshot map[string]any + AssetSHA256s []string +} + +type TaskMessageRefInput struct { + MessageID string + Position int +} + +type ConversationMessageRef struct { + MessageID string `json:"messageId"` + Position int `json:"position"` + Message map[string]any `json:"message"` +} + +func (s *Store) EnsureConversation(ctx context.Context, user *auth.User, conversationKey string, metadata map[string]any) (string, error) { + conversationKey = strings.TrimSpace(conversationKey) + if conversationKey == "" { + return "", nil + } + userID := "" + gatewayUserID := "" + if user != nil { + userID = strings.TrimSpace(user.ID) + gatewayUserID = strings.TrimSpace(user.GatewayUserID) + } + if userID == "" { + userID = "anonymous" + } + metadataJSON, _ := json.Marshal(emptyObjectIfNil(metadata)) + var conversationID string + err := s.pool.QueryRow(ctx, ` +INSERT INTO gateway_conversations (user_id, gateway_user_id, conversation_key, metadata) +VALUES ($1, NULLIF($2, '')::uuid, $3, $4::jsonb) +ON CONFLICT (user_id, conversation_key) DO UPDATE +SET gateway_user_id = COALESCE(gateway_conversations.gateway_user_id, EXCLUDED.gateway_user_id), + metadata = gateway_conversations.metadata || EXCLUDED.metadata, + updated_at = now() +RETURNING id::text`, userID, gatewayUserID, conversationKey, string(metadataJSON)).Scan(&conversationID) + return conversationID, err +} + +func (s *Store) UpsertConversationMessages(ctx context.Context, conversationID string, messages []ConversationMessageInput) ([]TaskMessageRefInput, int, error) { + if strings.TrimSpace(conversationID) == "" || len(messages) == 0 { + return nil, 0, nil + } + tx, err := s.pool.Begin(ctx) + if err != nil { + return nil, 0, err + } + defer tx.Rollback(ctx) + + refs := make([]TaskMessageRefInput, 0, len(messages)) + newCount := 0 + for index, message := range messages { + snapshotJSON, _ := json.Marshal(emptyObjectIfNil(message.Snapshot)) + var messageID string + var inserted bool + if err := tx.QueryRow(ctx, ` +INSERT INTO gateway_conversation_messages ( + conversation_id, message_hash, role, message_snapshot, asset_sha256s +) +VALUES ($1::uuid, $2, NULLIF($3, ''), $4::jsonb, $5) +ON CONFLICT (conversation_id, message_hash) DO UPDATE +SET updated_at = gateway_conversation_messages.updated_at +RETURNING id::text, (xmax = 0) AS inserted`, + conversationID, + message.Hash, + message.Role, + string(snapshotJSON), + message.AssetSHA256s, + ).Scan(&messageID, &inserted); err != nil { + return nil, 0, err + } + if inserted { + newCount++ + } + refs = append(refs, TaskMessageRefInput{MessageID: messageID, Position: index}) + } + if err := tx.Commit(ctx); err != nil { + return nil, 0, err + } + return refs, newCount, nil +} + +func (s *Store) ListTaskConversationMessages(ctx context.Context, taskID string) ([]ConversationMessageRef, error) { + rows, err := s.pool.Query(ctx, ` +SELECT refs.message_id::text, refs.position, messages.message_snapshot +FROM gateway_task_message_refs refs +JOIN gateway_conversation_messages messages ON messages.id = refs.message_id +WHERE refs.task_id = $1::uuid +ORDER BY refs.position ASC`, taskID) + if err != nil { + if IsUndefinedDatabaseObject(err) { + return nil, nil + } + return nil, err + } + defer rows.Close() + + items := make([]ConversationMessageRef, 0) + for rows.Next() { + var item ConversationMessageRef + var snapshot []byte + if err := rows.Scan(&item.MessageID, &item.Position, &snapshot); err != nil { + return nil, err + } + item.Message = decodeObject(snapshot) + items = append(items, item) + } + return items, rows.Err() +} + +func insertTaskMessageRefs(ctx context.Context, tx pgx.Tx, taskID string, refs []TaskMessageRefInput) error { + if len(refs) == 0 { + return nil + } + for _, ref := range refs { + if strings.TrimSpace(ref.MessageID) == "" { + continue + } + if _, err := tx.Exec(ctx, ` +INSERT INTO gateway_task_message_refs (task_id, message_id, position) +VALUES ($1::uuid, $2::uuid, $3) +ON CONFLICT (task_id, position) DO UPDATE +SET message_id = EXCLUDED.message_id`, + taskID, + ref.MessageID, + ref.Position, + ); err != nil { + return err + } + } + return nil +} diff --git a/apps/api/internal/store/file_storage_channels.go b/apps/api/internal/store/file_storage_channels.go index 7bcba18..64276f9 100644 --- a/apps/api/internal/store/file_storage_channels.go +++ b/apps/api/internal/store/file_storage_channels.go @@ -12,8 +12,9 @@ import ( const defaultServerMainUploadURL = "http://127.0.0.1:3001/v1/files/upload" const ( - FileStorageSceneUpload = "upload" - FileStorageSceneImageResult = "image_result" + FileStorageSceneUpload = "upload" + FileStorageSceneImageResult = "image_result" + FileStorageSceneRequestAsset = "request_asset" ) const ( diff --git a/apps/api/internal/store/postgres.go b/apps/api/internal/store/postgres.go index 1adf780..a907baa 100644 --- a/apps/api/internal/store/postgres.go +++ b/apps/api/internal/store/postgres.go @@ -380,11 +380,14 @@ type RateLimitWindow struct { } type CreateTaskInput struct { - Kind string `json:"kind"` - Model string `json:"model"` - RunMode string `json:"runMode"` - Async bool `json:"async"` - Request map[string]any `json:"request"` + Kind string `json:"kind"` + Model string `json:"model"` + RunMode string `json:"runMode"` + Async bool `json:"async"` + Request map[string]any `json:"request"` + ConversationID string `json:"conversationId"` + NewMessageCount int `json:"newMessageCount"` + MessageRefs []TaskMessageRefInput `json:"messageRefs"` } type GatewayTask struct { @@ -407,6 +410,8 @@ type GatewayTask struct { RequestedModel string `json:"requestedModel,omitempty"` ResolvedModel string `json:"resolvedModel,omitempty"` RequestID string `json:"requestId,omitempty"` + ConversationID string `json:"conversationId,omitempty"` + NewMessageCount int `json:"newMessageCount,omitempty"` Request map[string]any `json:"request,omitempty"` AsyncMode bool `json:"asyncMode"` RiverJobID int64 `json:"riverJobId,omitempty"` @@ -438,6 +443,7 @@ COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_ COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''), COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model, COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''), +COALESCE(conversation_id::text, ''), COALESCE(new_message_count, 0), request, COALESCE(async_mode, false), COALESCE(river_job_id, 0), status, COALESCE(attempt_count, 0), COALESCE(remote_task_id, ''), COALESCE(remote_task_payload, '{}'::jsonb), COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb), @@ -1746,15 +1752,18 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut INSERT INTO gateway_tasks ( kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key, api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key, - model, requested_model, request, async_mode, status, result, billings, finished_at + model, requested_model, request, async_mode, status, result, billings, conversation_id, new_message_count, finished_at ) - VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, CASE WHEN $20 THEN now() ELSE NULL END) + VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, NULLIF($20, '')::uuid, $21, CASE WHEN $22 THEN now() ELSE NULL END) RETURNING `+gatewayTaskColumns, - input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, false, + input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, input.ConversationID, input.NewMessageCount, false, )) if err != nil { return GatewayTask{}, err } + if err := insertTaskMessageRefs(ctx, tx, task.ID, input.MessageRefs); err != nil { + return GatewayTask{}, err + } events := taskEventsForCreate(task.ID, runMode, status, nil) for _, event := range events { payload, _ := json.Marshal(event.Payload) @@ -1822,6 +1831,8 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) { &task.RequestedModel, &task.ResolvedModel, &task.RequestID, + &task.ConversationID, + &task.NewMessageCount, &requestBytes, &task.AsyncMode, &task.RiverJobID, diff --git a/apps/api/internal/store/rate_limit_status.go b/apps/api/internal/store/rate_limit_status.go index cd5fef3..7c4cde6 100644 --- a/apps/api/internal/store/rate_limit_status.go +++ b/apps/api/internal/store/rate_limit_status.go @@ -7,6 +7,8 @@ import ( "strconv" "strings" "time" + + "github.com/jackc/pgx/v5" ) type RateLimitMetricStatus struct { @@ -82,6 +84,59 @@ type PlatformPolicyEvent struct { CreatedAt time.Time `json:"createdAt"` } +func (s *Store) RestorePlatformModelRuntimeStatus(ctx context.Context, platformModelID string) (ModelRateLimitStatus, error) { + platformModelID = strings.TrimSpace(platformModelID) + if platformModelID == "" { + return ModelRateLimitStatus{}, pgx.ErrNoRows + } + + tx, err := s.pool.Begin(ctx) + if err != nil { + return ModelRateLimitStatus{}, err + } + defer tx.Rollback(ctx) + + var restoredModelID string + if err := tx.QueryRow(ctx, ` +WITH restored_model AS ( + UPDATE platform_models + SET enabled = true, + cooldown_until = NULL, + updated_at = now() + WHERE id = $1::uuid + RETURNING id::text, platform_id +), +restored_platform AS ( + UPDATE integration_platforms + SET status = 'enabled', + disabled_reason = NULL, + cooldown_until = NULL, + updated_at = now() + WHERE id = (SELECT platform_id FROM restored_model) + AND deleted_at IS NULL + RETURNING id +) +SELECT id +FROM restored_model +WHERE EXISTS (SELECT 1 FROM restored_platform)`, platformModelID).Scan(&restoredModelID); err != nil { + return ModelRateLimitStatus{}, err + } + if err := tx.Commit(ctx); err != nil { + return ModelRateLimitStatus{}, err + } + + items, err := s.ListModelRateLimitStatuses(ctx) + if err != nil { + return ModelRateLimitStatus{}, err + } + for _, item := range items { + if item.PlatformModelID == restoredModelID { + return item, nil + } + } + return ModelRateLimitStatus{}, pgx.ErrNoRows +} + func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) { rows, err := s.pool.Query(ctx, ` SELECT m.id::text, m.platform_id::text, p.name, p.provider, p.status, diff --git a/apps/api/internal/store/request_assets.go b/apps/api/internal/store/request_assets.go new file mode 100644 index 0000000..56a03e0 --- /dev/null +++ b/apps/api/internal/store/request_assets.go @@ -0,0 +1,130 @@ +package store + +import ( + "context" + "database/sql" + "time" + + "github.com/jackc/pgx/v5" +) + +type RequestAsset struct { + ID string `json:"id"` + SHA256 string `json:"sha256"` + ContentType string `json:"contentType"` + ByteSize int64 `json:"byteSize"` + URL string `json:"url"` + StorageProvider string `json:"storageProvider"` + LocalPath string `json:"localPath,omitempty"` + ExpiresAt *time.Time `json:"expiresAt,omitempty"` + ExpiredAt *time.Time `json:"expiredAt,omitempty"` + RefCount int `json:"refCount"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +type RequestAssetInput struct { + SHA256 string + ContentType string + ByteSize int64 + URL string + StorageProvider string + LocalPath string + ExpiresAt *time.Time +} + +func (s *Store) FindRequestAsset(ctx context.Context, sha256 string, contentType string) (RequestAsset, bool, error) { + asset, err := scanRequestAsset(s.pool.QueryRow(ctx, ` +SELECT id::text, sha256, content_type, byte_size, url, storage_provider, + COALESCE(local_path, ''), expires_at, expired_at, ref_count, created_at, updated_at +FROM gateway_request_assets +WHERE sha256 = $1 AND content_type = $2`, sha256, contentType)) + if err != nil { + if err == pgx.ErrNoRows { + return RequestAsset{}, false, nil + } + return RequestAsset{}, false, err + } + return asset, true, nil +} + +func (s *Store) UpsertRequestAsset(ctx context.Context, input RequestAssetInput) (RequestAsset, error) { + return scanRequestAsset(s.pool.QueryRow(ctx, ` +INSERT INTO gateway_request_assets ( + sha256, content_type, byte_size, url, storage_provider, local_path, expires_at, expired_at, ref_count +) +VALUES ($1, $2, $3, $4, $5, NULLIF($6, ''), $7, NULL, 1) +ON CONFLICT (sha256, content_type) DO UPDATE +SET byte_size = EXCLUDED.byte_size, + url = EXCLUDED.url, + storage_provider = EXCLUDED.storage_provider, + local_path = EXCLUDED.local_path, + expires_at = EXCLUDED.expires_at, + expired_at = NULL, + ref_count = gateway_request_assets.ref_count + 1, + updated_at = now() +RETURNING id::text, sha256, content_type, byte_size, url, storage_provider, + COALESCE(local_path, ''), expires_at, expired_at, ref_count, created_at, updated_at`, + input.SHA256, + input.ContentType, + input.ByteSize, + input.URL, + input.StorageProvider, + input.LocalPath, + input.ExpiresAt, + )) +} + +func (s *Store) IncrementRequestAssetRefCount(ctx context.Context, sha256 string, contentType string) error { + _, err := s.pool.Exec(ctx, ` +UPDATE gateway_request_assets +SET ref_count = ref_count + 1, + updated_at = now() +WHERE sha256 = $1 AND content_type = $2`, sha256, contentType) + return err +} + +func (s *Store) MarkRequestAssetExpiredByLocalPath(ctx context.Context, localPath string, expiredAt time.Time) error { + if localPath == "" { + return nil + } + _, err := s.pool.Exec(ctx, ` +UPDATE gateway_request_assets +SET expired_at = COALESCE(expired_at, $2), + updated_at = now() +WHERE local_path = $1 + AND storage_provider = 'local_static' + AND expired_at IS NULL`, localPath, expiredAt) + return err +} + +func scanRequestAsset(scanner interface{ Scan(dest ...any) error }) (RequestAsset, error) { + var asset RequestAsset + var localPath string + var expiresAt sql.NullTime + var expiredAt sql.NullTime + if err := scanner.Scan( + &asset.ID, + &asset.SHA256, + &asset.ContentType, + &asset.ByteSize, + &asset.URL, + &asset.StorageProvider, + &localPath, + &expiresAt, + &expiredAt, + &asset.RefCount, + &asset.CreatedAt, + &asset.UpdatedAt, + ); err != nil { + return RequestAsset{}, err + } + asset.LocalPath = localPath + if expiresAt.Valid { + asset.ExpiresAt = &expiresAt.Time + } + if expiredAt.Valid { + asset.ExpiredAt = &expiredAt.Time + } + return asset, nil +} diff --git a/apps/api/migrations/0045_request_assets_conversation_dedupe.sql b/apps/api/migrations/0045_request_assets_conversation_dedupe.sql new file mode 100644 index 0000000..fe587ef --- /dev/null +++ b/apps/api/migrations/0045_request_assets_conversation_dedupe.sql @@ -0,0 +1,68 @@ +CREATE TABLE IF NOT EXISTS gateway_request_assets ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + sha256 text NOT NULL, + content_type text NOT NULL, + byte_size bigint NOT NULL, + url text NOT NULL, + storage_provider text NOT NULL, + local_path text, + expires_at timestamptz, + expired_at timestamptz, + ref_count integer NOT NULL DEFAULT 0, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (sha256, content_type) +); + +CREATE INDEX IF NOT EXISTS idx_gateway_request_assets_expires + ON gateway_request_assets(storage_provider, expires_at) + WHERE expires_at IS NOT NULL AND expired_at IS NULL; + +CREATE INDEX IF NOT EXISTS idx_gateway_request_assets_local_path + ON gateway_request_assets(local_path) + WHERE local_path IS NOT NULL; + +CREATE TABLE IF NOT EXISTS gateway_conversations ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id text NOT NULL, + gateway_user_id uuid REFERENCES gateway_users(id) ON DELETE SET NULL, + conversation_key text NOT NULL, + metadata jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (user_id, conversation_key) +); + +CREATE TABLE IF NOT EXISTS gateway_conversation_messages ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + conversation_id uuid NOT NULL REFERENCES gateway_conversations(id) ON DELETE CASCADE, + message_hash text NOT NULL, + role text, + message_snapshot jsonb NOT NULL, + asset_sha256s text[] NOT NULL DEFAULT '{}', + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (conversation_id, message_hash) +); + +CREATE INDEX IF NOT EXISTS idx_gateway_conversation_messages_conversation + ON gateway_conversation_messages(conversation_id, created_at); + +CREATE TABLE IF NOT EXISTS gateway_task_message_refs ( + task_id uuid NOT NULL REFERENCES gateway_tasks(id) ON DELETE CASCADE, + message_id uuid NOT NULL REFERENCES gateway_conversation_messages(id) ON DELETE CASCADE, + position integer NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + PRIMARY KEY (task_id, position) +); + +CREATE INDEX IF NOT EXISTS idx_gateway_task_message_refs_message + ON gateway_task_message_refs(message_id); + +ALTER TABLE gateway_tasks + ADD COLUMN IF NOT EXISTS conversation_id uuid REFERENCES gateway_conversations(id) ON DELETE SET NULL, + ADD COLUMN IF NOT EXISTS new_message_count integer NOT NULL DEFAULT 0; + +CREATE INDEX IF NOT EXISTS idx_gateway_tasks_conversation_created + ON gateway_tasks(conversation_id, created_at DESC) + WHERE conversation_id IS NOT NULL; diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx index 1d2cee3..d661428 100644 --- a/apps/web/src/App.tsx +++ b/apps/web/src/App.tsx @@ -84,6 +84,7 @@ import { pollTaskUntilSettled, registerLocalAccount, replacePlatformModels, + restoreModelRuntimeStatus, setUserWalletBalance, type HealthResponse, updateAccessRule, @@ -635,6 +636,50 @@ export function App() { } } + async function restoreRuntimeModel(platformModelId: string) { + setCoreState('loading'); + setCoreMessage(''); + try { + const restored = await restoreModelRuntimeStatus(token, platformModelId); + setModelRateLimits((current) => current.map((status) => { + if (status.platformId !== restored.platformId) return status; + return { + ...status, + platformStatus: 'enabled', + platformCooldownUntil: undefined, + platformDisabledReason: undefined, + ...(status.platformModelId === platformModelId + ? { + enabled: restored.enabled, + modelCooldownUntil: restored.modelCooldownUntil, + } + : {}), + }; + })); + setPlatforms((current) => current.map((platform) => platform.id === restored.platformId + ? { + ...platform, + status: 'enabled', + cooldownUntil: undefined, + } + : platform)); + setModels((current) => current.map((model) => model.id === platformModelId + ? { + ...model, + enabled: true, + cooldownUntil: undefined, + } + : model)); + invalidateDataKeys('modelCatalog', 'modelRateLimits', 'models', 'platforms', 'playgroundModels'); + setCoreState('ready'); + setCoreMessage('模型运行状态已恢复。'); + } catch (err) { + setCoreState('error'); + setCoreMessage(err instanceof Error ? err.message : '恢复模型运行状态失败'); + throw err; + } + } + async function removePlatform(platformId: string) { setCoreState('loading'); setCoreMessage(''); @@ -1143,6 +1188,7 @@ export function App() { onResetBaseModel={resetBaseModelToDefault} onSavePlatform={savePlatformWithModels} onSavePlatformDynamicPriority={savePlatformDynamicPriority} + onRestoreRuntimeModel={restoreRuntimeModel} onTogglePlatformStatus={savePlatformStatus} onSaveProvider={saveProvider} onSavePricingRuleSet={savePricingRuleSet} diff --git a/apps/web/src/api.ts b/apps/web/src/api.ts index c3b2c02..9386fcb 100644 --- a/apps/web/src/api.ts +++ b/apps/web/src/api.ts @@ -913,6 +913,13 @@ export async function listModelRateLimitStatuses(token: string): Promise>('/api/admin/runtime/model-rate-limits', { token }); } +export async function restoreModelRuntimeStatus(token: string, platformModelId: string): Promise { + return request(`/api/admin/runtime/model-rate-limits/${platformModelId}/restore`, { + method: 'POST', + token, + }); +} + export async function getNetworkProxyConfig(token: string): Promise { return request('/api/admin/config/network-proxy', { token }); } diff --git a/apps/web/src/pages/AdminPage.tsx b/apps/web/src/pages/AdminPage.tsx index 3df5d1b..7a24623 100644 --- a/apps/web/src/pages/AdminPage.tsx +++ b/apps/web/src/pages/AdminPage.tsx @@ -71,6 +71,7 @@ export function AdminPage(props: { onBatchAccessRules: (input: GatewayAccessRuleBatchRequest) => Promise; onSavePlatform: (input: PlatformWithModelsInput) => Promise; onSavePlatformDynamicPriority: (platformId: string, input: PlatformDynamicPriorityUpdateRequest) => Promise; + onRestoreRuntimeModel: (platformModelId: string) => Promise; onTogglePlatformStatus: (platform: IntegrationPlatform, status: 'enabled' | 'disabled') => Promise; onSaveProvider: (input: CatalogProviderUpsertRequest, providerId?: string) => Promise; onSavePricingRuleSet: (input: PricingRuleSetUpsertRequest, ruleSetId?: string) => Promise; @@ -173,6 +174,7 @@ export function AdminPage(props: { modelRateLimitsUpdatedAt={props.data.modelRateLimitsUpdatedAt} platforms={props.data.platforms} onSavePlatformDynamicPriority={props.onSavePlatformDynamicPriority} + onRestoreRuntimeModel={props.onRestoreRuntimeModel} /> )} {props.section === 'tenants' && } diff --git a/apps/web/src/pages/admin/RealtimeLoadPanel.tsx b/apps/web/src/pages/admin/RealtimeLoadPanel.tsx index 449a587..fd25a33 100644 --- a/apps/web/src/pages/admin/RealtimeLoadPanel.tsx +++ b/apps/web/src/pages/admin/RealtimeLoadPanel.tsx @@ -9,11 +9,13 @@ export function RealtimeLoadPanel(props: { modelRateLimitsUpdatedAt: number | null; platforms: IntegrationPlatform[]; onSavePlatformDynamicPriority: (platformId: string, input: PlatformDynamicPriorityUpdateRequest) => Promise; + onRestoreRuntimeModel: (platformModelId: string) => Promise; }) { const [now, setNow] = useState(() => Date.now()); const [priorityDialog, setPriorityDialog] = useState(null); const [priorityError, setPriorityError] = useState(''); const [prioritySaving, setPrioritySaving] = useState(false); + const [restoreSavingId, setRestoreSavingId] = useState(null); const platformMap = useMemo(() => new Map(props.platforms.map((item) => [item.id, item])), [props.platforms]); useEffect(() => { @@ -65,6 +67,17 @@ export function RealtimeLoadPanel(props: { } } + async function restoreRuntimeModel(platformModelId: string) { + setRestoreSavingId(platformModelId); + try { + await props.onRestoreRuntimeModel(platformModelId); + } catch { + // App-level state owns the error message. + } finally { + setRestoreSavingId(null); + } + } + return (
@@ -81,6 +94,8 @@ export function RealtimeLoadPanel(props: { statuses={props.modelRateLimits} updatedAt={props.modelRateLimitsUpdatedAt} onAdjustPriority={openPriorityDialog} + onRestoreRuntimeModel={restoreRuntimeModel} + restoreSavingId={restoreSavingId} /> @@ -109,6 +124,8 @@ function RateLimitStatusTable(props: { now: number; updatedAt: number | null; onAdjustPriority: (status: ModelRateLimitStatus, platform: IntegrationPlatform | undefined) => void; + onRestoreRuntimeModel: (platformModelId: string) => Promise; + restoreSavingId: string | null; }) { if (!props.statuses.length) { return ; @@ -150,7 +167,15 @@ function RateLimitStatusTable(props: { {status.provider} - {modelRuntimeStatusCell(status, platform, props.now)} + + {modelRuntimeStatusCell( + status, + platform, + props.now, + props.onRestoreRuntimeModel, + props.restoreSavingId === status.platformModelId, + )} + {platformPriorityCell(status, platform, props.onAdjustPriority)} 0.8 ? 'true' : undefined}> @@ -447,50 +472,104 @@ function shortId(value: string | undefined) { return value.length > 8 ? value.slice(0, 8) : value; } -function modelRuntimeStatusCell(status: ModelRateLimitStatus, platform: IntegrationPlatform | undefined, now: number) { +function modelRuntimeStatusCell( + status: ModelRateLimitStatus, + platform: IntegrationPlatform | undefined, + now: number, + onRestore: (platformModelId: string) => Promise, + restoring: boolean, +) { const modelCooldownMs = cooldownRemainingMs(status.modelCooldownUntil, now); const platformCooldownMs = cooldownRemainingMs(status.platformCooldownUntil, now); const platformStatus = platform?.status || status.platformStatus || 'enabled'; + const restoreButton = runtimeRestoreButton(status, platformStatus, modelCooldownMs, platformCooldownMs, onRestore, restoring); if (modelCooldownMs > 0) { return ( - - 模型冷却中 - 剩余 {formatCooldownRemaining(modelCooldownMs)} + + + 模型冷却中 + 剩余 {formatCooldownRemaining(modelCooldownMs)} + + {restoreButton} ); } if (platformStatus !== 'enabled') { const badge = 已禁用; return ( - } - overlayClassName="priorityDemotionAntPopover" - placement="bottomLeft" - trigger={['hover', 'focus']} - > - - {badge} + + } + overlayClassName="priorityDemotionAntPopover" + placement="bottomLeft" + trigger={['hover', 'focus']} + > + + {badge} + + + {restoreButton} + + ); + } + if (!status.enabled) { + return ( + + + 已停用 + 不参与路由 - + {restoreButton} + ); } if (platformCooldownMs > 0) { return ( - - 平台冷却中 - 剩余 {formatCooldownRemaining(platformCooldownMs)} + + + 平台冷却中 + 剩余 {formatCooldownRemaining(platformCooldownMs)} + + {restoreButton} ); } return ( - - {status.enabled ? '可用' : '已停用'} - {status.enabled ? '参与路由' : '不参与路由'} + + + 可用 + 参与路由 + ); } +function runtimeRestoreButton( + status: ModelRateLimitStatus, + platformStatus: string, + modelCooldownMs: number, + platformCooldownMs: number, + onRestore: (platformModelId: string) => Promise, + restoring: boolean, +) { + const canRestore = modelCooldownMs > 0 || platformCooldownMs > 0 || platformStatus !== 'enabled' || !status.enabled; + if (!canRestore) return null; + return ( + + ); +} + function cooldownRemainingMs(cooldownUntil: string | undefined, now: number) { if (!cooldownUntil) return 0; const until = Date.parse(cooldownUntil); diff --git a/apps/web/src/styles/pages.css b/apps/web/src/styles/pages.css index 79e9356..5c80749 100644 --- a/apps/web/src/styles/pages.css +++ b/apps/web/src/styles/pages.css @@ -1086,8 +1086,8 @@ } .platformLimitTable .shTableRow { - grid-template-columns: minmax(180px, 1.1fr) minmax(160px, 0.9fr) 160px 132px 150px 170px 140px 132px; - min-width: 1224px; + grid-template-columns: minmax(180px, 1.1fr) minmax(160px, 0.9fr) 178px 132px 150px 170px 140px 132px; + min-width: 1242px; } .platformLimitTable .shTableHead, @@ -1130,6 +1130,22 @@ justify-items: start; } +.platformRuntimeStatusCell { + display: grid; + min-width: 0; + gap: 7px; + align-content: start; +} + +.platformRestoreButton { + justify-self: start; + min-height: 22px; + padding-inline: 8px; + border-color: var(--border); + color: var(--text-normal); + background: #fff; +} + .rateMetricCell, .rateLoadCell { display: grid;