1471 lines
48 KiB
Go
1471 lines
48 KiB
Go
package httpapi
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/netproxy"
|
||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||
)
|
||
|
||
// health godoc
|
||
// @Summary 健康检查
|
||
// @Description 返回服务进程、运行环境和身份模式,供负载均衡或人工排障使用。
|
||
// @Tags system
|
||
// @Produce json
|
||
// @Success 200 {object} HealthResponse
|
||
// @Router /healthz [get]
|
||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||
writeJSON(w, http.StatusOK, map[string]any{
|
||
"ok": true,
|
||
"service": "easyai-ai-gateway",
|
||
"env": s.cfg.AppEnv,
|
||
"identityMode": s.cfg.IdentityMode,
|
||
})
|
||
}
|
||
|
||
// ready godoc
|
||
// @Summary 就绪检查
|
||
// @Description 检查 Postgres 是否可用;数据库不可用时返回 503。
|
||
// @Tags system
|
||
// @Produce json
|
||
// @Success 200 {object} ReadyResponse
|
||
// @Failure 503 {object} ErrorEnvelope
|
||
// @Router /readyz [get]
|
||
func (s *Server) ready(w http.ResponseWriter, r *http.Request) {
|
||
if err := s.store.Ping(r.Context()); err != nil {
|
||
writeError(w, http.StatusServiceUnavailable, "postgres unavailable")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||
}
|
||
|
||
// me godoc
|
||
// @Summary 获取当前用户
|
||
// @Description 返回鉴权中解析出的用户、租户、用户组和 API Key 上下文。
|
||
// @Tags auth
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} auth.User
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Router /api/v1/me [get]
|
||
func (s *Server) me(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
writeJSON(w, http.StatusOK, user)
|
||
}
|
||
|
||
// register godoc
|
||
// @Summary 本地注册
|
||
// @Description 在 standalone 或 hybrid 身份模式下创建本地用户,并返回 24 小时 JWT。
|
||
// @Tags auth
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param input body store.LocalRegisterInput true "注册请求,password 至少 8 位,invitationCode 取决于部署策略"
|
||
// @Success 201 {object} AuthResponse
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 409 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/auth/register [post]
|
||
func (s *Server) register(w http.ResponseWriter, r *http.Request) {
|
||
if !s.localIdentityEnabled() {
|
||
writeError(w, http.StatusForbidden, "local registration is disabled")
|
||
return
|
||
}
|
||
var input store.LocalRegisterInput
|
||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
user, err := s.store.RegisterLocalUser(r.Context(), input)
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrWeakPassword) {
|
||
writeError(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
if errors.Is(err, store.ErrInvalidInvitation) {
|
||
writeError(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
if errors.Is(err, store.ErrUserAlreadyExists) {
|
||
writeError(w, http.StatusConflict, err.Error())
|
||
return
|
||
}
|
||
s.logger.Error("register local user failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "register local user failed")
|
||
return
|
||
}
|
||
s.writeAuthResponse(w, http.StatusCreated, user)
|
||
}
|
||
|
||
// login godoc
|
||
// @Summary 本地登录
|
||
// @Description 使用用户名或邮箱登录本地账号,并返回 24 小时 JWT。
|
||
// @Tags auth
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param input body store.LocalLoginInput true "登录请求,account 可为用户名或邮箱"
|
||
// @Success 200 {object} AuthResponse
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/auth/login [post]
|
||
func (s *Server) login(w http.ResponseWriter, r *http.Request) {
|
||
if !s.localIdentityEnabled() {
|
||
writeError(w, http.StatusForbidden, "local login is disabled")
|
||
return
|
||
}
|
||
var input store.LocalLoginInput
|
||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
user, err := s.store.AuthenticateLocalUser(r.Context(), input)
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrInvalidCredentials) {
|
||
writeError(w, http.StatusUnauthorized, "invalid account or password")
|
||
return
|
||
}
|
||
s.logger.Error("login local user failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "login failed")
|
||
return
|
||
}
|
||
s.writeAuthResponse(w, http.StatusOK, user)
|
||
}
|
||
|
||
func (s *Server) localIdentityEnabled() bool {
|
||
mode := strings.ToLower(strings.TrimSpace(s.cfg.IdentityMode))
|
||
return mode == "" || mode == "standalone" || mode == "hybrid"
|
||
}
|
||
|
||
func (s *Server) writeAuthResponse(w http.ResponseWriter, status int, user store.GatewayUser) {
|
||
authUser := authUserFromGatewayUser(user)
|
||
const ttl = 24 * time.Hour
|
||
token, err := s.auth.SignJWT(authUser, ttl)
|
||
if err != nil {
|
||
s.logger.Error("sign local jwt failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "token sign failed")
|
||
return
|
||
}
|
||
writeJSON(w, status, map[string]any{
|
||
"accessToken": token,
|
||
"tokenType": "Bearer",
|
||
"expiresIn": int(ttl.Seconds()),
|
||
"user": authUser,
|
||
})
|
||
}
|
||
|
||
func authUserFromGatewayUser(user store.GatewayUser) *auth.User {
|
||
roles := user.Roles
|
||
if len(roles) == 0 {
|
||
roles = []string{"user"}
|
||
}
|
||
tenantID := user.TenantID
|
||
if tenantID == "" {
|
||
tenantID = user.TenantKey
|
||
}
|
||
return &auth.User{
|
||
ID: user.ID,
|
||
Username: user.Username,
|
||
Roles: roles,
|
||
TenantID: tenantID,
|
||
GatewayTenantID: user.GatewayTenantID,
|
||
TenantKey: user.TenantKey,
|
||
Source: "gateway",
|
||
GatewayUserID: user.ID,
|
||
UserGroupID: user.DefaultUserGroupID,
|
||
}
|
||
}
|
||
|
||
// listPlatforms godoc
|
||
// @Summary 列出平台
|
||
// @Description 管理端返回所有接入平台及其优先级、定价和运行策略摘要。
|
||
// @Tags platforms
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} PlatformListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/platforms [get]
|
||
func (s *Server) listPlatforms(w http.ResponseWriter, r *http.Request) {
|
||
platforms, err := s.store.ListPlatforms(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list platforms failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list platforms failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": platforms})
|
||
}
|
||
|
||
// listPlayablePlatforms godoc
|
||
// @Summary 列出可用平台
|
||
// @Description 按当前用户可访问模型过滤平台,仅返回启用且存在可访问模型的平台。
|
||
// @Tags playground
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} PlatformListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/platforms [get]
|
||
func (s *Server) listPlayablePlatforms(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
models, err := s.store.ListAccessiblePlatformModels(r.Context(), user)
|
||
if err != nil {
|
||
s.logger.Error("list playable platform models failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list playable platforms failed")
|
||
return
|
||
}
|
||
allowedPlatformIDs := map[string]bool{}
|
||
for _, model := range models {
|
||
allowedPlatformIDs[model.PlatformID] = true
|
||
}
|
||
platforms, err := s.store.ListPlatforms(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list platforms failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list playable platforms failed")
|
||
return
|
||
}
|
||
filtered := platforms[:0]
|
||
for _, platform := range platforms {
|
||
if platform.Status == "enabled" && allowedPlatformIDs[platform.ID] {
|
||
filtered = append(filtered, platform)
|
||
}
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": filtered})
|
||
}
|
||
|
||
// createPlatform godoc
|
||
// @Summary 创建平台
|
||
// @Description 新增模型供应商平台配置;credentials 会被服务端保存并在返回值中脱敏。
|
||
// @Tags platforms
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param input body store.CreatePlatformInput true "平台配置请求"
|
||
// @Success 201 {object} store.Platform
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/platforms [post]
|
||
func (s *Server) createPlatform(w http.ResponseWriter, r *http.Request) {
|
||
var input store.CreatePlatformInput
|
||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
input.Provider = strings.TrimSpace(input.Provider)
|
||
input.Name = strings.TrimSpace(input.Name)
|
||
input.InternalName = strings.TrimSpace(input.InternalName)
|
||
input.Status = strings.TrimSpace(input.Status)
|
||
if input.Provider == "" || input.Name == "" {
|
||
writeError(w, http.StatusBadRequest, "provider and name are required")
|
||
return
|
||
}
|
||
if input.Status != "" && input.Status != "enabled" && input.Status != "disabled" {
|
||
writeError(w, http.StatusBadRequest, "status must be enabled or disabled")
|
||
return
|
||
}
|
||
if input.AuthType == "" {
|
||
input.AuthType = "bearer"
|
||
}
|
||
config, err := netproxy.NormalizePlatformConfig(input.Config)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
input.Config = config
|
||
platform, err := s.store.CreatePlatform(r.Context(), input)
|
||
if err != nil {
|
||
s.logger.Error("create platform failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "create platform failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusCreated, platform)
|
||
}
|
||
|
||
// updatePlatform godoc
|
||
// @Summary 更新平台
|
||
// @Description 覆盖指定平台的基础配置、凭证、优先级、定价和运行策略。
|
||
// @Tags platforms
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param platformID path string true "平台 ID"
|
||
// @Param input body store.CreatePlatformInput true "平台配置请求"
|
||
// @Success 200 {object} store.Platform
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 409 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/platforms/{platformID} [patch]
|
||
func (s *Server) updatePlatform(w http.ResponseWriter, r *http.Request) {
|
||
var input store.CreatePlatformInput
|
||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
input.Provider = strings.TrimSpace(input.Provider)
|
||
input.Name = strings.TrimSpace(input.Name)
|
||
input.InternalName = strings.TrimSpace(input.InternalName)
|
||
input.Status = strings.TrimSpace(input.Status)
|
||
if input.Provider == "" || input.Name == "" {
|
||
writeError(w, http.StatusBadRequest, "provider and name are required")
|
||
return
|
||
}
|
||
if input.Status != "" && input.Status != "enabled" && input.Status != "disabled" {
|
||
writeError(w, http.StatusBadRequest, "status must be enabled or disabled")
|
||
return
|
||
}
|
||
if input.AuthType == "" {
|
||
input.AuthType = "bearer"
|
||
}
|
||
config, err := netproxy.NormalizePlatformConfig(input.Config)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
input.Config = config
|
||
platform, err := s.store.UpdatePlatform(r.Context(), r.PathValue("platformID"), input)
|
||
if err != nil {
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "platform not found")
|
||
return
|
||
}
|
||
if store.IsUniqueViolation(err) {
|
||
writeError(w, http.StatusConflict, "platform key already exists")
|
||
return
|
||
}
|
||
s.logger.Error("update platform failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "update platform failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, platform)
|
||
}
|
||
|
||
// deletePlatform godoc
|
||
// @Summary 删除平台
|
||
// @Description 删除指定平台及关联配置;不存在时返回 404。
|
||
// @Tags platforms
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param platformID path string true "平台 ID"
|
||
// @Success 204 "No Content"
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/platforms/{platformID} [delete]
|
||
func (s *Server) deletePlatform(w http.ResponseWriter, r *http.Request) {
|
||
if err := s.store.DeletePlatform(r.Context(), r.PathValue("platformID")); err != nil {
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "platform not found")
|
||
return
|
||
}
|
||
s.logger.Error("delete platform failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "delete platform failed")
|
||
return
|
||
}
|
||
w.WriteHeader(http.StatusNoContent)
|
||
}
|
||
|
||
// createPlatformModel godoc
|
||
// @Summary 创建平台模型
|
||
// @Description 为平台新增一个可路由模型;路径中的 platformID 会覆盖请求体 platformId。
|
||
// @Tags platform-models
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param platformID path string true "平台 ID,使用 /api/admin/platforms/{platformID}/models 时由路径提供"
|
||
// @Param input body store.CreatePlatformModelInput true "平台模型配置请求"
|
||
// @Success 201 {object} store.PlatformModel
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/platforms/{platformID}/models [post]
|
||
// @Router /api/admin/platform-models [post]
|
||
func (s *Server) createPlatformModel(w http.ResponseWriter, r *http.Request) {
|
||
var input store.CreatePlatformModelInput
|
||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
if pathPlatformID := r.PathValue("platformID"); pathPlatformID != "" {
|
||
input.PlatformID = pathPlatformID
|
||
}
|
||
if input.PlatformID == "" {
|
||
writeError(w, http.StatusBadRequest, "platformId is required")
|
||
return
|
||
}
|
||
model, err := s.store.CreatePlatformModel(r.Context(), input)
|
||
if err != nil {
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "base model not found")
|
||
return
|
||
}
|
||
s.logger.Error("create platform model failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "create platform model failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusCreated, s.platformModelResponse(r.Context(), model))
|
||
}
|
||
|
||
// replacePlatformModels godoc
|
||
// @Summary 替换平台模型
|
||
// @Description 用请求体中的 models 列表整体替换指定平台下的模型配置。
|
||
// @Tags platform-models
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param platformID path string true "平台 ID"
|
||
// @Param input body ReplacePlatformModelsRequest true "模型列表请求"
|
||
// @Success 200 {object} PlatformModelListResponse
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/platforms/{platformID}/models [put]
|
||
func (s *Server) replacePlatformModels(w http.ResponseWriter, r *http.Request) {
|
||
platformID := r.PathValue("platformID")
|
||
if platformID == "" {
|
||
writeError(w, http.StatusBadRequest, "platformId is required")
|
||
return
|
||
}
|
||
|
||
var input struct {
|
||
Models []store.CreatePlatformModelInput `json:"models"`
|
||
}
|
||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
|
||
models, err := s.store.ReplacePlatformModels(r.Context(), platformID, input.Models)
|
||
if err != nil {
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "base model not found")
|
||
return
|
||
}
|
||
s.logger.Error("replace platform models failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "replace platform models failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
|
||
}
|
||
|
||
// deletePlatformModel godoc
|
||
// @Summary 删除平台模型
|
||
// @Description 删除指定平台模型路由配置。
|
||
// @Tags platform-models
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param modelID path string true "平台模型 ID"
|
||
// @Success 204 "No Content"
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/platform-models/{modelID} [delete]
|
||
func (s *Server) deletePlatformModel(w http.ResponseWriter, r *http.Request) {
|
||
if err := s.store.DeletePlatformModel(r.Context(), r.PathValue("modelID")); err != nil {
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "platform model not found")
|
||
return
|
||
}
|
||
s.logger.Error("delete platform model failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "delete platform model failed")
|
||
return
|
||
}
|
||
w.WriteHeader(http.StatusNoContent)
|
||
}
|
||
|
||
// listModels godoc
|
||
// @Summary 列出平台模型
|
||
// @Description 管理端返回所有平台模型,并补齐有效计费配置。
|
||
// @Tags platform-models
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} PlatformModelListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/models [get]
|
||
func (s *Server) listModels(w http.ResponseWriter, r *http.Request) {
|
||
models, err := s.store.ListModels(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list models failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list models failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
|
||
}
|
||
|
||
// listPlayableModels godoc
|
||
// @Summary 列出可调用模型
|
||
// @Description 按当前用户权限返回可用于 Playground 或 API 调用的模型列表。
|
||
// @Tags playground
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} PlatformModelListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/models [get]
|
||
// @Router /api/v1/playground/models [get]
|
||
func (s *Server) listPlayableModels(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
models, err := s.store.ListAccessiblePlatformModels(r.Context(), user)
|
||
if err != nil {
|
||
s.logger.Error("list playable models failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list playable models failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
|
||
}
|
||
|
||
// listPricingRules godoc
|
||
// @Summary 列出定价规则
|
||
// @Description 返回所有定价规则明细,便于管理端排查有效价格。
|
||
// @Tags pricing
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} PricingRuleListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/pricing/rules [get]
|
||
func (s *Server) listPricingRules(w http.ResponseWriter, r *http.Request) {
|
||
items, err := s.store.ListPricingRules(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list pricing rules failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list pricing rules failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||
}
|
||
|
||
// listTenants godoc
|
||
// @Summary 列出租户
|
||
// @Description 管理端返回网关租户列表。
|
||
// @Tags identity
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} TenantListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/tenants [get]
|
||
func (s *Server) listTenants(w http.ResponseWriter, r *http.Request) {
|
||
items, err := s.store.ListTenants(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list tenants failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list tenants failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||
}
|
||
|
||
// listUsers godoc
|
||
// @Summary 列出用户
|
||
// @Description 管理端返回网关用户列表及钱包摘要。
|
||
// @Tags identity
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} UserListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/users [get]
|
||
func (s *Server) listUsers(w http.ResponseWriter, r *http.Request) {
|
||
items, err := s.store.ListUsers(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list users failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list users failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||
}
|
||
|
||
// listUserGroups godoc
|
||
// @Summary 列出用户组
|
||
// @Description 管理端返回用户组及其计费、限流和配额策略。
|
||
// @Tags identity
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} UserGroupListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/user-groups [get]
|
||
func (s *Server) listUserGroups(w http.ResponseWriter, r *http.Request) {
|
||
items, err := s.store.ListUserGroups(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list user groups failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list user groups failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||
}
|
||
|
||
// listAPIKeys godoc
|
||
// @Summary 列出 API Key
|
||
// @Description 返回当前用户创建的 API Key 元数据,secret 只在创建时返回。
|
||
// @Tags api-keys
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} APIKeyListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/api-keys [get]
|
||
func (s *Server) listAPIKeys(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
items, err := s.store.ListAPIKeys(r.Context(), user)
|
||
if err != nil {
|
||
s.logger.Error("list api keys failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list api keys failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||
}
|
||
|
||
// listPlayableAPIKeys godoc
|
||
// @Summary 列出 Playground API Key
|
||
// @Description 返回当前本地用户可在 Playground 中直接使用的 API Key 和 secret。
|
||
// @Tags playground
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} PlayableAPIKeyListResponse
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/playground/api-keys [get]
|
||
func (s *Server) listPlayableAPIKeys(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
items, err := s.store.ListPlayableAPIKeys(r.Context(), user)
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrLocalUserRequired) {
|
||
writeError(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
s.logger.Error("list playable api keys failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list playable api keys failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||
}
|
||
|
||
// createAPIKey godoc
|
||
// @Summary 创建 API Key
|
||
// @Description 为当前本地用户创建 API Key;secret 仅在本次响应中返回。
|
||
// @Tags api-keys
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param input body store.CreateAPIKeyInput true "API Key 创建请求"
|
||
// @Success 201 {object} store.CreatedAPIKey
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/api-keys [post]
|
||
func (s *Server) createAPIKey(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
var input store.CreateAPIKeyInput
|
||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
created, err := s.store.CreateAPIKey(r.Context(), input, user)
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrLocalUserRequired) {
|
||
writeError(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
s.logger.Error("create api key failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "create api key failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusCreated, created)
|
||
}
|
||
|
||
// disableAPIKey godoc
|
||
// @Summary 禁用 API Key
|
||
// @Description 禁用当前用户拥有的 API Key,保留记录但不再允许调用。
|
||
// @Tags api-keys
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param apiKeyID path string true "API Key ID"
|
||
// @Success 200 {object} store.APIKey
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/api-keys/{apiKeyID}/disable [patch]
|
||
func (s *Server) disableAPIKey(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
item, err := s.store.DisableAPIKey(r.Context(), r.PathValue("apiKeyID"), user)
|
||
if err == nil {
|
||
writeJSON(w, http.StatusOK, item)
|
||
return
|
||
}
|
||
if errors.Is(err, store.ErrLocalUserRequired) {
|
||
writeError(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "api key not found")
|
||
return
|
||
}
|
||
s.logger.Error("disable api key failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "disable api key failed")
|
||
}
|
||
|
||
// deleteAPIKey godoc
|
||
// @Summary 删除 API Key
|
||
// @Description 删除当前用户拥有的 API Key。
|
||
// @Tags api-keys
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param apiKeyID path string true "API Key ID"
|
||
// @Success 204 "No Content"
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/api-keys/{apiKeyID} [delete]
|
||
func (s *Server) deleteAPIKey(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
err := s.store.DeleteAPIKey(r.Context(), r.PathValue("apiKeyID"), user)
|
||
if err == nil {
|
||
w.WriteHeader(http.StatusNoContent)
|
||
return
|
||
}
|
||
if errors.Is(err, store.ErrLocalUserRequired) {
|
||
writeError(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "api key not found")
|
||
return
|
||
}
|
||
s.logger.Error("delete api key failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "delete api key failed")
|
||
}
|
||
|
||
// estimatePricing godoc
|
||
// @Summary 估算请求价格
|
||
// @Description 按当前用户、模型候选、任务类型和请求参数估算计费条目。
|
||
// @Tags pricing
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param input body PricingEstimateRequest true "计费估算请求,kind 默认为 chat.completions"
|
||
// @Success 200 {object} PricingEstimateResponse
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 429 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/v1/pricing/estimate [post]
|
||
func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) {
|
||
user, _ := auth.UserFromContext(r.Context())
|
||
var body map[string]any
|
||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
model, _ := body["model"].(string)
|
||
kind, _ := body["kind"].(string)
|
||
if kind == "" {
|
||
kind = "chat.completions"
|
||
}
|
||
if model == "" {
|
||
writeError(w, http.StatusBadRequest, "model is required")
|
||
return
|
||
}
|
||
if !apiKeyScopeAllowed(user, kind) {
|
||
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
|
||
return
|
||
}
|
||
estimate, err := s.runner.Estimate(r.Context(), kind, model, body, user)
|
||
if err != nil {
|
||
if errors.Is(err, store.ErrNoModelCandidate) {
|
||
writeError(w, statusFromRunError(err), err.Error(), store.ModelCandidateErrorCode(err))
|
||
return
|
||
}
|
||
s.logger.Error("estimate pricing failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "estimate pricing failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, estimate)
|
||
}
|
||
|
||
// listRateLimitWindows godoc
|
||
// @Summary 列出限流窗口
|
||
// @Description 管理端查看当前运行时限流窗口状态。
|
||
// @Tags runtime
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} RateLimitWindowListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/runtime/rate-limit-windows [get]
|
||
func (s *Server) listRateLimitWindows(w http.ResponseWriter, r *http.Request) {
|
||
items, err := s.store.ListRateLimitWindows(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list rate limit windows failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list rate limit windows failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||
}
|
||
|
||
// listModelRateLimitStatuses godoc
|
||
// @Summary 列出模型限流状态
|
||
// @Description 管理端查看平台模型维度的限流和冷却状态。
|
||
// @Tags runtime
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Success 200 {object} ModelRateLimitStatusListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/admin/runtime/model-rate-limits [get]
|
||
func (s *Server) listModelRateLimitStatuses(w http.ResponseWriter, r *http.Request) {
|
||
items, err := s.store.ListModelRateLimitStatuses(r.Context())
|
||
if err != nil {
|
||
s.logger.Error("list model rate limit statuses failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list model rate limit statuses failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||
}
|
||
|
||
// createTask godoc
|
||
// @Summary 创建或执行 AI 任务
|
||
// @Description 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。
|
||
// @Tags tasks
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param X-Async header bool false "true 时异步创建任务并返回 202"
|
||
// @Param input body TaskRequest true "AI 任务请求,字段随任务类型变化"
|
||
// @Success 200 {object} CompatibleResponse
|
||
// @Success 202 {object} TaskAcceptedResponse
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 402 {object} ErrorEnvelope
|
||
// @Failure 403 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 429 {object} ErrorEnvelope
|
||
// @Failure 502 {object} ErrorEnvelope
|
||
// @Router /api/v1/chat/completions [post]
|
||
// @Router /api/v1/responses [post]
|
||
// @Router /api/v1/images/generations [post]
|
||
// @Router /api/v1/images/edits [post]
|
||
// @Router /api/v1/videos/generations [post]
|
||
// @Router /chat/completions [post]
|
||
// @Router /v1/chat/completions [post]
|
||
// @Router /responses [post]
|
||
// @Router /v1/responses [post]
|
||
// @Router /images/generations [post]
|
||
// @Router /v1/images/generations [post]
|
||
// @Router /images/edits [post]
|
||
// @Router /v1/images/edits [post]
|
||
func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
user, ok := auth.UserFromContext(r.Context())
|
||
if !ok {
|
||
writeError(w, http.StatusUnauthorized, "unauthorized")
|
||
return
|
||
}
|
||
|
||
var body map[string]any
|
||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid json body")
|
||
return
|
||
}
|
||
model, _ := body["model"].(string)
|
||
if model == "" {
|
||
writeError(w, http.StatusBadRequest, "model is required")
|
||
return
|
||
}
|
||
if !apiKeyScopeAllowed(user, kind) {
|
||
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
|
||
return
|
||
}
|
||
asyncMode := asyncRequest(r)
|
||
|
||
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
|
||
Kind: kind,
|
||
Model: model,
|
||
RunMode: runModeFromRequest(body),
|
||
Async: asyncMode,
|
||
Request: body,
|
||
}, user)
|
||
if err != nil {
|
||
s.logger.Error("create task failed", "kind", kind, "error", err)
|
||
writeError(w, http.StatusInternalServerError, "create task failed")
|
||
return
|
||
}
|
||
if asyncMode {
|
||
if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil {
|
||
writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed")
|
||
return
|
||
}
|
||
writeTaskAccepted(w, task)
|
||
return
|
||
}
|
||
runCtx, cancelRun := s.requestExecutionContext(r)
|
||
defer cancelRun()
|
||
if compatible {
|
||
if boolValue(body, "stream") {
|
||
flusher := prepareCompatibleStream(w)
|
||
result, runErr := s.runner.ExecuteStream(runCtx, task, user, func(delta string) error {
|
||
if !requestStillConnected(r) {
|
||
return nil
|
||
}
|
||
writeCompatibleDelta(w, kind, model, delta)
|
||
if flusher != nil {
|
||
flusher.Flush()
|
||
}
|
||
return nil
|
||
})
|
||
if runErr != nil {
|
||
if !requestStillConnected(r) {
|
||
return
|
||
}
|
||
status := statusFromRunError(runErr)
|
||
errorPayload := map[string]any{
|
||
"code": runErrorCode(runErr),
|
||
"message": runErrorMessage(runErr),
|
||
"status": status,
|
||
}
|
||
if result.Task.ID != "" {
|
||
errorPayload["taskId"] = result.Task.ID
|
||
}
|
||
if result.Task.RequestID != "" {
|
||
errorPayload["requestId"] = result.Task.RequestID
|
||
}
|
||
for key, value := range runErrorDetails(runErr) {
|
||
errorPayload[key] = value
|
||
}
|
||
sendSSE(w, "error", map[string]any{"error": errorPayload})
|
||
if flusher != nil {
|
||
flusher.Flush()
|
||
}
|
||
return
|
||
}
|
||
if !requestStillConnected(r) {
|
||
return
|
||
}
|
||
writeCompatibleDone(w, kind, model, result.Output)
|
||
if flusher != nil {
|
||
flusher.Flush()
|
||
}
|
||
return
|
||
}
|
||
result, runErr := s.runner.Execute(runCtx, task, user)
|
||
if runErr != nil {
|
||
if !requestStillConnected(r) {
|
||
return
|
||
}
|
||
writeErrorWithDetails(w, statusFromRunError(runErr), runErrorMessage(runErr), runErrorDetails(runErr), runErrorCode(runErr))
|
||
return
|
||
}
|
||
if !requestStillConnected(r) {
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, result.Output)
|
||
return
|
||
}
|
||
result, runErr := s.runner.Execute(runCtx, task, user)
|
||
if runErr != nil {
|
||
s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr)
|
||
}
|
||
|
||
if !requestStillConnected(r) {
|
||
return
|
||
}
|
||
writeTaskAccepted(w, result.Task)
|
||
})
|
||
}
|
||
|
||
func (s *Server) requestExecutionContext(r *http.Request) (context.Context, context.CancelFunc) {
|
||
base := context.WithoutCancel(r.Context())
|
||
if s.ctx == nil {
|
||
return base, func() {}
|
||
}
|
||
ctx, cancel := context.WithCancel(base)
|
||
go func() {
|
||
select {
|
||
case <-s.ctx.Done():
|
||
cancel()
|
||
case <-ctx.Done():
|
||
}
|
||
}()
|
||
return ctx, cancel
|
||
}
|
||
|
||
func requestStillConnected(r *http.Request) bool {
|
||
select {
|
||
case <-r.Context().Done():
|
||
return false
|
||
default:
|
||
return true
|
||
}
|
||
}
|
||
|
||
func asyncRequest(r *http.Request) bool {
|
||
value := strings.TrimSpace(strings.ToLower(r.Header.Get("x-async")))
|
||
return value == "1" || value == "true" || value == "yes" || value == "on"
|
||
}
|
||
|
||
func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) {
|
||
writeJSON(w, http.StatusAccepted, map[string]any{
|
||
"taskId": task.ID,
|
||
"task": task,
|
||
"next": map[string]string{
|
||
"events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID),
|
||
"detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID),
|
||
},
|
||
})
|
||
}
|
||
|
||
func apiKeyScopeAllowed(user *auth.User, kind string) bool {
|
||
if user == nil || strings.TrimSpace(user.APIKeyID) == "" || len(user.APIKeyScopes) == 0 {
|
||
return true
|
||
}
|
||
required := scopeForTaskKind(kind)
|
||
for _, scope := range user.APIKeyScopes {
|
||
scope = strings.TrimSpace(strings.ToLower(scope))
|
||
if scope == "*" || scope == "all" || scope == required {
|
||
return true
|
||
}
|
||
if required == "chat" && (scope == "text" || scope == "text_generate") {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
func scopeForTaskKind(kind string) string {
|
||
switch kind {
|
||
case "chat.completions", "responses":
|
||
return "chat"
|
||
case "images.generations", "images.edits":
|
||
return "image"
|
||
case "videos.generations":
|
||
return "video"
|
||
default:
|
||
return kind
|
||
}
|
||
}
|
||
|
||
func statusFromRunError(err error) int {
|
||
switch {
|
||
case store.ModelCandidateErrorCode(err) == "platform_cooling_down" || store.ModelCandidateErrorCode(err) == "model_cooling_down":
|
||
return http.StatusTooManyRequests
|
||
case errors.Is(err, store.ErrNoModelCandidate):
|
||
return http.StatusNotFound
|
||
case errors.Is(err, store.ErrRateLimited):
|
||
return http.StatusTooManyRequests
|
||
case clients.ErrorCode(err) == "rate_limit":
|
||
return http.StatusTooManyRequests
|
||
case errors.Is(err, store.ErrInsufficientWalletBalance):
|
||
return http.StatusPaymentRequired
|
||
default:
|
||
return http.StatusBadGateway
|
||
}
|
||
}
|
||
|
||
func runErrorCode(err error) string {
|
||
if errors.Is(err, store.ErrNoModelCandidate) {
|
||
return store.ModelCandidateErrorCode(err)
|
||
}
|
||
return clients.ErrorCode(err)
|
||
}
|
||
|
||
func runErrorMessage(err error) string {
|
||
if err == nil {
|
||
return ""
|
||
}
|
||
if summary := rateLimitErrorSummary(err); summary != "" {
|
||
return err.Error() + ";" + summary
|
||
}
|
||
return err.Error()
|
||
}
|
||
|
||
func runErrorDetails(err error) map[string]any {
|
||
if detail := rateLimitErrorDetail(err); len(detail) > 0 {
|
||
return map[string]any{"rateLimit": detail}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func rateLimitErrorSummary(err error) string {
|
||
var limitErr *store.RateLimitExceededError
|
||
if !errors.As(err, &limitErr) {
|
||
return ""
|
||
}
|
||
scopeLabel := "限流对象"
|
||
switch limitErr.ScopeType {
|
||
case "user_group":
|
||
scopeLabel = "用户组"
|
||
case "platform_model":
|
||
scopeLabel = "平台模型"
|
||
}
|
||
scopeName := strings.TrimSpace(limitErr.ScopeName)
|
||
if scopeName == "" {
|
||
scopeName = strings.TrimSpace(limitErr.ScopeKey)
|
||
}
|
||
if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); limitErr.ScopeType == "user_group" && groupKey != "" && groupKey != scopeName {
|
||
scopeName = fmt.Sprintf("%s(%s)", scopeName, groupKey)
|
||
}
|
||
projected := limitErr.Projected
|
||
if projected <= 0 {
|
||
projected = limitErr.Current + limitErr.Amount
|
||
}
|
||
parts := []string{
|
||
fmt.Sprintf("限流摘要:%s %s 的 %s 超限", scopeLabel, scopeName, limitErr.Metric),
|
||
fmt.Sprintf("当前 %s,本次 %s,预计 %s,限制 %s", formatRateLimitValue(limitErr.Current), formatRateLimitValue(limitErr.Amount), formatRateLimitValue(projected), formatRateLimitValue(limitErr.Limit)),
|
||
}
|
||
if limitErr.WindowSeconds > 0 {
|
||
parts = append(parts, fmt.Sprintf("窗口 %d 秒", limitErr.WindowSeconds))
|
||
}
|
||
if limitErr.RetryAfter > 0 {
|
||
parts = append(parts, fmt.Sprintf("约%s后可重试", formatRateLimitDuration(limitErr.RetryAfter)))
|
||
} else if !limitErr.Retryable {
|
||
parts = append(parts, "该请求超过单次限额,不能排队重试")
|
||
}
|
||
return strings.Join(parts, ",")
|
||
}
|
||
|
||
func rateLimitErrorDetail(err error) map[string]any {
|
||
var limitErr *store.RateLimitExceededError
|
||
if !errors.As(err, &limitErr) {
|
||
return nil
|
||
}
|
||
detail := map[string]any{
|
||
"scopeType": limitErr.ScopeType,
|
||
"scopeKey": limitErr.ScopeKey,
|
||
"scopeName": limitErr.ScopeName,
|
||
"metric": limitErr.Metric,
|
||
"limit": limitErr.Limit,
|
||
"amount": limitErr.Amount,
|
||
"current": limitErr.Current,
|
||
"used": limitErr.Used,
|
||
"reserved": limitErr.Reserved,
|
||
"projected": limitErr.Projected,
|
||
"windowSeconds": limitErr.WindowSeconds,
|
||
"retryable": limitErr.Retryable,
|
||
"exceeded": map[string]any{
|
||
"metric": limitErr.Metric,
|
||
"current": limitErr.Current,
|
||
"amount": limitErr.Amount,
|
||
"projected": limitErr.Projected,
|
||
"limit": limitErr.Limit,
|
||
},
|
||
}
|
||
if limitErr.RetryAfter > 0 {
|
||
detail["retryAfterMs"] = limitErr.RetryAfter.Milliseconds()
|
||
}
|
||
if !limitErr.ResetAt.IsZero() {
|
||
detail["resetAt"] = limitErr.ResetAt.UTC().Format(time.RFC3339Nano)
|
||
}
|
||
if len(limitErr.Policy) > 0 {
|
||
detail["rateLimitPolicy"] = limitErr.Policy
|
||
if matchedRule := matchedRateLimitRule(limitErr.Policy, limitErr.Metric); len(matchedRule) > 0 {
|
||
detail["matchedRule"] = matchedRule
|
||
}
|
||
}
|
||
if len(limitErr.ScopeMetadata) > 0 {
|
||
detail["scopeMetadata"] = limitErr.ScopeMetadata
|
||
}
|
||
if limitErr.ScopeType == "user_group" {
|
||
userGroup := map[string]any{
|
||
"id": limitErr.ScopeKey,
|
||
"name": limitErr.ScopeName,
|
||
}
|
||
if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); groupKey != "" {
|
||
userGroup["groupKey"] = groupKey
|
||
}
|
||
detail["userGroup"] = userGroup
|
||
}
|
||
return detail
|
||
}
|
||
|
||
func formatRateLimitValue(value float64) string {
|
||
return strconv.FormatFloat(value, 'f', -1, 64)
|
||
}
|
||
|
||
func formatRateLimitDuration(duration time.Duration) string {
|
||
if duration < time.Second {
|
||
return strconv.FormatInt(duration.Milliseconds(), 10) + "毫秒"
|
||
}
|
||
seconds := duration.Seconds()
|
||
return strconv.FormatFloat(seconds, 'f', -1, 64) + "秒"
|
||
}
|
||
|
||
func matchedRateLimitRule(policy map[string]any, metric string) map[string]any {
|
||
rules, _ := policy["rules"].([]any)
|
||
for _, rawRule := range rules {
|
||
rule, _ := rawRule.(map[string]any)
|
||
if stringValue(rule["metric"]) == metric {
|
||
return rule
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// listTasks godoc
|
||
// @Summary 列出任务
|
||
// @Description 按当前用户列出任务,支持关键字、模型类型、时间范围和分页过滤。
|
||
// @Tags tasks
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param q query string false "搜索关键字,别名 query"
|
||
// @Param modelType query string false "模型类型,别名 type"
|
||
// @Param createdFrom query string false "创建时间起点,支持 RFC3339 或日期格式,别名 from"
|
||
// @Param createdTo query string false "创建时间终点,支持 RFC3339 或日期格式,别名 to"
|
||
// @Param page query int false "页码" default(1)
|
||
// @Param pageSize query int false "每页数量,别名 limit" default(50)
|
||
// @Success 200 {object} TaskListResponse
|
||
// @Failure 400 {object} ErrorEnvelope
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/workspace/tasks [get]
|
||
// @Router /api/v1/tasks [get]
|
||
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
|
||
user, ok := auth.UserFromContext(r.Context())
|
||
if !ok {
|
||
writeError(w, http.StatusUnauthorized, "unauthorized")
|
||
return
|
||
}
|
||
query := r.URL.Query()
|
||
page, err := positiveQueryInt(query.Get("page"), 1)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid page")
|
||
return
|
||
}
|
||
pageSizeRaw := query.Get("pageSize")
|
||
if pageSizeRaw == "" {
|
||
pageSizeRaw = query.Get("limit")
|
||
}
|
||
pageSize, err := positiveQueryInt(pageSizeRaw, 50)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid pageSize")
|
||
return
|
||
}
|
||
createdFrom, err := parseTaskListTime(query.Get("createdFrom"), query.Get("from"))
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid createdFrom")
|
||
return
|
||
}
|
||
createdTo, err := parseTaskListTime(query.Get("createdTo"), query.Get("to"))
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid createdTo")
|
||
return
|
||
}
|
||
result, err := s.store.ListTasks(r.Context(), user, store.TaskListFilter{
|
||
Query: firstNonEmpty(query.Get("q"), query.Get("query")),
|
||
ModelType: firstNonEmpty(query.Get("modelType"), query.Get("type")),
|
||
CreatedFrom: createdFrom,
|
||
CreatedTo: createdTo,
|
||
Page: page,
|
||
PageSize: pageSize,
|
||
})
|
||
if err != nil {
|
||
s.logger.Error("list tasks failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list tasks failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{
|
||
"items": result.Items,
|
||
"total": result.Total,
|
||
"page": result.Page,
|
||
"pageSize": result.PageSize,
|
||
})
|
||
}
|
||
|
||
func positiveQueryInt(raw string, fallback int) (int, error) {
|
||
raw = strings.TrimSpace(raw)
|
||
if raw == "" {
|
||
return fallback, nil
|
||
}
|
||
value, err := strconv.Atoi(raw)
|
||
if err != nil || value <= 0 {
|
||
return 0, fmt.Errorf("invalid positive integer")
|
||
}
|
||
return value, nil
|
||
}
|
||
|
||
func parseTaskListTime(values ...string) (*time.Time, error) {
|
||
raw := strings.TrimSpace(firstNonEmpty(values...))
|
||
if raw == "" {
|
||
return nil, nil
|
||
}
|
||
layouts := []string{time.RFC3339Nano, time.RFC3339, "2006-01-02T15:04", "2006-01-02 15:04:05", "2006-01-02"}
|
||
var lastErr error
|
||
for _, layout := range layouts {
|
||
parsed, err := time.ParseInLocation(layout, raw, time.Local)
|
||
if err == nil {
|
||
return &parsed, nil
|
||
}
|
||
lastErr = err
|
||
}
|
||
return nil, lastErr
|
||
}
|
||
|
||
func firstNonEmpty(values ...string) string {
|
||
for _, value := range values {
|
||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||
return trimmed
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func boolValue(body map[string]any, key string) bool {
|
||
value, _ := body[key].(bool)
|
||
return value
|
||
}
|
||
|
||
// getTask godoc
|
||
// @Summary 获取任务详情
|
||
// @Description 返回指定任务的请求、状态、输出和执行摘要。
|
||
// @Tags tasks
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param taskID path string true "任务 ID"
|
||
// @Success 200 {object} store.GatewayTask
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/workspace/tasks/{taskID} [get]
|
||
// @Router /api/v1/tasks/{taskID} [get]
|
||
func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
|
||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||
if err == nil {
|
||
writeJSON(w, http.StatusOK, task)
|
||
return
|
||
}
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "task not found")
|
||
return
|
||
}
|
||
s.logger.Error("get task failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "get task failed")
|
||
}
|
||
|
||
// taskParamPreprocessing godoc
|
||
// @Summary 获取任务参数预处理日志
|
||
// @Description 返回指定任务在执行前的参数改写、校验或模板处理日志。
|
||
// @Tags tasks
|
||
// @Produce json
|
||
// @Security BearerAuth
|
||
// @Param taskID path string true "任务 ID"
|
||
// @Success 200 {object} TaskParamPreprocessingLogListResponse
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/workspace/tasks/{taskID}/param-preprocessing [get]
|
||
// @Router /api/v1/tasks/{taskID}/param-preprocessing [get]
|
||
func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) {
|
||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||
if err != nil {
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "task not found")
|
||
return
|
||
}
|
||
s.logger.Error("get task failed", "error", err)
|
||
writeError(w, http.StatusInternalServerError, "get task failed")
|
||
return
|
||
}
|
||
logs, err := s.store.ListTaskParamPreprocessingLogs(r.Context(), task.ID)
|
||
if err != nil {
|
||
s.logger.Error("list task parameter preprocessing logs failed", "taskID", task.ID, "error", err)
|
||
writeError(w, http.StatusInternalServerError, "list task parameter preprocessing logs failed")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, map[string]any{"items": logs})
|
||
}
|
||
|
||
// taskEvents godoc
|
||
// @Summary 订阅任务事件
|
||
// @Description 以 text/event-stream 返回指定任务的历史事件;无事件时返回 task.accepted 占位事件。
|
||
// @Tags tasks
|
||
// @Produce text/event-stream
|
||
// @Security BearerAuth
|
||
// @Param taskID path string true "任务 ID"
|
||
// @Success 200 {string} string "Server-Sent Events,data 为 store.TaskEvent 或 TaskAcceptedEvent"
|
||
// @Failure 401 {object} ErrorEnvelope
|
||
// @Failure 404 {object} ErrorEnvelope
|
||
// @Failure 500 {object} ErrorEnvelope
|
||
// @Router /api/workspace/tasks/{taskID}/events [get]
|
||
// @Router /api/v1/tasks/{taskID}/events [get]
|
||
func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) {
|
||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||
if err != nil {
|
||
if store.IsNotFound(err) {
|
||
writeError(w, http.StatusNotFound, "task not found")
|
||
return
|
||
}
|
||
writeError(w, http.StatusInternalServerError, "get task failed")
|
||
return
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "text/event-stream")
|
||
w.Header().Set("Cache-Control", "no-cache")
|
||
w.Header().Set("Connection", "keep-alive")
|
||
|
||
events, err := s.store.ListTaskEvents(r.Context(), task.ID)
|
||
if err != nil {
|
||
s.logger.Error("list task events failed", "error", err)
|
||
return
|
||
}
|
||
for _, event := range events {
|
||
sendSSE(w, event.EventType, event)
|
||
if flusher, ok := w.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
if len(events) == 0 {
|
||
sendSSE(w, "task.accepted", map[string]any{
|
||
"taskId": task.ID,
|
||
"status": task.Status,
|
||
})
|
||
}
|
||
}
|
||
|
||
func runModeFromRequest(body map[string]any) string {
|
||
if value, ok := body["runMode"].(string); ok {
|
||
return value
|
||
}
|
||
if value, ok := body["mode"].(string); ok {
|
||
return value
|
||
}
|
||
if value, ok := body["simulation"].(bool); ok && value {
|
||
return "simulation"
|
||
}
|
||
if value, ok := body["testMode"].(bool); ok && value {
|
||
return "simulation"
|
||
}
|
||
return ""
|
||
}
|