easyai-ai-gateway/apps/api/internal/httpapi/handlers.go

1058 lines
32 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package httpapi
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/netproxy"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{
"ok": true,
"service": "easyai-ai-gateway",
"env": s.cfg.AppEnv,
"identityMode": s.cfg.IdentityMode,
})
}
func (s *Server) ready(w http.ResponseWriter, r *http.Request) {
if err := s.store.Ping(r.Context()); err != nil {
writeError(w, http.StatusServiceUnavailable, "postgres unavailable")
return
}
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
}
func (s *Server) me(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
writeJSON(w, http.StatusOK, user)
}
func (s *Server) register(w http.ResponseWriter, r *http.Request) {
if !s.localIdentityEnabled() {
writeError(w, http.StatusForbidden, "local registration is disabled")
return
}
var input store.LocalRegisterInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
user, err := s.store.RegisterLocalUser(r.Context(), input)
if err != nil {
if errors.Is(err, store.ErrWeakPassword) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
if errors.Is(err, store.ErrInvalidInvitation) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
if errors.Is(err, store.ErrUserAlreadyExists) {
writeError(w, http.StatusConflict, err.Error())
return
}
s.logger.Error("register local user failed", "error", err)
writeError(w, http.StatusInternalServerError, "register local user failed")
return
}
s.writeAuthResponse(w, http.StatusCreated, user)
}
func (s *Server) login(w http.ResponseWriter, r *http.Request) {
if !s.localIdentityEnabled() {
writeError(w, http.StatusForbidden, "local login is disabled")
return
}
var input store.LocalLoginInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
user, err := s.store.AuthenticateLocalUser(r.Context(), input)
if err != nil {
if errors.Is(err, store.ErrInvalidCredentials) {
writeError(w, http.StatusUnauthorized, "invalid account or password")
return
}
s.logger.Error("login local user failed", "error", err)
writeError(w, http.StatusInternalServerError, "login failed")
return
}
s.writeAuthResponse(w, http.StatusOK, user)
}
func (s *Server) localIdentityEnabled() bool {
mode := strings.ToLower(strings.TrimSpace(s.cfg.IdentityMode))
return mode == "" || mode == "standalone" || mode == "hybrid"
}
func (s *Server) writeAuthResponse(w http.ResponseWriter, status int, user store.GatewayUser) {
authUser := authUserFromGatewayUser(user)
const ttl = 24 * time.Hour
token, err := s.auth.SignJWT(authUser, ttl)
if err != nil {
s.logger.Error("sign local jwt failed", "error", err)
writeError(w, http.StatusInternalServerError, "token sign failed")
return
}
writeJSON(w, status, map[string]any{
"accessToken": token,
"tokenType": "Bearer",
"expiresIn": int(ttl.Seconds()),
"user": authUser,
})
}
func authUserFromGatewayUser(user store.GatewayUser) *auth.User {
roles := user.Roles
if len(roles) == 0 {
roles = []string{"user"}
}
tenantID := user.TenantID
if tenantID == "" {
tenantID = user.TenantKey
}
return &auth.User{
ID: user.ID,
Username: user.Username,
Roles: roles,
TenantID: tenantID,
GatewayTenantID: user.GatewayTenantID,
TenantKey: user.TenantKey,
Source: "gateway",
GatewayUserID: user.ID,
UserGroupID: user.DefaultUserGroupID,
}
}
func (s *Server) listPlatforms(w http.ResponseWriter, r *http.Request) {
platforms, err := s.store.ListPlatforms(r.Context())
if err != nil {
s.logger.Error("list platforms failed", "error", err)
writeError(w, http.StatusInternalServerError, "list platforms failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": platforms})
}
func (s *Server) listPlayablePlatforms(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
models, err := s.store.ListAccessiblePlatformModels(r.Context(), user)
if err != nil {
s.logger.Error("list playable platform models failed", "error", err)
writeError(w, http.StatusInternalServerError, "list playable platforms failed")
return
}
allowedPlatformIDs := map[string]bool{}
for _, model := range models {
allowedPlatformIDs[model.PlatformID] = true
}
platforms, err := s.store.ListPlatforms(r.Context())
if err != nil {
s.logger.Error("list platforms failed", "error", err)
writeError(w, http.StatusInternalServerError, "list playable platforms failed")
return
}
filtered := platforms[:0]
for _, platform := range platforms {
if platform.Status == "enabled" && allowedPlatformIDs[platform.ID] {
filtered = append(filtered, platform)
}
}
writeJSON(w, http.StatusOK, map[string]any{"items": filtered})
}
func (s *Server) createPlatform(w http.ResponseWriter, r *http.Request) {
var input store.CreatePlatformInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
input.Provider = strings.TrimSpace(input.Provider)
input.Name = strings.TrimSpace(input.Name)
input.InternalName = strings.TrimSpace(input.InternalName)
input.Status = strings.TrimSpace(input.Status)
if input.Provider == "" || input.Name == "" {
writeError(w, http.StatusBadRequest, "provider and name are required")
return
}
if input.Status != "" && input.Status != "enabled" && input.Status != "disabled" {
writeError(w, http.StatusBadRequest, "status must be enabled or disabled")
return
}
if input.AuthType == "" {
input.AuthType = "bearer"
}
config, err := netproxy.NormalizePlatformConfig(input.Config)
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
input.Config = config
platform, err := s.store.CreatePlatform(r.Context(), input)
if err != nil {
s.logger.Error("create platform failed", "error", err)
writeError(w, http.StatusInternalServerError, "create platform failed")
return
}
writeJSON(w, http.StatusCreated, platform)
}
func (s *Server) updatePlatform(w http.ResponseWriter, r *http.Request) {
var input store.CreatePlatformInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
input.Provider = strings.TrimSpace(input.Provider)
input.Name = strings.TrimSpace(input.Name)
input.InternalName = strings.TrimSpace(input.InternalName)
input.Status = strings.TrimSpace(input.Status)
if input.Provider == "" || input.Name == "" {
writeError(w, http.StatusBadRequest, "provider and name are required")
return
}
if input.Status != "" && input.Status != "enabled" && input.Status != "disabled" {
writeError(w, http.StatusBadRequest, "status must be enabled or disabled")
return
}
if input.AuthType == "" {
input.AuthType = "bearer"
}
config, err := netproxy.NormalizePlatformConfig(input.Config)
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
input.Config = config
platform, err := s.store.UpdatePlatform(r.Context(), r.PathValue("platformID"), input)
if err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "platform not found")
return
}
if store.IsUniqueViolation(err) {
writeError(w, http.StatusConflict, "platform key already exists")
return
}
s.logger.Error("update platform failed", "error", err)
writeError(w, http.StatusInternalServerError, "update platform failed")
return
}
writeJSON(w, http.StatusOK, platform)
}
func (s *Server) deletePlatform(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeletePlatform(r.Context(), r.PathValue("platformID")); err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "platform not found")
return
}
s.logger.Error("delete platform failed", "error", err)
writeError(w, http.StatusInternalServerError, "delete platform failed")
return
}
w.WriteHeader(http.StatusNoContent)
}
func (s *Server) createPlatformModel(w http.ResponseWriter, r *http.Request) {
var input store.CreatePlatformModelInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
if pathPlatformID := r.PathValue("platformID"); pathPlatformID != "" {
input.PlatformID = pathPlatformID
}
if input.PlatformID == "" {
writeError(w, http.StatusBadRequest, "platformId is required")
return
}
model, err := s.store.CreatePlatformModel(r.Context(), input)
if err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "base model not found")
return
}
s.logger.Error("create platform model failed", "error", err)
writeError(w, http.StatusInternalServerError, "create platform model failed")
return
}
writeJSON(w, http.StatusCreated, s.platformModelResponse(r.Context(), model))
}
func (s *Server) replacePlatformModels(w http.ResponseWriter, r *http.Request) {
platformID := r.PathValue("platformID")
if platformID == "" {
writeError(w, http.StatusBadRequest, "platformId is required")
return
}
var input struct {
Models []store.CreatePlatformModelInput `json:"models"`
}
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
models, err := s.store.ReplacePlatformModels(r.Context(), platformID, input.Models)
if err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "base model not found")
return
}
s.logger.Error("replace platform models failed", "error", err)
writeError(w, http.StatusInternalServerError, "replace platform models failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
}
func (s *Server) deletePlatformModel(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeletePlatformModel(r.Context(), r.PathValue("modelID")); err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "platform model not found")
return
}
s.logger.Error("delete platform model failed", "error", err)
writeError(w, http.StatusInternalServerError, "delete platform model failed")
return
}
w.WriteHeader(http.StatusNoContent)
}
func (s *Server) listModels(w http.ResponseWriter, r *http.Request) {
models, err := s.store.ListModels(r.Context())
if err != nil {
s.logger.Error("list models failed", "error", err)
writeError(w, http.StatusInternalServerError, "list models failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
}
func (s *Server) listPlayableModels(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
models, err := s.store.ListAccessiblePlatformModels(r.Context(), user)
if err != nil {
s.logger.Error("list playable models failed", "error", err)
writeError(w, http.StatusInternalServerError, "list playable models failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
}
func (s *Server) listPricingRules(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListPricingRules(r.Context())
if err != nil {
s.logger.Error("list pricing rules failed", "error", err)
writeError(w, http.StatusInternalServerError, "list pricing rules failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) listTenants(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListTenants(r.Context())
if err != nil {
s.logger.Error("list tenants failed", "error", err)
writeError(w, http.StatusInternalServerError, "list tenants failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) listUsers(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListUsers(r.Context())
if err != nil {
s.logger.Error("list users failed", "error", err)
writeError(w, http.StatusInternalServerError, "list users failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) listUserGroups(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListUserGroups(r.Context())
if err != nil {
s.logger.Error("list user groups failed", "error", err)
writeError(w, http.StatusInternalServerError, "list user groups failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) listAPIKeys(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
items, err := s.store.ListAPIKeys(r.Context(), user)
if err != nil {
s.logger.Error("list api keys failed", "error", err)
writeError(w, http.StatusInternalServerError, "list api keys failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) listPlayableAPIKeys(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
items, err := s.store.ListPlayableAPIKeys(r.Context(), user)
if err != nil {
if errors.Is(err, store.ErrLocalUserRequired) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
s.logger.Error("list playable api keys failed", "error", err)
writeError(w, http.StatusInternalServerError, "list playable api keys failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) createAPIKey(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
var input store.CreateAPIKeyInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
created, err := s.store.CreateAPIKey(r.Context(), input, user)
if err != nil {
if errors.Is(err, store.ErrLocalUserRequired) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
s.logger.Error("create api key failed", "error", err)
writeError(w, http.StatusInternalServerError, "create api key failed")
return
}
writeJSON(w, http.StatusCreated, created)
}
func (s *Server) disableAPIKey(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
item, err := s.store.DisableAPIKey(r.Context(), r.PathValue("apiKeyID"), user)
if err == nil {
writeJSON(w, http.StatusOK, item)
return
}
if errors.Is(err, store.ErrLocalUserRequired) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "api key not found")
return
}
s.logger.Error("disable api key failed", "error", err)
writeError(w, http.StatusInternalServerError, "disable api key failed")
}
func (s *Server) deleteAPIKey(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
err := s.store.DeleteAPIKey(r.Context(), r.PathValue("apiKeyID"), user)
if err == nil {
w.WriteHeader(http.StatusNoContent)
return
}
if errors.Is(err, store.ErrLocalUserRequired) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "api key not found")
return
}
s.logger.Error("delete api key failed", "error", err)
writeError(w, http.StatusInternalServerError, "delete api key failed")
}
func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
model, _ := body["model"].(string)
kind, _ := body["kind"].(string)
if kind == "" {
kind = "chat.completions"
}
if model == "" {
writeError(w, http.StatusBadRequest, "model is required")
return
}
if !apiKeyScopeAllowed(user, kind) {
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
return
}
estimate, err := s.runner.Estimate(r.Context(), kind, model, body, user)
if err != nil {
if errors.Is(err, store.ErrNoModelCandidate) {
writeError(w, statusFromRunError(err), err.Error(), store.ModelCandidateErrorCode(err))
return
}
s.logger.Error("estimate pricing failed", "error", err)
writeError(w, http.StatusInternalServerError, "estimate pricing failed")
return
}
writeJSON(w, http.StatusOK, estimate)
}
func (s *Server) listRateLimitWindows(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListRateLimitWindows(r.Context())
if err != nil {
s.logger.Error("list rate limit windows failed", "error", err)
writeError(w, http.StatusInternalServerError, "list rate limit windows failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) listModelRateLimitStatuses(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListModelRateLimitStatuses(r.Context())
if err != nil {
s.logger.Error("list model rate limit statuses failed", "error", err)
writeError(w, http.StatusInternalServerError, "list model rate limit statuses failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) createTask(kind string, compatible bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized")
return
}
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
model, _ := body["model"].(string)
if model == "" {
writeError(w, http.StatusBadRequest, "model is required")
return
}
if !apiKeyScopeAllowed(user, kind) {
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
return
}
asyncMode := asyncRequest(r)
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
Kind: kind,
Model: model,
RunMode: runModeFromRequest(body),
Async: asyncMode,
Request: body,
}, user)
if err != nil {
s.logger.Error("create task failed", "kind", kind, "error", err)
writeError(w, http.StatusInternalServerError, "create task failed")
return
}
if asyncMode {
if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil {
writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed")
return
}
writeTaskAccepted(w, task)
return
}
runCtx, cancelRun := s.requestExecutionContext(r)
defer cancelRun()
if compatible {
if boolValue(body, "stream") {
flusher := prepareCompatibleStream(w)
result, runErr := s.runner.ExecuteStream(runCtx, task, user, func(delta string) error {
if !requestStillConnected(r) {
return nil
}
writeCompatibleDelta(w, kind, model, delta)
if flusher != nil {
flusher.Flush()
}
return nil
})
if runErr != nil {
if !requestStillConnected(r) {
return
}
status := statusFromRunError(runErr)
errorPayload := map[string]any{
"code": runErrorCode(runErr),
"message": runErrorMessage(runErr),
"status": status,
}
if result.Task.ID != "" {
errorPayload["taskId"] = result.Task.ID
}
if result.Task.RequestID != "" {
errorPayload["requestId"] = result.Task.RequestID
}
for key, value := range runErrorDetails(runErr) {
errorPayload[key] = value
}
sendSSE(w, "error", map[string]any{"error": errorPayload})
if flusher != nil {
flusher.Flush()
}
return
}
if !requestStillConnected(r) {
return
}
writeCompatibleDone(w, kind, model, result.Output)
if flusher != nil {
flusher.Flush()
}
return
}
result, runErr := s.runner.Execute(runCtx, task, user)
if runErr != nil {
if !requestStillConnected(r) {
return
}
writeErrorWithDetails(w, statusFromRunError(runErr), runErrorMessage(runErr), runErrorDetails(runErr), runErrorCode(runErr))
return
}
if !requestStillConnected(r) {
return
}
writeJSON(w, http.StatusOK, result.Output)
return
}
result, runErr := s.runner.Execute(runCtx, task, user)
if runErr != nil {
s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr)
}
if !requestStillConnected(r) {
return
}
writeTaskAccepted(w, result.Task)
})
}
func (s *Server) requestExecutionContext(r *http.Request) (context.Context, context.CancelFunc) {
base := context.WithoutCancel(r.Context())
if s.ctx == nil {
return base, func() {}
}
ctx, cancel := context.WithCancel(base)
go func() {
select {
case <-s.ctx.Done():
cancel()
case <-ctx.Done():
}
}()
return ctx, cancel
}
func requestStillConnected(r *http.Request) bool {
select {
case <-r.Context().Done():
return false
default:
return true
}
}
func asyncRequest(r *http.Request) bool {
value := strings.TrimSpace(strings.ToLower(r.Header.Get("x-async")))
return value == "1" || value == "true" || value == "yes" || value == "on"
}
func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) {
writeJSON(w, http.StatusAccepted, map[string]any{
"taskId": task.ID,
"task": task,
"next": map[string]string{
"events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID),
"detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID),
},
})
}
func apiKeyScopeAllowed(user *auth.User, kind string) bool {
if user == nil || strings.TrimSpace(user.APIKeyID) == "" || len(user.APIKeyScopes) == 0 {
return true
}
required := scopeForTaskKind(kind)
for _, scope := range user.APIKeyScopes {
scope = strings.TrimSpace(strings.ToLower(scope))
if scope == "*" || scope == "all" || scope == required {
return true
}
if required == "chat" && (scope == "text" || scope == "text_generate") {
return true
}
}
return false
}
func scopeForTaskKind(kind string) string {
switch kind {
case "chat.completions", "responses":
return "chat"
case "images.generations", "images.edits":
return "image"
case "videos.generations":
return "video"
default:
return kind
}
}
func statusFromRunError(err error) int {
switch {
case store.ModelCandidateErrorCode(err) == "platform_cooling_down" || store.ModelCandidateErrorCode(err) == "model_cooling_down":
return http.StatusTooManyRequests
case errors.Is(err, store.ErrNoModelCandidate):
return http.StatusNotFound
case errors.Is(err, store.ErrRateLimited):
return http.StatusTooManyRequests
case clients.ErrorCode(err) == "rate_limit":
return http.StatusTooManyRequests
case errors.Is(err, store.ErrInsufficientWalletBalance):
return http.StatusPaymentRequired
default:
return http.StatusBadGateway
}
}
func runErrorCode(err error) string {
if errors.Is(err, store.ErrNoModelCandidate) {
return store.ModelCandidateErrorCode(err)
}
return clients.ErrorCode(err)
}
func runErrorMessage(err error) string {
if err == nil {
return ""
}
if summary := rateLimitErrorSummary(err); summary != "" {
return err.Error() + "" + summary
}
return err.Error()
}
func runErrorDetails(err error) map[string]any {
if detail := rateLimitErrorDetail(err); len(detail) > 0 {
return map[string]any{"rateLimit": detail}
}
return nil
}
func rateLimitErrorSummary(err error) string {
var limitErr *store.RateLimitExceededError
if !errors.As(err, &limitErr) {
return ""
}
scopeLabel := "限流对象"
switch limitErr.ScopeType {
case "user_group":
scopeLabel = "用户组"
case "platform_model":
scopeLabel = "平台模型"
}
scopeName := strings.TrimSpace(limitErr.ScopeName)
if scopeName == "" {
scopeName = strings.TrimSpace(limitErr.ScopeKey)
}
if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); limitErr.ScopeType == "user_group" && groupKey != "" && groupKey != scopeName {
scopeName = fmt.Sprintf("%s(%s)", scopeName, groupKey)
}
projected := limitErr.Projected
if projected <= 0 {
projected = limitErr.Current + limitErr.Amount
}
parts := []string{
fmt.Sprintf("限流摘要:%s %s 的 %s 超限", scopeLabel, scopeName, limitErr.Metric),
fmt.Sprintf("当前 %s本次 %s预计 %s限制 %s", formatRateLimitValue(limitErr.Current), formatRateLimitValue(limitErr.Amount), formatRateLimitValue(projected), formatRateLimitValue(limitErr.Limit)),
}
if limitErr.WindowSeconds > 0 {
parts = append(parts, fmt.Sprintf("窗口 %d 秒", limitErr.WindowSeconds))
}
if limitErr.RetryAfter > 0 {
parts = append(parts, fmt.Sprintf("约%s后可重试", formatRateLimitDuration(limitErr.RetryAfter)))
} else if !limitErr.Retryable {
parts = append(parts, "该请求超过单次限额,不能排队重试")
}
return strings.Join(parts, "")
}
func rateLimitErrorDetail(err error) map[string]any {
var limitErr *store.RateLimitExceededError
if !errors.As(err, &limitErr) {
return nil
}
detail := map[string]any{
"scopeType": limitErr.ScopeType,
"scopeKey": limitErr.ScopeKey,
"scopeName": limitErr.ScopeName,
"metric": limitErr.Metric,
"limit": limitErr.Limit,
"amount": limitErr.Amount,
"current": limitErr.Current,
"used": limitErr.Used,
"reserved": limitErr.Reserved,
"projected": limitErr.Projected,
"windowSeconds": limitErr.WindowSeconds,
"retryable": limitErr.Retryable,
"exceeded": map[string]any{
"metric": limitErr.Metric,
"current": limitErr.Current,
"amount": limitErr.Amount,
"projected": limitErr.Projected,
"limit": limitErr.Limit,
},
}
if limitErr.RetryAfter > 0 {
detail["retryAfterMs"] = limitErr.RetryAfter.Milliseconds()
}
if !limitErr.ResetAt.IsZero() {
detail["resetAt"] = limitErr.ResetAt.UTC().Format(time.RFC3339Nano)
}
if len(limitErr.Policy) > 0 {
detail["rateLimitPolicy"] = limitErr.Policy
if matchedRule := matchedRateLimitRule(limitErr.Policy, limitErr.Metric); len(matchedRule) > 0 {
detail["matchedRule"] = matchedRule
}
}
if len(limitErr.ScopeMetadata) > 0 {
detail["scopeMetadata"] = limitErr.ScopeMetadata
}
if limitErr.ScopeType == "user_group" {
userGroup := map[string]any{
"id": limitErr.ScopeKey,
"name": limitErr.ScopeName,
}
if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); groupKey != "" {
userGroup["groupKey"] = groupKey
}
detail["userGroup"] = userGroup
}
return detail
}
func formatRateLimitValue(value float64) string {
return strconv.FormatFloat(value, 'f', -1, 64)
}
func formatRateLimitDuration(duration time.Duration) string {
if duration < time.Second {
return strconv.FormatInt(duration.Milliseconds(), 10) + "毫秒"
}
seconds := duration.Seconds()
return strconv.FormatFloat(seconds, 'f', -1, 64) + "秒"
}
func matchedRateLimitRule(policy map[string]any, metric string) map[string]any {
rules, _ := policy["rules"].([]any)
for _, rawRule := range rules {
rule, _ := rawRule.(map[string]any)
if stringValue(rule["metric"]) == metric {
return rule
}
}
return nil
}
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized")
return
}
query := r.URL.Query()
page, err := positiveQueryInt(query.Get("page"), 1)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid page")
return
}
pageSizeRaw := query.Get("pageSize")
if pageSizeRaw == "" {
pageSizeRaw = query.Get("limit")
}
pageSize, err := positiveQueryInt(pageSizeRaw, 50)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid pageSize")
return
}
createdFrom, err := parseTaskListTime(query.Get("createdFrom"), query.Get("from"))
if err != nil {
writeError(w, http.StatusBadRequest, "invalid createdFrom")
return
}
createdTo, err := parseTaskListTime(query.Get("createdTo"), query.Get("to"))
if err != nil {
writeError(w, http.StatusBadRequest, "invalid createdTo")
return
}
result, err := s.store.ListTasks(r.Context(), user, store.TaskListFilter{
Query: firstNonEmpty(query.Get("q"), query.Get("query")),
ModelType: firstNonEmpty(query.Get("modelType"), query.Get("type")),
CreatedFrom: createdFrom,
CreatedTo: createdTo,
Page: page,
PageSize: pageSize,
})
if err != nil {
s.logger.Error("list tasks failed", "error", err)
writeError(w, http.StatusInternalServerError, "list tasks failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{
"items": result.Items,
"total": result.Total,
"page": result.Page,
"pageSize": result.PageSize,
})
}
func positiveQueryInt(raw string, fallback int) (int, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return fallback, nil
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return 0, fmt.Errorf("invalid positive integer")
}
return value, nil
}
func parseTaskListTime(values ...string) (*time.Time, error) {
raw := strings.TrimSpace(firstNonEmpty(values...))
if raw == "" {
return nil, nil
}
layouts := []string{time.RFC3339Nano, time.RFC3339, "2006-01-02T15:04", "2006-01-02 15:04:05", "2006-01-02"}
var lastErr error
for _, layout := range layouts {
parsed, err := time.ParseInLocation(layout, raw, time.Local)
if err == nil {
return &parsed, nil
}
lastErr = err
}
return nil, lastErr
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func boolValue(body map[string]any, key string) bool {
value, _ := body[key].(bool)
return value
}
func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err == nil {
writeJSON(w, http.StatusOK, task)
return
}
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "task not found")
return
}
s.logger.Error("get task failed", "error", err)
writeError(w, http.StatusInternalServerError, "get task failed")
}
func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "task not found")
return
}
s.logger.Error("get task failed", "error", err)
writeError(w, http.StatusInternalServerError, "get task failed")
return
}
logs, err := s.store.ListTaskParamPreprocessingLogs(r.Context(), task.ID)
if err != nil {
s.logger.Error("list task parameter preprocessing logs failed", "taskID", task.ID, "error", err)
writeError(w, http.StatusInternalServerError, "list task parameter preprocessing logs failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": logs})
}
func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "task not found")
return
}
writeError(w, http.StatusInternalServerError, "get task failed")
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
events, err := s.store.ListTaskEvents(r.Context(), task.ID)
if err != nil {
s.logger.Error("list task events failed", "error", err)
return
}
for _, event := range events {
sendSSE(w, event.EventType, event)
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
if len(events) == 0 {
sendSSE(w, "task.accepted", map[string]any{
"taskId": task.ID,
"status": task.Status,
})
}
}
func runModeFromRequest(body map[string]any) string {
if value, ok := body["runMode"].(string); ok {
return value
}
if value, ok := body["mode"].(string); ok {
return value
}
if value, ok := body["simulation"].(bool); ok && value {
return "simulation"
}
if value, ok := body["testMode"].(bool); ok && value {
return "simulation"
}
return ""
}