Compare commits

..

9 Commits

Author SHA1 Message Date
ba419cd90a fix task duration recording 2026-05-18 01:06:52 +08:00
d09a4c2e4d feat(clients): 添加Keling客户端支持视频生成和多模态功能
- 实现KelingClient结构体及其Run方法,支持视频生成功能
- 添加对image2video、text2video和omni-video任务类型的完整支持
- 集成Keling平台的身份验证机制,包括JWT令牌生成
- 实现视频任务的提交和轮询逻辑,支持异步处理流程
- 添加对多种输入格式的支持,包括图像帧、基础视频和参考素材
- 实现Keling元素管理和清理机制,处理临时创建的素材
- 在服务初始化中注册keling和kling两个提供商标识
- 添加数据库迁移脚本,更新Keling模型的音频功能配置
- 完善错误处理和重试机制,提升服务稳定性
- 编写完整的单元测试,覆盖各种视频生成场景和边界情况
2026-05-17 22:08:55 +08:00
90c3315468 fix(runner): record failed task attempts 2026-05-17 20:50:20 +08:00
ae197a742f docs(api): 同步 /api/v1/chat/completions 的 OpenAPI 与同步响应
补充 Chat Completions 的兼容响应模型与路由注释,确保 /api/v1/chat/completions 按同步兼容格式返回并更新对应测试与 Swagger 文档。
2026-05-16 00:19:39 +08:00
34c3251c6d docs(api): 补全 OpenAPI 上传与系统设置文档
为文件上传、静态资源和文件存储设置接口补齐注释,并同步更新生成的 Swagger 文档。
2026-05-15 09:59:25 +08:00
62d426bdfb Merge pull request 'chore(dev): 配置本地开发环境' (#1) from chore/devenv-setup into main
Reviewed-on: #1
2026-05-15 09:46:35 +08:00
7abb6a1baf Merge pull request '补全 API OpenAPI 文档' (#2) from feature/openapi-docs into main
Reviewed-on: #2
2026-05-15 09:46:18 +08:00
5d3e543cba Merge remote-tracking branch 'origin/main' into feature/openapi-docs
# Conflicts:
#	apps/api/internal/httpapi/handlers.go
2026-05-15 09:42:44 +08:00
918dfbfee1 docs(api): 补全 OpenAPI 注释与生成文档
为接口、模型与脚本补齐 Swagger/OpenAPI 注释,生成最新文档,并增加一键生成与查看入口。
2026-05-14 18:18:27 +08:00
35 changed files with 19549 additions and 136 deletions

View File

@ -43,6 +43,15 @@ pnpm dev
后端热更新可通过 `GO_WATCH_SHUTDOWN_GRACE_MS``GO_WATCH_RESTART_DELAY_MS` 调整旧进程退出等待时间与重启间隔。
## OpenAPI 文档
修改 `apps/api/internal/httpapi` 下的接口、请求或响应类型后,请重新执行:
```bash
pnpm openapi
```
默认 EasyAI 部署里,`easyai-pgvector` 在容器网络内的连接串是:
```dotenv

View File

@ -15,6 +15,16 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// @title EasyAI AI Gateway API
// @version 0.1.0
// @description EasyAI AI Gateway 的本地鉴权、平台模型管理、定价、运行策略、钱包和 AI 任务接口。
// @description 受保护接口使用 Authorization: Bearer <JWT 或 API Key>,管理接口只接受 JWT 用户凭证。
// @BasePath /
// @schemes http https
// @securityDefinitions.apikey BearerAuth
// @in header
// @name Authorization
// @description Bearer JWT 或 API Key。
func main() {
cfg := config.Load()
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{

9808
apps/api/docs/swagger.json Normal file

File diff suppressed because it is too large Load Diff

6425
apps/api/docs/swagger.yaml Normal file

File diff suppressed because it is too large Load Diff

View File

@ -110,6 +110,7 @@ func TestOpenAIClientChatContract(t *testing.T) {
t.Fatalf("decode request: %v", err)
}
gotModel, _ = body["model"].(string)
time.Sleep(25 * time.Millisecond)
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-test",
"object": "chat.completion",
@ -145,6 +146,9 @@ func TestOpenAIClientChatContract(t *testing.T) {
if response.RequestID != "req-chat-test" || response.ResponseStartedAt.IsZero() || response.ResponseFinishedAt.IsZero() {
t.Fatalf("response metadata was not captured: %+v", response)
}
if response.ResponseDurationMS < 20 {
t.Fatalf("response duration should include upstream latency, got %dms", response.ResponseDurationMS)
}
}
func TestOpenAIClientChatStreamContract(t *testing.T) {
@ -662,6 +666,266 @@ func TestVolcesClientVideoResumePollsExistingTaskID(t *testing.T) {
}
}
func TestKelingClientVideoSubmitsAndPollsImageTask(t *testing.T) {
var submitPath string
var pollPath string
var gotAuth string
var submittedTaskID string
var submittedPayload map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
switch r.Method + " " + r.URL.Path {
case "POST /videos/image2video":
submitPath = r.URL.Path
if err := json.NewDecoder(r.Body).Decode(&submittedPayload); err != nil {
t.Fatalf("decode keling submit: %v", err)
}
if _, ok := submittedPayload["aspect_ratio"]; ok {
t.Fatalf("image2video payload should not include aspect_ratio: %+v", submittedPayload)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"request_id": "req-submit",
"data": map[string]any{"task_id": "keling-task-1"},
})
case "GET /videos/image2video/keling-task-1":
pollPath = r.URL.Path
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"request_id": "req-poll",
"data": map[string]any{
"task_id": "keling-task-1",
"task_status": "succeed",
"created_at": 456,
"task_result": map[string]any{
"videos": []any{map[string]any{"url": "https://example.com/keling.mp4", "duration": 6}},
},
},
})
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
}))
defer server.Close()
response, err := (KelingClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "image_to_video",
Model: "可灵2.6",
Body: map[string]any{
"model": "可灵2.6",
"prompt": "A clean product reveal",
"first_frame": "data:image/png;base64,Zmlyc3Q=",
"last_frame": "data:image/png;base64,bGFzdA==",
"duration": 6,
"resolution": "1080p",
"aspect_ratio": "16:9",
"audio": true,
"camera_control": "simple:zoom",
"camera_control_strength": 0.6,
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
Provider: "keling",
AuthType: "AccessKey-SecretKey",
ModelName: "可灵2.6",
ProviderModelName: "kling-v2-6",
Credentials: map[string]any{"accessKey": "ak", "secretKey": "sk"},
PlatformConfig: map[string]any{
"kelingPollIntervalMs": 100,
"kelingPollTimeoutSeconds": 1,
},
},
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
submittedTaskID = remoteTaskID
if payload["endpoint"] != "/videos/image2video" || payload["taskType"] != "image2video" {
t.Fatalf("unexpected submitted keling payload: %+v", payload)
}
return nil
},
})
if err != nil {
t.Fatalf("run keling video: %v", err)
}
if submitPath != "/videos/image2video" || pollPath != "/videos/image2video/keling-task-1" || !strings.HasPrefix(gotAuth, "Bearer ") {
t.Fatalf("unexpected keling paths/auth submit=%s poll=%s auth=%s", submitPath, pollPath, gotAuth)
}
if submittedTaskID != "keling-task-1" {
t.Fatalf("remote task submit callback did not receive task id, got %q", submittedTaskID)
}
if submittedPayload["model_name"] != "kling-v2-6" ||
submittedPayload["prompt"] != "A clean product reveal" ||
submittedPayload["duration"] != "6" ||
submittedPayload["mode"] != "pro" ||
submittedPayload["sound"] != "on" ||
submittedPayload["image"] != "Zmlyc3Q=" ||
submittedPayload["image_tail"] != "bGFzdA==" {
t.Fatalf("unexpected keling submit payload: %+v", submittedPayload)
}
camera, _ := submittedPayload["camera_control"].(map[string]any)
config, _ := camera["config"].(map[string]any)
if camera["type"] != "simple" || numericValue(config["zoom"], 0) != 0.6 || numericValue(config["pan"], -1) != 0 {
t.Fatalf("unexpected keling camera conversion: %+v", submittedPayload["camera_control"])
}
data, _ := response.Result["data"].([]any)
item, _ := data[0].(map[string]any)
if response.Result["upstream_task_id"] != "keling-task-1" || item["url"] != "https://example.com/keling.mp4" || item["video_url"] != "https://example.com/keling.mp4" {
t.Fatalf("unexpected keling response: %+v", response.Result)
}
}
func TestKelingOmniPayloadConvertsGatewayContent(t *testing.T) {
payload, cleanupIDs, err := (KelingClient{}).kelingOmniPayload(context.Background(), Request{
Kind: "videos.generations",
ModelType: "omni_video",
Model: "可灵V3多模态",
Body: map[string]any{
"model": "可灵V3多模态",
"duration": 8,
"aspect_ratio": "9:16",
"resolution": "2160p",
"audio": true,
"content": []any{
map[string]any{"type": "text", "text": "Refine the base video"},
map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}},
map[string]any{"type": "image_url", "role": "last_frame", "image_url": map[string]any{"url": "https://example.com/last.png"}},
map[string]any{
"type": "video_url",
"role": "video_base",
"video_url": map[string]any{
"url": "https://example.com/base.mp4",
"keep_original_sound": "yes",
},
},
},
},
Candidate: store.RuntimeModelCandidate{
Provider: "keling",
ProviderModelName: "kling-v3-omni",
Capabilities: map[string]any{"omni_video": map[string]any{}},
},
}, "token")
if err != nil {
t.Fatalf("build keling omni payload: %v", err)
}
if len(cleanupIDs) != 0 {
t.Fatalf("unexpected cleanup ids: %+v", cleanupIDs)
}
if payload["model_name"] != "kling-v3-omni" || payload["mode"] != "4k" || payload["prompt"] != "Refine the base video" {
t.Fatalf("unexpected keling omni base fields: %+v", payload)
}
if _, ok := payload["sound"]; ok {
t.Fatalf("omni payload with base video should not include sound: %+v", payload)
}
if _, ok := payload["duration"]; ok {
t.Fatalf("base video edit should not include duration: %+v", payload)
}
if _, ok := payload["aspect_ratio"]; ok {
t.Fatalf("base video edit should not include aspect_ratio: %+v", payload)
}
watermark, _ := payload["watermark_info"].(map[string]any)
if watermark["enabled"] != false {
t.Fatalf("keling watermark should be disabled by default: %+v", payload)
}
images, _ := payload["image_list"].([]any)
if len(images) != 2 {
t.Fatalf("unexpected keling image_list: %+v", payload["image_list"])
}
firstImage, _ := images[0].(map[string]any)
lastImage, _ := images[1].(map[string]any)
if firstImage["type"] != "first_frame" || lastImage["type"] != "end_frame" {
t.Fatalf("frame roles should convert to keling omni types: %+v", images)
}
videos, _ := payload["video_list"].([]map[string]any)
if len(videos) != 1 || videos[0]["refer_type"] != "base" || videos[0]["keep_original_sound"] != "yes" {
t.Fatalf("video roles should convert to keling omni refer_type: %+v", payload["video_list"])
}
}
func TestKelingClientVideoResumePollsWithoutSubmitting(t *testing.T) {
var submitCalled bool
var pollPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method + " " + r.URL.Path {
case "POST /general/custom-elements", "POST /videos/omni-video":
submitCalled = true
t.Fatalf("resume should not submit or upload temporary elements")
case "GET /videos/omni-video/keling-existing":
pollPath = r.URL.Path
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"request_id": "req-resume",
"data": map[string]any{
"task_id": "keling-existing",
"task_status": "succeed",
"task_result": map[string]any{
"videos": []any{map[string]any{"url": "https://example.com/resumed-keling.mp4"}},
},
},
})
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
}))
defer server.Close()
response, err := (KelingClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "omni_video",
Model: "可灵V3多模态",
Body: map[string]any{"prompt": "resume", "pollIntervalMs": 100, "pollTimeoutSeconds": 1},
RemoteTaskID: "keling-existing",
RemoteTaskPayload: map[string]any{
"endpoint": "/videos/omni-video",
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
Provider: "keling",
AuthType: "AccessKey-SecretKey",
ProviderModelName: "kling-v3-omni",
Credentials: map[string]any{"accessKey": "ak", "secretKey": "sk"},
},
})
if err != nil {
t.Fatalf("resume keling video: %v", err)
}
if submitCalled || pollPath != "/videos/omni-video/keling-existing" {
t.Fatalf("resume should poll existing task only, submit=%v poll=%s", submitCalled, pollPath)
}
data, _ := response.Result["data"].([]any)
item, _ := data[0].(map[string]any)
if response.Result["upstream_task_id"] != "keling-existing" || item["url"] != "https://example.com/resumed-keling.mp4" {
t.Fatalf("unexpected resumed keling response: %+v", response.Result)
}
}
func TestKelingElementPayloadMapsTags(t *testing.T) {
payload := kelingCreateElementPayload(map[string]any{
"name": "subject",
"frontal_image_url": "https://example.com/front.png",
"tags": []any{"character", "unknown"},
"refer_images": []any{
map[string]any{"url": "https://example.com/side.png"},
},
})
if payload["element_name"] != "subject" || payload["element_frontal_image"] != "https://example.com/front.png" {
t.Fatalf("unexpected element payload base fields: %+v", payload)
}
tags, _ := payload["tag_list"].([]any)
if len(tags) != 2 {
t.Fatalf("unexpected tag list: %+v", payload["tag_list"])
}
firstTag, _ := tags[0].(map[string]any)
secondTag, _ := tags[1].(map[string]any)
if firstTag["tag_id"] != "o_102" || secondTag["tag_id"] != "o_108" {
t.Fatalf("unexpected keling tag conversion: %+v", payload["tag_list"])
}
refs, _ := payload["element_refer_list"].([]any)
if len(refs) != 1 {
t.Fatalf("unexpected element references: %+v", payload["element_refer_list"])
}
}
func extractText(result map[string]any) string {
choices, _ := result["choices"].([]any)
choice, _ := choices[0].(map[string]any)

View File

@ -27,11 +27,11 @@ func (c GeminiClient) Run(ctx context.Context, request Request) (Response, error
return Response{}, err
}
req.Header.Set("Content-Type", "application/json")
responseStartedAt := time.Now()
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
if err != nil {
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
responseStartedAt := time.Now()
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
responseFinishedAt := time.Now()

View File

@ -0,0 +1,960 @@
package clients
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"sort"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/golang-jwt/jwt/v5"
)
type KelingClient struct {
HTTPClient *http.Client
}
type kelingPreparedTask struct {
Endpoint string
Payload map[string]any
RemoteTaskPayload map[string]any
CleanupElementIDs []string
}
func (c KelingClient) Run(ctx context.Context, request Request) (Response, error) {
if request.Kind != "videos.generations" {
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported keling request kind", Retryable: false}
}
token, err := kelingAuthToken(request.Candidate)
if err != nil {
return Response{}, err
}
return c.runVideo(ctx, request, token)
}
func (c KelingClient) runVideo(ctx context.Context, request Request, token string) (Response, error) {
submitStartedAt := time.Now()
submitRequestID := strings.TrimSpace(request.RemoteTaskID)
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
prepared := kelingResumePreparedTask(request)
if upstreamTaskID == "" {
var err error
prepared, err = c.prepareVideoTask(ctx, request, token)
if err != nil {
return Response{}, err
}
}
defer func() {
if upstreamTaskID == "" {
_ = c.cleanupKelingElements(context.WithoutCancel(ctx), request, token, prepared.CleanupElementIDs)
}
}()
if upstreamTaskID == "" {
submitResult, requestID, err := c.postJSON(ctx, request, prepared.Endpoint, token, prepared.Payload)
submitRequestID = requestID
if err != nil {
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, time.Now())
}
upstreamTaskID = strings.TrimSpace(stringFromAny(kelingData(submitResult)["task_id"]))
if upstreamTaskID == "" {
_ = c.cleanupKelingElements(context.WithoutCancel(ctx), request, token, prepared.CleanupElementIDs)
return Response{}, &ClientError{Code: "invalid_response", Message: "keling video task id is missing", RequestID: submitRequestID, Retryable: false}
}
prepared.RemoteTaskPayload["submit"] = submitResult
if request.OnRemoteTaskSubmitted != nil {
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, prepared.RemoteTaskPayload); err != nil {
return Response{}, err
}
}
}
pollEndpoint := kelingPollEndpoint(request, prepared.Endpoint)
interval := kelingPollInterval(request)
timeout := kelingPollTimeout(request)
deadline := time.NewTimer(timeout)
defer deadline.Stop()
ticker := time.NewTicker(interval)
defer ticker.Stop()
var lastResult map[string]any
for {
select {
case <-ctx.Done():
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: submitRequestID, Retryable: true}
default:
}
pollStartedAt := time.Now()
pollResult, pollRequestID, err := c.getJSON(ctx, request, pollEndpoint+"/"+upstreamTaskID, token)
pollFinishedAt := time.Now()
requestID := firstNonEmpty(pollRequestID, submitRequestID, upstreamTaskID)
if err != nil {
return Response{}, annotateResponseError(err, requestID, pollStartedAt, pollFinishedAt)
}
lastResult = pollResult
switch kelingTaskStatus(pollResult) {
case "succeed":
_ = c.cleanupKelingElements(context.WithoutCancel(ctx), request, token, prepared.CleanupElementIDs)
prepared.CleanupElementIDs = nil
result := kelingVideoSuccessResult(request, upstreamTaskID, pollResult)
return Response{
Result: result,
RequestID: requestID,
Progress: kelingVideoProgress(request, upstreamTaskID),
ResponseStartedAt: submitStartedAt,
ResponseFinishedAt: pollFinishedAt,
ResponseDurationMS: responseDurationMS(submitStartedAt, pollFinishedAt),
}, nil
case "failed":
_ = c.cleanupKelingElements(context.WithoutCancel(ctx), request, token, prepared.CleanupElementIDs)
prepared.CleanupElementIDs = nil
return Response{}, &ClientError{
Code: kelingTaskErrorCode(pollResult),
Message: kelingTaskErrorMessage(request.Candidate, pollResult),
RequestID: requestID,
ResponseStartedAt: submitStartedAt,
ResponseFinishedAt: pollFinishedAt,
ResponseDurationMS: responseDurationMS(submitStartedAt, pollFinishedAt),
Retryable: false,
}
}
select {
case <-ctx.Done():
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
case <-deadline.C:
return Response{}, &ClientError{
Code: "timeout",
Message: fmt.Sprintf("keling video task %s did not finish before timeout; last status: %s", upstreamTaskID, kelingTaskStatus(lastResult)),
RequestID: requestID,
Retryable: true,
}
case <-ticker.C:
}
}
}
func (c KelingClient) prepareVideoTask(ctx context.Context, request Request, token string) (kelingPreparedTask, error) {
if kelingIsOmniRequest(request) {
payload, cleanupIDs, err := c.kelingOmniPayload(ctx, request, token)
if err != nil {
return kelingPreparedTask{}, err
}
return kelingPreparedTask{
Endpoint: "/videos/omni-video",
Payload: payload,
RemoteTaskPayload: map[string]any{"endpoint": "/videos/omni-video", "mode": "omni_video", "cleanupElementIds": cleanupIDs},
CleanupElementIDs: cleanupIDs,
}, nil
}
payload, taskType, err := kelingVideoPayload(ctx, request)
if err != nil {
return kelingPreparedTask{}, err
}
endpoint := "/videos/" + taskType
return kelingPreparedTask{
Endpoint: endpoint,
Payload: payload,
RemoteTaskPayload: map[string]any{"endpoint": endpoint, "taskType": taskType},
}, nil
}
func kelingResumePreparedTask(request Request) kelingPreparedTask {
endpoint := ""
for _, key := range []string{"endpoint", "pollEndpoint"} {
if value := strings.TrimSpace(stringFromAny(request.RemoteTaskPayload[key])); value != "" {
endpoint = value
break
}
}
if endpoint == "" {
if kelingIsOmniRequest(request) {
endpoint = "/videos/omni-video"
} else {
endpoint = "/videos/" + kelingTaskTypeFromRequest(request)
}
}
return kelingPreparedTask{Endpoint: endpoint, RemoteTaskPayload: map[string]any{"endpoint": endpoint}}
}
func kelingVideoPayload(ctx context.Context, request Request) (map[string]any, string, error) {
body := cleanProviderBody(request.Body)
content := contentItems(body["content"])
if len(content) == 0 {
content = buildVolcesContentFromBody(body)
}
prompt := firstKelingPrompt(content)
if prompt == "" {
return nil, "", &ClientError{Code: "invalid_parameter", Message: "keling video prompt is required", StatusCode: 400, Retryable: false}
}
firstFrame, lastFrame, referenceImages := kelingImageInputs(content)
isImage2Video := firstFrame != "" || lastFrame != "" || len(referenceImages) > 0
primaryImage := firstFrame
if primaryImage == "" && len(referenceImages) <= 1 && len(referenceImages) > 0 {
primaryImage = referenceImages[0]
}
if primaryImage == "" {
primaryImage = lastFrame
}
payload := map[string]any{
"prompt": prompt,
"model_name": upstreamModelName(request.Candidate),
"duration": fmtDuration(body["duration"], 5),
}
if value := strings.TrimSpace(stringFromAny(body["negative_prompt"])); value != "" {
payload["negative_prompt"] = value
}
if value, ok := body["cfg_scale"]; ok && numericValue(value, 0) > 0 {
payload["cfg_scale"] = value
}
if boolValue(body, "audio") || boolValue(body, "output_audio") {
payload["sound"] = "on"
}
if mode := kelingModeByResolution(firstNonEmptyStringValue(body, "resolution", "size")); mode != "" {
payload["mode"] = mode
}
if ratio := strings.TrimSpace(firstNonEmptyStringValue(body, "aspect_ratio", "aspectRatio", "ratio")); strings.Contains(ratio, ":") {
payload["aspect_ratio"] = ratio
}
if camera := kelingCameraControl(body); camera != nil {
payload["camera_control"] = camera
}
if primaryImage != "" {
encoded, err := kelingImageToBase64(ctx, request, primaryImage)
if err != nil {
return nil, "", err
}
payload["image"] = encoded
}
if lastFrame != "" {
encoded, err := kelingImageToBase64(ctx, request, lastFrame)
if err != nil {
return nil, "", err
}
payload["image_tail"] = encoded
}
if len(referenceImages) > 0 {
imageList := make([]any, 0, len(referenceImages))
for _, url := range referenceImages {
encoded, err := kelingImageToBase64(ctx, request, url)
if err != nil {
return nil, "", err
}
imageList = append(imageList, map[string]any{"image": encoded})
}
payload["image_list"] = imageList
}
if !strings.Contains(stringFromAny(payload["aspect_ratio"]), ":") || isImage2Video {
delete(payload, "aspect_ratio")
}
taskType := "text2video"
if primaryImage != "" {
taskType = "image2video"
} else if len(referenceImages) > 1 {
taskType = "multi-image2video"
}
return payload, taskType, nil
}
func kelingTaskTypeFromRequest(request Request) string {
body := cleanProviderBody(request.Body)
content := contentItems(body["content"])
if len(content) == 0 {
content = buildVolcesContentFromBody(body)
}
firstFrame, lastFrame, referenceImages := kelingImageInputs(content)
if firstFrame != "" || lastFrame != "" || len(referenceImages) == 1 {
return "image2video"
}
if len(referenceImages) > 1 {
return "multi-image2video"
}
return "text2video"
}
func (c KelingClient) kelingOmniPayload(ctx context.Context, request Request, token string) (map[string]any, []string, error) {
body := cleanProviderBody(request.Body)
content := contentItems(body["content"])
if len(content) == 0 {
content = buildVolcesContentFromBody(body)
}
prompt := firstKelingPrompt(content)
images := kelingOmniImageList(content)
videos := kelingOmniVideoList(content)
uploadedElementIDs := make([]string, 0)
elements, createdIDs, err := c.kelingOmniElementList(ctx, request, token, content)
if err != nil {
return nil, nil, err
}
uploadedElementIDs = append(uploadedElementIDs, createdIDs...)
shots := kelingShotPrompts(content)
hasMultiPrompt := len(shots) > 0
hasVideo := len(videos) > 0
hasVideoEdit := kelingHasBaseVideo(videos)
hasFirstFrame := kelingHasFirstFrame(images)
payload := map[string]any{
"model_name": upstreamModelName(request.Candidate),
"mode": kelingModeByResolution(firstNonEmptyStringValue(body, "resolution", "size")),
"watermark_info": map[string]any{"enabled": false},
"negative_prompt": strings.TrimSpace(stringFromAny(body["negative_prompt"])),
}
if !hasMultiPrompt {
payload["prompt"] = prompt
if body["duration"] != nil {
payload["duration"] = fmtDuration(body["duration"], 0)
}
}
if ratio := strings.TrimSpace(firstNonEmptyStringValue(body, "aspect_ratio", "aspectRatio", "ratio")); strings.Contains(ratio, ":") {
payload["aspect_ratio"] = ratio
}
if len(images) > 0 {
payload["image_list"] = images
}
if len(videos) > 0 {
payload["video_list"] = videos
}
if len(elements) > 0 {
payload["element_list"] = elements
}
if (boolValue(body, "audio") || boolValue(body, "output_audio")) && !hasVideo {
payload["sound"] = "on"
}
if hasMultiPrompt {
payload["multi_shot"] = true
payload["shot_type"] = "customize"
total := 0.0
multiPrompt := make([]any, 0, len(shots))
for index, shot := range shots {
duration := shot.duration
if duration <= 0 {
duration = 5
}
total += duration
multiPrompt = append(multiPrompt, map[string]any{
"index": index + 1,
"prompt": shot.text,
"duration": fmtDuration(duration, 5),
})
}
delete(payload, "prompt")
payload["multi_prompt"] = multiPrompt
payload["duration"] = fmtDuration(total, 0)
}
deleteEmptyStringFields(payload)
if hasVideoEdit {
delete(payload, "duration")
delete(payload, "aspect_ratio")
}
if hasVideo && !hasVideoEdit && !strings.Contains(stringFromAny(payload["aspect_ratio"]), ":") {
payload["aspect_ratio"] = "16:9"
}
if !hasVideoEdit && !hasFirstFrame && !strings.Contains(stringFromAny(payload["aspect_ratio"]), ":") {
payload["aspect_ratio"] = "16:9"
}
return payload, uploadedElementIDs, nil
}
func (c KelingClient) kelingOmniElementList(ctx context.Context, request Request, token string, content []map[string]any) ([]any, []string, error) {
elements := make([]any, 0)
createdIDs := make([]string, 0)
for _, item := range content {
if stringFromAny(item["type"]) != "element" {
continue
}
element := mapFromAny(item["element"])
if element == nil {
continue
}
if id := kelingStringFromAny(firstPresent(element["element_id"], element["id"])); id != "" {
elements = append(elements, map[string]any{"element_id": id})
continue
}
inline := mapFromAny(element["inline_element"])
if inline == nil {
continue
}
payload := kelingCreateElementPayload(inline)
if payload == nil {
continue
}
id, err := c.createKelingElement(ctx, request, token, payload)
if err != nil {
return nil, createdIDs, err
}
elements = append(elements, map[string]any{"element_id": id})
createdIDs = append(createdIDs, id)
}
return elements, createdIDs, nil
}
func (c KelingClient) postJSON(ctx context.Context, request Request, path string, token string, body map[string]any) (map[string]any, string, error) {
raw, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, path), bytes.NewReader(raw))
if err != nil {
return nil, "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
if err != nil {
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
if err != nil {
return result, requestID, err
}
if code := intFromAny(result["code"]); code != 0 {
return result, requestID, &ClientError{Code: kelingEnvelopeErrorCode(result), Message: kelingEnvelopeErrorMessage(result), RequestID: firstNonEmpty(requestID, stringFromAny(result["request_id"])), Retryable: false}
}
return result, firstNonEmpty(requestID, stringFromAny(result["request_id"])), nil
}
func (c KelingClient) getJSON(ctx context.Context, request Request, path string, token string) (map[string]any, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, joinURL(request.Candidate.BaseURL, path), nil)
if err != nil {
return nil, "", err
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
if err != nil {
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
if err != nil {
return result, requestID, err
}
if code := intFromAny(result["code"]); code != 0 {
return result, requestID, &ClientError{Code: kelingEnvelopeErrorCode(result), Message: kelingEnvelopeErrorMessage(result), RequestID: firstNonEmpty(requestID, stringFromAny(result["request_id"])), Retryable: false}
}
return result, firstNonEmpty(requestID, stringFromAny(result["request_id"])), nil
}
func (c KelingClient) createKelingElement(ctx context.Context, request Request, token string, payload map[string]any) (string, error) {
raw, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, "/general/custom-elements"), bytes.NewReader(raw))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
if err != nil {
return "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
defer resp.Body.Close()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", &ClientError{Code: statusCodeName(resp.StatusCode), Message: errorMessage(body, resp.Status), StatusCode: resp.StatusCode, RequestID: requestIDFromHTTPResponse(resp), Retryable: HTTPRetryable(resp.StatusCode)}
}
var parsed struct {
Code int `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
Data map[string]any `json:"data"`
}
decoder := json.NewDecoder(bytes.NewReader(body))
decoder.UseNumber()
if err := decoder.Decode(&parsed); err != nil {
return "", &ClientError{Code: "invalid_response", Message: err.Error(), Retryable: false}
}
if parsed.Code != 0 {
return "", &ClientError{Code: "keling_element_create_failed", Message: parsed.Message, RequestID: parsed.RequestID, Retryable: false}
}
id := kelingStringFromAny(parsed.Data["element_id"])
if id == "" {
return "", &ClientError{Code: "invalid_response", Message: "keling element id is missing", RequestID: parsed.RequestID, Retryable: false}
}
return id, nil
}
func (c KelingClient) cleanupKelingElements(ctx context.Context, request Request, token string, elementIDs []string) error {
for _, id := range elementIDs {
id = strings.TrimSpace(id)
if id == "" {
continue
}
_, _, _ = c.postJSON(ctx, request, "/general/delete-elements", token, map[string]any{"element_id": id})
}
return nil
}
func kelingAuthToken(candidate store.RuntimeModelCandidate) (string, error) {
apiKey := credential(candidate.Credentials, "apiKey", "api_key", "key", "token")
accessKey := credential(candidate.Credentials, "accessKey", "access_key", "ak")
secretKey := credential(candidate.Credentials, "secretKey", "secret_key", "sk")
if accessKey != "" || secretKey != "" || strings.EqualFold(strings.TrimSpace(candidate.AuthType), "AccessKey-SecretKey") {
if accessKey == "" || secretKey == "" {
return "", &ClientError{Code: "missing_credentials", Message: "keling accessKey and secretKey are required", Retryable: false}
}
now := time.Now()
claims := jwt.MapClaims{
"iss": accessKey,
"exp": now.Add(30 * time.Minute).Unix(),
"nbf": now.Add(-5 * time.Second).Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString([]byte(secretKey))
if err != nil {
return "", &ClientError{Code: "auth_failed", Message: err.Error(), Retryable: false}
}
return signed, nil
}
if apiKey == "" {
return "", &ClientError{Code: "missing_credentials", Message: "keling api key is required", Retryable: false}
}
return apiKey, nil
}
func kelingImageToBase64(ctx context.Context, request Request, value string) (string, error) {
value = strings.TrimSpace(value)
if value == "" {
return "", nil
}
if strings.HasPrefix(value, "data:") {
parts := strings.SplitN(value, ",", 2)
if len(parts) == 2 {
return strings.TrimSpace(parts[1]), nil
}
}
if strings.HasPrefix(value, "http://") || strings.HasPrefix(value, "https://") {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, value, nil)
if err != nil {
return "", err
}
resp, err := httpClient(request.HTTPClient).Do(req)
if err != nil {
return "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
return "", &ClientError{Code: statusCodeName(resp.StatusCode), Message: errorMessage(raw, resp.Status), StatusCode: resp.StatusCode, RequestID: requestIDFromHTTPResponse(resp), Retryable: HTTPRetryable(resp.StatusCode)}
}
raw, err := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
if err != nil {
return "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
return base64.StdEncoding.EncodeToString(raw), nil
}
return value, nil
}
func kelingIsOmniRequest(request Request) bool {
modelType := strings.TrimSpace(request.ModelType)
return modelType == "omni_video" || modelType == "omni" ||
request.Candidate.Capabilities["omni_video"] != nil ||
request.Candidate.Capabilities["omni"] != nil
}
func firstKelingPrompt(content []map[string]any) string {
for _, item := range content {
if stringFromAny(item["type"]) == "text" && stringFromAny(item["role"]) != "shot_prompt" && item["shot_index"] == nil {
if text := strings.TrimSpace(stringFromAny(item["text"])); text != "" {
return text
}
}
}
return ""
}
func kelingImageInputs(content []map[string]any) (string, string, []string) {
firstFrame := ""
lastFrame := ""
references := make([]string, 0)
for _, item := range content {
if !isKelingImageContent(item) {
continue
}
url := kelingNestedURL(item, "image_url")
if url == "" {
continue
}
switch stringFromAny(item["role"]) {
case "first_frame":
if firstFrame == "" {
firstFrame = url
}
case "last_frame":
if lastFrame == "" {
lastFrame = url
}
default:
references = append(references, url)
}
}
return firstFrame, lastFrame, references
}
func kelingOmniImageList(content []map[string]any) []any {
out := make([]any, 0)
for _, item := range content {
if !isKelingImageContent(item) {
continue
}
url := kelingNestedURL(item, "image_url")
if url == "" {
continue
}
image := map[string]any{"image_url": url}
switch stringFromAny(item["role"]) {
case "first_frame":
image["type"] = "first_frame"
case "last_frame":
image["type"] = "end_frame"
}
out = append(out, image)
}
return out
}
func kelingOmniVideoList(content []map[string]any) []map[string]any {
out := make([]map[string]any, 0)
for _, item := range content {
if !isKelingVideoContent(item) {
continue
}
nested := mapFromAny(item["video_url"])
url := strings.TrimSpace(stringFromAny(nested["url"]))
if url == "" {
continue
}
video := map[string]any{"video_url": url}
referType := strings.TrimSpace(stringFromAny(nested["refer_type"]))
if referType == "" {
switch stringFromAny(item["role"]) {
case "video_base":
referType = "base"
case "video_feature", "reference_video":
referType = "feature"
}
}
if referType == "base" || referType == "feature" {
video["refer_type"] = referType
}
if keep := strings.TrimSpace(stringFromAny(nested["keep_original_sound"])); keep != "" {
video["keep_original_sound"] = keep
}
out = append(out, video)
}
return out
}
type kelingShotPrompt struct {
index int
text string
duration float64
}
func kelingShotPrompts(content []map[string]any) []kelingShotPrompt {
shots := make([]kelingShotPrompt, 0)
for index, item := range content {
if stringFromAny(item["type"]) != "text" {
continue
}
if stringFromAny(item["role"]) != "shot_prompt" && item["shot_index"] == nil {
continue
}
text := strings.TrimSpace(stringFromAny(item["text"]))
if text == "" {
continue
}
shotIndex := int(math.Floor(numericValue(item["shot_index"], float64(index))))
shots = append(shots, kelingShotPrompt{index: shotIndex, text: text, duration: numericValue(item["duration"], 5)})
}
sort.SliceStable(shots, func(i, j int) bool { return shots[i].index < shots[j].index })
return shots
}
func kelingHasBaseVideo(videos []map[string]any) bool {
for _, video := range videos {
if stringFromAny(video["refer_type"]) == "base" {
return true
}
}
return false
}
func kelingHasFirstFrame(images []any) bool {
for _, item := range images {
image := mapFromAny(item)
if stringFromAny(image["type"]) == "first_frame" {
return true
}
}
return false
}
func kelingCreateElementPayload(inline map[string]any) map[string]any {
frontURL := strings.TrimSpace(firstNonEmptyStringValue(inline, "frontal_image_url", "frontalImageUrl", "element_frontal_image", "image_url", "imageUrl", "url"))
if frontURL == "" {
return nil
}
name := firstNonEmptyStringValue(inline, "name", "element_name", "elementName")
if name == "" {
name = "temporary element"
}
payload := map[string]any{
"element_name": name,
"element_description": firstNonEmpty(firstNonEmptyStringValue(inline, "description"), name),
"element_frontal_image": frontURL,
}
referImages := make([]any, 0)
for _, ref := range mapListFromAny(firstPresent(inline["refer_images"], inline["referImages"], inline["element_refer_list"])) {
url := strings.TrimSpace(firstNonEmptyStringValue(ref, "url", "image_url", "imageUrl"))
if url != "" {
referImages = append(referImages, map[string]any{"image_url": url})
}
}
if len(referImages) > 0 {
payload["element_refer_list"] = referImages
}
if tags := kelingElementTagList(inline["tags"]); len(tags) > 0 {
payload["tag_list"] = tags
}
return payload
}
func kelingElementTagList(value any) []any {
mapping := map[string]string{
"hot_meme": "o_101",
"character": "o_102",
"animal": "o_103",
"prop": "o_104",
"costume": "o_105",
"scene": "o_106",
"effect": "o_107",
"other": "o_108",
}
out := make([]any, 0)
for _, tag := range stringListFromAny(value) {
id := mapping[strings.TrimSpace(tag)]
if id == "" {
id = mapping["other"]
}
out = append(out, map[string]any{"tag_id": id})
}
return out
}
func kelingNestedURL(item map[string]any, key string) string {
nested := mapFromAny(item[key])
if nested != nil {
if value := strings.TrimSpace(stringFromAny(nested["url"])); value != "" {
return value
}
}
return strings.TrimSpace(stringFromAny(item[key]))
}
func isKelingImageContent(item map[string]any) bool {
return stringFromAny(item["type"]) == "image_url" || mapFromAny(item["image_url"]) != nil || strings.TrimSpace(stringFromAny(item["image_url"])) != ""
}
func isKelingVideoContent(item map[string]any) bool {
return stringFromAny(item["type"]) == "video_url" || mapFromAny(item["video_url"]) != nil || strings.TrimSpace(stringFromAny(item["video_url"])) != ""
}
func kelingModeByResolution(resolution string) string {
switch strings.TrimSpace(resolution) {
case "2160p":
return "4k"
case "1080p":
return "pro"
case "480p", "720p", "":
return "std"
default:
if strings.HasSuffix(strings.TrimSpace(resolution), "p") {
return "std"
}
return ""
}
}
func kelingCameraControl(body map[string]any) map[string]any {
cameraControl := strings.TrimSpace(stringFromAny(body["camera_control"]))
if cameraControl == "" {
return nil
}
if strings.HasPrefix(cameraControl, "simple") {
directions := []string{"horizontal", "vertical", "pan", "tilt", "roll", "zoom"}
current := ""
parts := strings.SplitN(cameraControl, ":", 2)
if len(parts) == 2 {
current = parts[1]
}
strength := firstPresent(body["camera_control_strength"], body["cameraControlStrength"])
config := map[string]any{}
for _, direction := range directions {
if direction == current {
config[direction] = strength
} else {
config[direction] = 0
}
}
return map[string]any{"type": "simple", "config": config}
}
return map[string]any{"type": cameraControl}
}
func kelingData(result map[string]any) map[string]any {
data, _ := result["data"].(map[string]any)
if data == nil {
return map[string]any{}
}
return data
}
func kelingTaskStatus(result map[string]any) string {
return strings.ToLower(strings.TrimSpace(stringFromAny(kelingData(result)["task_status"])))
}
func kelingTaskErrorCode(result map[string]any) string {
if code := intFromAny(result["code"]); code != 0 {
return fmt.Sprintf("keling_%d", code)
}
return "keling_task_failed"
}
func kelingTaskErrorMessage(candidate store.RuntimeModelCandidate, result map[string]any) string {
message := strings.TrimSpace(stringFromAny(kelingData(result)["task_status_msg"]))
if message == "" {
message = strings.TrimSpace(stringFromAny(result["message"]))
}
if message == "" {
message = "keling video task failed"
}
return fmt.Sprintf("Platform:%s,Code:%v,requestId:%s,message:%s", candidate.Provider, result["code"], stringFromAny(result["request_id"]), message)
}
func kelingEnvelopeErrorCode(result map[string]any) string {
if code := intFromAny(result["code"]); code != 0 {
return fmt.Sprintf("keling_%d", code)
}
return "keling_error"
}
func kelingEnvelopeErrorMessage(result map[string]any) string {
if message := strings.TrimSpace(stringFromAny(result["message"])); message != "" {
return message
}
return "keling request failed"
}
func kelingVideoSuccessResult(request Request, upstreamTaskID string, raw map[string]any) map[string]any {
data := kelingData(raw)
taskResult, _ := data["task_result"].(map[string]any)
videos, _ := taskResult["videos"].([]any)
items := make([]any, 0, len(videos))
for _, rawVideo := range videos {
video := mapFromAny(rawVideo)
url := strings.TrimSpace(stringFromAny(video["url"]))
if url == "" {
continue
}
item := map[string]any{"url": url, "video_url": url, "type": "video"}
if duration := intFromAny(video["duration"]); duration > 0 {
item["duration"] = duration
}
items = append(items, item)
}
created := intFromAny(data["created_at"])
if created == 0 {
created = int(nowUnix())
}
return map[string]any{
"id": upstreamTaskID,
"object": "video.generation",
"created": created,
"model": upstreamModelName(request.Candidate),
"status": "succeeded",
"upstream_task_id": upstreamTaskID,
"data": items,
"raw": raw,
}
}
func kelingVideoProgress(request Request, upstreamTaskID string) []Progress {
progress := providerProgress(request)
progress = append(progress, Progress{
Phase: "polling_result",
Progress: 0.9,
Message: "keling video task completed",
Payload: map[string]any{"upstreamTaskId": upstreamTaskID},
})
return progress
}
func kelingPollEndpoint(request Request, fallback string) string {
for _, key := range []string{"endpoint", "pollEndpoint"} {
if value := strings.TrimSpace(stringFromAny(request.RemoteTaskPayload[key])); value != "" {
return value
}
}
return fallback
}
func kelingPollInterval(request Request) time.Duration {
ms := numericValue(firstPresent(request.Candidate.PlatformConfig["kelingPollIntervalMs"], request.Candidate.PlatformConfig["klingPollIntervalMs"], request.Body["pollIntervalMs"], request.Body["poll_interval_ms"]), 15000)
if ms < 100 {
ms = 100
}
return time.Duration(ms) * time.Millisecond
}
func kelingPollTimeout(request Request) time.Duration {
seconds := numericValue(firstPresent(request.Candidate.PlatformConfig["kelingPollTimeoutSeconds"], request.Candidate.PlatformConfig["klingPollTimeoutSeconds"], request.Body["pollTimeoutSeconds"], request.Body["poll_timeout_seconds"]), 600)
if seconds < 1 {
seconds = 600
}
return time.Duration(seconds) * time.Second
}
func fmtDuration(value any, fallback float64) string {
duration := numericValue(value, fallback)
if math.Abs(duration-math.Round(duration)) < 1e-9 {
return fmt.Sprintf("%d", int(math.Round(duration)))
}
return strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", duration), "0"), ".")
}
func deleteEmptyStringFields(payload map[string]any) {
for key, value := range payload {
if text, ok := value.(string); ok && strings.TrimSpace(text) == "" {
delete(payload, key)
}
}
}
func kelingStringFromAny(value any) string {
switch typed := value.(type) {
case json.Number:
return typed.String()
case float64:
if math.Abs(typed-math.Round(typed)) < 1e-9 {
return fmt.Sprintf("%.0f", typed)
}
return fmt.Sprintf("%v", typed)
case int:
return fmt.Sprintf("%d", typed)
case int64:
return fmt.Sprintf("%d", typed)
case string:
return strings.TrimSpace(typed)
default:
return ""
}
}

View File

@ -33,11 +33,11 @@ func (c OpenAIClient) Run(ctx context.Context, request Request) (Response, error
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
responseStartedAt := time.Now()
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
if err != nil {
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
responseStartedAt := time.Now()
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeOpenAIResponse(resp, stream, request.StreamDelta)
responseFinishedAt := time.Now()

View File

@ -146,5 +146,8 @@ func responseDurationMS(startedAt time.Time, finishedAt time.Time) int64 {
if duration < 0 {
return 0
}
if duration == 0 && finishedAt.After(startedAt) {
return 1
}
return duration
}

View File

@ -45,11 +45,11 @@ func (c VolcesClient) runImage(ctx context.Context, request Request, apiKey stri
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
responseStartedAt := time.Now()
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
if err != nil {
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
responseStartedAt := time.Now()
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
responseFinishedAt := time.Now()

View File

@ -10,6 +10,17 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// listAccessRules godoc
// @Summary 列出访问规则
// @Description 管理端返回用户组、租户、用户或 API Key 到平台、平台模型、基础模型的访问规则。
// @Tags access-rules
// @Produce json
// @Security BearerAuth
// @Success 200 {object} AccessRuleListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/access-rules [get]
func (s *Server) listAccessRules(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListAccessRules(r.Context())
if err != nil {
@ -20,6 +31,17 @@ func (s *Server) listAccessRules(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// listAPIKeyAccessRules godoc
// @Summary 列出 API Key 访问规则
// @Description 返回当前本地用户可管理的 API Key 访问规则。
// @Tags api-keys
// @Produce json
// @Security BearerAuth
// @Success 200 {object} AccessRuleListResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/api-keys/access-rules [get]
func (s *Server) listAPIKeyAccessRules(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
items, err := s.store.ListAPIKeyAccessRules(r.Context(), user)
@ -35,6 +57,21 @@ func (s *Server) listAPIKeyAccessRules(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// createAccessRule godoc
// @Summary 创建访问规则
// @Description 管理端创建一条访问控制规则。
// @Tags access-rules
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.AccessRuleInput true "访问规则请求"
// @Success 201 {object} store.AccessRule
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/access-rules [post]
func (s *Server) createAccessRule(w http.ResponseWriter, r *http.Request) {
var input store.AccessRuleInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -58,6 +95,20 @@ func (s *Server) createAccessRule(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, item)
}
// batchAccessRules godoc
// @Summary 批量写入访问规则
// @Description 管理端为同一主体批量新增、更新或删除资源访问规则。
// @Tags access-rules
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.AccessRuleBatchInput true "访问规则批量请求"
// @Success 200 {object} AccessRuleListResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/access-rules/batch [post]
func (s *Server) batchAccessRules(w http.ResponseWriter, r *http.Request) {
var input store.AccessRuleBatchInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -77,6 +128,21 @@ func (s *Server) batchAccessRules(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// batchAPIKeyAccessRules godoc
// @Summary 批量写入 API Key 访问规则
// @Description 当前本地用户为自己的 API Key 批量新增、更新或删除可访问资源。
// @Tags api-keys
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.AccessRuleBatchInput true "API Key 访问规则批量请求subjectType 必须为 api_key"
// @Success 200 {object} AccessRuleListResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/api-keys/access-rules/batch [post]
func (s *Server) batchAPIKeyAccessRules(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
var input store.AccessRuleBatchInput
@ -109,6 +175,23 @@ func (s *Server) batchAPIKeyAccessRules(w http.ResponseWriter, r *http.Request)
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// updateAccessRule godoc
// @Summary 更新访问规则
// @Description 管理端更新一条访问控制规则。
// @Tags access-rules
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param ruleID path string true "访问规则 ID"
// @Param input body store.AccessRuleInput true "访问规则请求"
// @Success 200 {object} store.AccessRule
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/access-rules/{ruleID} [patch]
func (s *Server) updateAccessRule(w http.ResponseWriter, r *http.Request) {
var input store.AccessRuleInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -136,6 +219,19 @@ func (s *Server) updateAccessRule(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// deleteAccessRule godoc
// @Summary 删除访问规则
// @Description 管理端删除一条访问控制规则。
// @Tags access-rules
// @Produce json
// @Security BearerAuth
// @Param ruleID path string true "访问规则 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/access-rules/{ruleID} [delete]
func (s *Server) deleteAccessRule(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteAccessRule(r.Context(), r.PathValue("ruleID")); err != nil {
if store.IsNotFound(err) {

View File

@ -12,13 +12,29 @@ import (
)
type walletBalanceRequest struct {
Currency string `json:"currency"`
Balance float64 `json:"balance"`
Reason string `json:"reason"`
IdempotencyKey string `json:"idempotencyKey"`
Currency string `json:"currency" example:"USD"`
Balance float64 `json:"balance" example:"100"`
Reason string `json:"reason" example:"manual recharge"`
IdempotencyKey string `json:"idempotencyKey" example:"wallet-set-20260514-001"`
Metadata map[string]any `json:"metadata"`
}
// setUserWalletBalance godoc
// @Summary 设置用户钱包余额
// @Description 管理端把指定用户钱包余额调整到目标值并记录审计日志balance 不允许为负数。
// @Tags billing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userID path string true "用户 ID"
// @Param input body walletBalanceRequest true "钱包余额设置请求"
// @Success 200 {object} WalletAdjustmentResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/users/{userID}/wallet [patch]
func (s *Server) setUserWalletBalance(w http.ResponseWriter, r *http.Request) {
actor, _ := auth.UserFromContext(r.Context())
var input walletBalanceRequest
@ -79,6 +95,23 @@ func (s *Server) setUserWalletBalance(w http.ResponseWriter, r *http.Request) {
})
}
// listAuditLogs godoc
// @Summary 列出审计日志
// @Description 管理端按分类、动作、目标类型和目标 ID 查询审计日志。
// @Tags billing
// @Produce json
// @Security BearerAuth
// @Param category query string false "审计分类"
// @Param action query string false "审计动作"
// @Param targetType query string false "目标类型"
// @Param targetId query string false "目标 ID"
// @Param limit query int false "返回数量" default(100)
// @Success 200 {object} AuditLogListResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/audit-logs [get]
func (s *Server) listAuditLogs(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
limit, err := positiveQueryInt(query.Get("limit"), 100)

View File

@ -9,6 +9,15 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// listCatalogProviders godoc
// @Summary 列出目录供应商
// @Description 返回模型目录使用的供应商元数据;公共路径和管理路径返回同一结构。
// @Tags catalog
// @Produce json
// @Success 200 {object} CatalogProviderListResponse
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/public/catalog/providers [get]
// @Router /api/admin/catalog/providers [get]
func (s *Server) listCatalogProviders(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListCatalogProviders(r.Context())
if err != nil {
@ -19,6 +28,21 @@ func (s *Server) listCatalogProviders(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// createCatalogProvider godoc
// @Summary 创建目录供应商
// @Description 管理端新增模型目录供应商providerKey 和 displayName 必填。
// @Tags catalog
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.CatalogProviderInput true "目录供应商请求"
// @Success 201 {object} store.CatalogProvider
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/catalog/providers [post]
func (s *Server) createCatalogProvider(w http.ResponseWriter, r *http.Request) {
var input store.CatalogProviderInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -42,6 +66,23 @@ func (s *Server) createCatalogProvider(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, item)
}
// updateCatalogProvider godoc
// @Summary 更新目录供应商
// @Description 管理端更新目录供应商展示信息、图标和元数据。
// @Tags catalog
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param providerID path string true "目录供应商 ID"
// @Param input body store.CatalogProviderInput true "目录供应商请求"
// @Success 200 {object} store.CatalogProvider
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/catalog/providers/{providerID} [patch]
func (s *Server) updateCatalogProvider(w http.ResponseWriter, r *http.Request) {
var input store.CatalogProviderInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -69,6 +110,19 @@ func (s *Server) updateCatalogProvider(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// deleteCatalogProvider godoc
// @Summary 删除目录供应商
// @Description 管理端删除目录供应商。
// @Tags catalog
// @Produce json
// @Security BearerAuth
// @Param providerID path string true "目录供应商 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/catalog/providers/{providerID} [delete]
func (s *Server) deleteCatalogProvider(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteCatalogProvider(r.Context(), r.PathValue("providerID")); err != nil {
if store.IsNotFound(err) {
@ -82,6 +136,15 @@ func (s *Server) deleteCatalogProvider(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// listBaseModels godoc
// @Summary 列出基础模型
// @Description 返回基础模型目录;公共路径和管理路径返回同一结构。
// @Tags catalog
// @Produce json
// @Success 200 {object} BaseModelListResponse
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/public/catalog/base-models [get]
// @Router /api/admin/catalog/base-models [get]
func (s *Server) listBaseModels(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListBaseModels(r.Context())
if err != nil {
@ -92,6 +155,21 @@ func (s *Server) listBaseModels(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// createBaseModel godoc
// @Summary 创建基础模型
// @Description 管理端新增基础模型目录项providerKey、providerModelName 和 modelType 必填。
// @Tags catalog
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.BaseModelInput true "基础模型请求"
// @Success 201 {object} store.BaseModel
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/catalog/base-models [post]
func (s *Server) createBaseModel(w http.ResponseWriter, r *http.Request) {
var input store.BaseModelInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -115,6 +193,23 @@ func (s *Server) createBaseModel(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, item)
}
// updateBaseModel godoc
// @Summary 更新基础模型
// @Description 管理端更新基础模型目录项及能力、图标、默认快照等元数据。
// @Tags catalog
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param baseModelID path string true "基础模型 ID"
// @Param input body store.BaseModelInput true "基础模型请求"
// @Success 200 {object} store.BaseModel
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/catalog/base-models/{baseModelID} [patch]
func (s *Server) updateBaseModel(w http.ResponseWriter, r *http.Request) {
var input store.BaseModelInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -142,6 +237,20 @@ func (s *Server) updateBaseModel(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// resetBaseModel godoc
// @Summary 重置基础模型
// @Description 将指定基础模型恢复为系统默认快照;无默认快照时返回 409。
// @Tags catalog
// @Produce json
// @Security BearerAuth
// @Param baseModelID path string true "基础模型 ID"
// @Success 200 {object} store.BaseModel
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/catalog/base-models/{baseModelID}/reset [post]
func (s *Server) resetBaseModel(w http.ResponseWriter, r *http.Request) {
item, err := s.store.ResetBaseModelToDefault(r.Context(), r.PathValue("baseModelID"))
if err != nil {
@ -160,6 +269,17 @@ func (s *Server) resetBaseModel(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// resetAllBaseModels godoc
// @Summary 重置全部基础模型
// @Description 将所有具备系统默认快照的基础模型恢复为默认配置。
// @Tags catalog
// @Produce json
// @Security BearerAuth
// @Success 200 {object} BaseModelListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/catalog/base-models/reset-all [post]
func (s *Server) resetAllBaseModels(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ResetAllBaseModelsToDefault(r.Context())
if err != nil {
@ -170,6 +290,19 @@ func (s *Server) resetAllBaseModels(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// deleteBaseModel godoc
// @Summary 删除基础模型
// @Description 管理端删除基础模型目录项。
// @Tags catalog
// @Produce json
// @Security BearerAuth
// @Param baseModelID path string true "基础模型 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/catalog/base-models/{baseModelID} [delete]
func (s *Server) deleteBaseModel(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteBaseModel(r.Context(), r.PathValue("baseModelID")); err != nil {
if store.IsNotFound(err) {

View File

@ -0,0 +1,114 @@
package httpapi
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestPlanTaskResponseTreatsAPIV1ChatCompletionsAsSynchronousCompatibleResponse(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
req.Header.Set("X-Async", "true")
plan := planTaskResponse("chat.completions", false, map[string]any{"stream": true}, req)
if plan.asyncMode {
t.Fatal("/api/v1/chat/completions must not enter async task mode")
}
if !plan.compatibleMode {
t.Fatal("/api/v1/chat/completions should return OpenAI-compatible response payloads")
}
if !plan.streamMode {
t.Fatal("stream=true should select SSE streaming mode")
}
}
func TestPlanTaskResponseKeepsAsyncTaskModeForOtherAPIV1Tasks(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/v1/images/generations", nil)
req.Header.Set("X-Async", "true")
plan := planTaskResponse("images.generations", false, map[string]any{"stream": true}, req)
if !plan.asyncMode {
t.Fatal("non-chat /api/v1 task endpoints should keep X-Async task mode")
}
if plan.compatibleMode {
t.Fatal("non-compatible /api/v1 task endpoints should not return OpenAI-compatible payloads")
}
}
func TestWriteCompatibleTaskResponseReturnsJSONWhenStreamIsFalse(t *testing.T) {
executor := &fakeTaskExecutor{output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion"}}
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
recorder := httptest.NewRecorder()
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, false)
if recorder.Code != http.StatusOK {
t.Fatalf("status=%d want=%d body=%s", recorder.Code, http.StatusOK, recorder.Body.String())
}
if executor.executeCalls != 1 || executor.streamCalls != 0 {
t.Fatalf("expected non-stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
}
var body map[string]any
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response body: %v body=%s", err, recorder.Body.String())
}
if body["object"] != "chat.completion" {
t.Fatalf("unexpected compatible JSON response: %+v", body)
}
}
func TestWriteCompatibleTaskResponseReturnsSSEWhenStreamIsTrue(t *testing.T) {
executor := &fakeTaskExecutor{
deltas: []string{"hel", "lo"},
output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion"},
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
recorder := httptest.NewRecorder()
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, true)
if executor.executeCalls != 0 || executor.streamCalls != 1 {
t.Fatalf("expected stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
}
if contentType := recorder.Header().Get("Content-Type"); contentType != "text/event-stream" {
t.Fatalf("Content-Type=%q want text/event-stream", contentType)
}
body := recorder.Body.String()
for _, want := range []string{"event: message", `"content":"hel"`, `"content":"lo"`, `"finish_reason":"stop"`} {
if !strings.Contains(body, want) {
t.Fatalf("SSE body missing %s: %s", want, body)
}
}
}
type fakeTaskExecutor struct {
executeCalls int
streamCalls int
deltas []string
output map[string]any
}
func (f *fakeTaskExecutor) Execute(context.Context, store.GatewayTask, *auth.User) (runner.Result, error) {
f.executeCalls++
return runner.Result{Output: f.output}, nil
}
func (f *fakeTaskExecutor) ExecuteStream(_ context.Context, _ store.GatewayTask, _ *auth.User, onDelta clients.StreamDelta) (runner.Result, error) {
f.streamCalls++
for _, delta := range f.deltas {
if err := onDelta(delta); err != nil {
return runner.Result{}, err
}
}
return runner.Result{Output: f.output}, nil
}

View File

@ -5,6 +5,16 @@ import (
"strings"
)
// getNetworkProxyConfig godoc
// @Summary 获取网络代理配置
// @Description 管理端查看服务当前使用的全局 HTTP 代理配置及来源。
// @Tags config
// @Produce json
// @Security BearerAuth
// @Success 200 {object} NetworkProxyConfigResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Router /api/admin/config/network-proxy [get]
func (s *Server) getNetworkProxyConfig(w http.ResponseWriter, r *http.Request) {
globalHTTPProxy := strings.TrimSpace(s.cfg.GlobalHTTPProxy)
writeJSON(w, http.StatusOK, map[string]any{

View File

@ -316,13 +316,17 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
} `json:"task"`
}
defaultTextModel := "openai:gpt-4o-mini"
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
var apiV1Chat map[string]any
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
"model": defaultTextModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "ping"}},
}, http.StatusAccepted, &taskResponse)
}, "default-chat-"+suffixText, http.StatusOK, &apiV1Chat, &taskResponse.Task)
if apiV1Chat["object"] != "chat.completion" {
t.Fatalf("unexpected api v1 chat response: %+v", apiV1Chat)
}
if taskResponse.Task.ID == "" || taskResponse.Task.Status != "succeeded" || taskResponse.Task.RunMode != "simulation" {
t.Fatalf("unexpected task response: %+v", taskResponse.Task)
}
@ -513,13 +517,13 @@ LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil {
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": deniedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "permission deny"}},
}, http.StatusAccepted, &deniedTask)
}, "permission-deny-"+suffixText, http.StatusNotFound, nil, &deniedTask.Task)
if deniedTask.Task.Status != "failed" || deniedTask.Task.ErrorCode != "no_model_candidate" {
t.Fatalf("deny access rule should hide denied model from runtime candidates: %+v", deniedTask.Task)
}
@ -561,13 +565,13 @@ LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil {
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": controlledModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "allow should block other keys"}},
}, http.StatusAccepted, &blockedControlledTask)
}, "permission-allow-block-"+suffixText, http.StatusNotFound, nil, &blockedControlledTask.Task)
if blockedControlledTask.Task.Status != "failed" || blockedControlledTask.Task.ErrorCode != "no_model_candidate" {
t.Fatalf("allow access rule should make the resource unavailable to unmatched subjects: %+v", blockedControlledTask.Task)
}
@ -586,13 +590,13 @@ LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil {
Status string `json:"status"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": controlledModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "allow should pass"}},
}, http.StatusAccepted, &allowedControlledTask)
}, "permission-allow-pass-"+suffixText, http.StatusOK, nil, &allowedControlledTask.Task)
if allowedControlledTask.Task.Status != "succeeded" {
t.Fatalf("matching allow access rule should make the controlled model usable: %+v", allowedControlledTask.Task)
}
@ -645,13 +649,13 @@ WHERE gateway_user_id = $1::uuid
FinalChargeAmount float64 `json:"finalChargeAmount"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
"model": pricingModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "priced ping"}},
}, http.StatusAccepted, &pricingTask)
}, "pricing-chat-"+suffixText, http.StatusOK, nil, &pricingTask.Task)
if pricingTask.Task.Status != "succeeded" || !floatNear(pricingTask.Task.FinalChargeAmount, 0.028) {
t.Fatalf("custom pricing rule set should drive text billing, got task=%+v", pricingTask.Task)
}
@ -757,14 +761,14 @@ WHERE reference_type = 'gateway_task'
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
"model": rateLimitedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"simulationProfile": "non_retryable_failure",
"messages": []map[string]any{{"role": "user", "content": "failed first"}},
}, http.StatusAccepted, &rateLimitFailedTask)
}, "rate-limit-failed-first-"+suffixText, http.StatusBadGateway, nil, &rateLimitFailedTask.Task)
if rateLimitFailedTask.Task.Status != "failed" || rateLimitFailedTask.Task.ErrorCode != "bad_request" {
t.Fatalf("failed rate-limited task should fail before consuming rpm: %+v", rateLimitFailedTask.Task)
}
@ -774,13 +778,13 @@ WHERE reference_type = 'gateway_task'
Status string `json:"status"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
"model": rateLimitedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "first"}},
}, http.StatusAccepted, &rateLimitTaskOne)
}, "rate-limit-first-"+suffixText, http.StatusOK, nil, &rateLimitTaskOne.Task)
if rateLimitTaskOne.Task.Status != "succeeded" {
t.Fatalf("first rate-limited task should succeed: %+v", rateLimitTaskOne.Task)
}
@ -790,13 +794,13 @@ WHERE reference_type = 'gateway_task'
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
"model": rateLimitedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "second"}},
}, http.StatusAccepted, &rateLimitTaskTwo)
}, "rate-limit-second-"+suffixText, http.StatusTooManyRequests, nil, &rateLimitTaskTwo.Task)
if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" {
t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task)
}
@ -808,12 +812,12 @@ WHERE reference_type = 'gateway_task'
AsyncMode bool `json:"asyncMode"`
} `json:"task"`
}
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/responses", apiKeyResponse.Secret, map[string]any{
"model": rateLimitedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "async queued"}},
"input": "async queued",
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &asyncRateLimitTask)
if asyncRateLimitTask.TaskID == "" || asyncRateLimitTask.Task.ID != asyncRateLimitTask.TaskID || !asyncRateLimitTask.Task.AsyncMode {
t.Fatalf("async task response should expose task id and async mode: %+v", asyncRateLimitTask)
@ -984,11 +988,11 @@ WHERE reference_type = 'gateway_task'
Status string `json:"status"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
"model": failoverModel,
"runMode": "simulation",
"messages": []map[string]any{{"role": "user", "content": "retry please"}},
}, http.StatusAccepted, &failoverTask)
}, "failover-chat-"+suffixText, http.StatusOK, nil, &failoverTask.Task)
if failoverTask.Task.Status != "succeeded" {
t.Fatalf("failover task should succeed through second client: %+v", failoverTask.Task)
}
@ -1103,13 +1107,13 @@ WHERE failed.id = $1::uuid`, failedPlatform.ID, successPlatform.ID, unrelatedPri
Status string `json:"status"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
"model": degradeModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "degrade please"}},
}, http.StatusAccepted, &degradeTask)
}, "degrade-chat-"+suffixText, http.StatusOK, nil, &degradeTask.Task)
if degradeTask.Task.Status != "succeeded" {
t.Fatalf("degrade task should fail over after cooling down failed model: %+v", degradeTask.Task)
}
@ -1170,13 +1174,13 @@ WHERE m.platform_id = $1::uuid
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
"model": autoDisableModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "disable please"}},
}, http.StatusAccepted, &autoDisableTask)
}, "auto-disable-chat-"+suffixText, http.StatusBadGateway, nil, &autoDisableTask.Task)
if autoDisableTask.Task.Status != "failed" || autoDisableTask.Task.ErrorCode != "invalid_api_key" {
t.Fatalf("auto disable task should fail with invalid_api_key: %+v", autoDisableTask.Task)
}
@ -1293,12 +1297,12 @@ WHERE m.platform_id = $1::uuid
AsyncMode bool `json:"asyncMode"`
} `json:"task"`
}
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/responses", apiKeyResponse.Secret, map[string]any{
"model": defaultTextModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 2000,
"messages": []map[string]any{{"role": "user", "content": "river worker restart"}},
"input": "river worker restart",
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &restartAsyncTask)
if restartAsyncTask.TaskID == "" || !restartAsyncTask.Task.AsyncMode {
t.Fatalf("restart async task should be accepted as async: %+v", restartAsyncTask)
@ -1453,6 +1457,20 @@ func doJSONWithHeaders(t *testing.T, baseURL string, method string, path string,
}
}
func doAPIV1ChatCompletionAndLoadTask(t *testing.T, ctx context.Context, pool *pgxpool.Pool, baseURL string, token string, payload map[string]any, marker string, expectedStatus int, responseOut any, taskDetailOut any) string {
t.Helper()
payload["integrationTestMarker"] = marker
if responseOut == nil {
responseOut = &map[string]any{}
}
doJSON(t, baseURL, http.MethodPost, "/api/v1/chat/completions", token, payload, expectedStatus, responseOut)
taskID := waitForTaskIDByRequestField(t, ctx, pool, "integrationTestMarker", marker, 2*time.Second)
if taskDetailOut != nil {
doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskID, token, nil, http.StatusOK, taskDetailOut)
}
return taskID
}
type taskWaitDetail struct {
ID string `json:"id"`
Status string `json:"status"`
@ -1481,6 +1499,11 @@ func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string
}
func waitForTaskIDByRequestMarker(t *testing.T, ctx context.Context, pool *pgxpool.Pool, marker string, timeout time.Duration) string {
t.Helper()
return waitForTaskIDByRequestField(t, ctx, pool, "cancelTestId", marker, timeout)
}
func waitForTaskIDByRequestField(t *testing.T, ctx context.Context, pool *pgxpool.Pool, key string, value string, timeout time.Duration) string {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
@ -1488,15 +1511,15 @@ func waitForTaskIDByRequestMarker(t *testing.T, ctx context.Context, pool *pgxpo
err := pool.QueryRow(ctx, `
SELECT id::text
FROM gateway_tasks
WHERE request->>'cancelTestId' = $1
WHERE request->>$1 = $2
ORDER BY created_at DESC
LIMIT 1`, marker).Scan(&taskID)
LIMIT 1`, key, value).Scan(&taskID)
if err == nil && taskID != "" {
return taskID
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("task with request marker %s was not created within %s", marker, timeout)
t.Fatalf("task with request %s=%s was not created within %s", key, value, timeout)
return ""
}
@ -1577,13 +1600,13 @@ func assertLoadAvoidanceSimulatedRetryChain(t *testing.T, ctx context.Context, t
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, baseURL, http.MethodPost, "/api/v1/chat/completions", runtimeToken, map[string]any{
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, baseURL, runtimeToken, map[string]any{
"model": model,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "load avoidance retry chain"}},
}, http.StatusAccepted, &taskResponse)
}, "load-avoidance-"+suffixText, http.StatusBadGateway, nil, &taskResponse.Task)
if taskResponse.Task.ID == "" || taskResponse.Task.Status != "failed" || taskResponse.Task.ErrorCode != "bad_request" {
t.Fatalf("load avoidance task should only fail after avoided clients are retried, got %+v", taskResponse.Task)
}

View File

@ -11,6 +11,22 @@ import (
const maxGatewayUploadBytes = 256 << 20
// uploadFile godoc
// @Summary 上传文件
// @Description 上传文件到配置的文件存储通道;没有启用通道时回退到本地静态上传目录。单文件最大 256MiB。
// @Tags files
// @Accept multipart/form-data
// @Produce json
// @Security BearerAuth
// @Param file formData file true "要上传的文件"
// @Param source formData string false "上传来源标识" default(ai-gateway-openapi)
// @Success 200 {object} FileUploadResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 502 {object} ErrorEnvelope
// @Failure 503 {object} ErrorEnvelope
// @Router /api/v1/files/upload [post]
// @Router /v1/files/upload [post]
func (s *Server) uploadFile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes)
if err := r.ParseMultipartForm(32 << 20); err != nil {

View File

@ -13,9 +13,17 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/netproxy"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// health godoc
// @Summary 健康检查
// @Description 返回服务进程、运行环境和身份模式,供负载均衡或人工排障使用。
// @Tags system
// @Produce json
// @Success 200 {object} HealthResponse
// @Router /healthz [get]
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{
"ok": true,
@ -25,6 +33,14 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
})
}
// ready godoc
// @Summary 就绪检查
// @Description 检查 Postgres 是否可用;数据库不可用时返回 503。
// @Tags system
// @Produce json
// @Success 200 {object} ReadyResponse
// @Failure 503 {object} ErrorEnvelope
// @Router /readyz [get]
func (s *Server) ready(w http.ResponseWriter, r *http.Request) {
if err := s.store.Ping(r.Context()); err != nil {
writeError(w, http.StatusServiceUnavailable, "postgres unavailable")
@ -33,11 +49,33 @@ func (s *Server) ready(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
}
// me godoc
// @Summary 获取当前用户
// @Description 返回鉴权中解析出的用户、租户、用户组和 API Key 上下文。
// @Tags auth
// @Produce json
// @Security BearerAuth
// @Success 200 {object} auth.User
// @Failure 401 {object} ErrorEnvelope
// @Router /api/v1/me [get]
func (s *Server) me(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
writeJSON(w, http.StatusOK, user)
}
// register godoc
// @Summary 本地注册
// @Description 在 standalone 或 hybrid 身份模式下创建本地用户,并返回 24 小时 JWT。
// @Tags auth
// @Accept json
// @Produce json
// @Param input body store.LocalRegisterInput true "注册请求password 至少 8 位invitationCode 取决于部署策略"
// @Success 201 {object} AuthResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/auth/register [post]
func (s *Server) register(w http.ResponseWriter, r *http.Request) {
if !s.localIdentityEnabled() {
writeError(w, http.StatusForbidden, "local registration is disabled")
@ -69,6 +107,19 @@ func (s *Server) register(w http.ResponseWriter, r *http.Request) {
s.writeAuthResponse(w, http.StatusCreated, user)
}
// login godoc
// @Summary 本地登录
// @Description 使用用户名或邮箱登录本地账号,并返回 24 小时 JWT。
// @Tags auth
// @Accept json
// @Produce json
// @Param input body store.LocalLoginInput true "登录请求account 可为用户名或邮箱"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/auth/login [post]
func (s *Server) login(w http.ResponseWriter, r *http.Request) {
if !s.localIdentityEnabled() {
writeError(w, http.StatusForbidden, "local login is disabled")
@ -136,6 +187,17 @@ func authUserFromGatewayUser(user store.GatewayUser) *auth.User {
}
}
// listPlatforms godoc
// @Summary 列出平台
// @Description 管理端返回所有接入平台及其优先级、定价和运行策略摘要。
// @Tags platforms
// @Produce json
// @Security BearerAuth
// @Success 200 {object} PlatformListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/platforms [get]
func (s *Server) listPlatforms(w http.ResponseWriter, r *http.Request) {
platforms, err := s.store.ListPlatforms(r.Context())
if err != nil {
@ -146,6 +208,16 @@ func (s *Server) listPlatforms(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": platforms})
}
// listPlayablePlatforms godoc
// @Summary 列出可用平台
// @Description 按当前用户可访问模型过滤平台,仅返回启用且存在可访问模型的平台。
// @Tags playground
// @Produce json
// @Security BearerAuth
// @Success 200 {object} PlatformListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/platforms [get]
func (s *Server) listPlayablePlatforms(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
models, err := s.store.ListAccessiblePlatformModels(r.Context(), user)
@ -173,6 +245,20 @@ func (s *Server) listPlayablePlatforms(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": filtered})
}
// createPlatform godoc
// @Summary 创建平台
// @Description 新增模型供应商平台配置credentials 会被服务端保存并在返回值中脱敏。
// @Tags platforms
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.CreatePlatformInput true "平台配置请求"
// @Success 201 {object} store.Platform
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/platforms [post]
func (s *Server) createPlatform(w http.ResponseWriter, r *http.Request) {
var input store.CreatePlatformInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -209,6 +295,23 @@ func (s *Server) createPlatform(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, platform)
}
// updatePlatform godoc
// @Summary 更新平台
// @Description 覆盖指定平台的基础配置、凭证、优先级、定价和运行策略。
// @Tags platforms
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param platformID path string true "平台 ID"
// @Param input body store.CreatePlatformInput true "平台配置请求"
// @Success 200 {object} store.Platform
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/platforms/{platformID} [patch]
func (s *Server) updatePlatform(w http.ResponseWriter, r *http.Request) {
var input store.CreatePlatformInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -253,6 +356,19 @@ func (s *Server) updatePlatform(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, platform)
}
// deletePlatform godoc
// @Summary 删除平台
// @Description 删除指定平台及关联配置;不存在时返回 404。
// @Tags platforms
// @Produce json
// @Security BearerAuth
// @Param platformID path string true "平台 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/platforms/{platformID} [delete]
func (s *Server) deletePlatform(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeletePlatform(r.Context(), r.PathValue("platformID")); err != nil {
if store.IsNotFound(err) {
@ -266,6 +382,23 @@ func (s *Server) deletePlatform(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// createPlatformModel godoc
// @Summary 创建平台模型
// @Description 为平台新增一个可路由模型;路径中的 platformID 会覆盖请求体 platformId。
// @Tags platform-models
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param platformID path string true "平台 ID使用 /api/admin/platforms/{platformID}/models 时由路径提供"
// @Param input body store.CreatePlatformModelInput true "平台模型配置请求"
// @Success 201 {object} store.PlatformModel
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/platforms/{platformID}/models [post]
// @Router /api/admin/platform-models [post]
func (s *Server) createPlatformModel(w http.ResponseWriter, r *http.Request) {
var input store.CreatePlatformModelInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -292,6 +425,22 @@ func (s *Server) createPlatformModel(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, s.platformModelResponse(r.Context(), model))
}
// replacePlatformModels godoc
// @Summary 替换平台模型
// @Description 用请求体中的 models 列表整体替换指定平台下的模型配置。
// @Tags platform-models
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param platformID path string true "平台 ID"
// @Param input body ReplacePlatformModelsRequest true "模型列表请求"
// @Success 200 {object} PlatformModelListResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/platforms/{platformID}/models [put]
func (s *Server) replacePlatformModels(w http.ResponseWriter, r *http.Request) {
platformID := r.PathValue("platformID")
if platformID == "" {
@ -320,6 +469,19 @@ func (s *Server) replacePlatformModels(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
}
// deletePlatformModel godoc
// @Summary 删除平台模型
// @Description 删除指定平台模型路由配置。
// @Tags platform-models
// @Produce json
// @Security BearerAuth
// @Param modelID path string true "平台模型 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/platform-models/{modelID} [delete]
func (s *Server) deletePlatformModel(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeletePlatformModel(r.Context(), r.PathValue("modelID")); err != nil {
if store.IsNotFound(err) {
@ -333,6 +495,17 @@ func (s *Server) deletePlatformModel(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// listModels godoc
// @Summary 列出平台模型
// @Description 管理端返回所有平台模型,并补齐有效计费配置。
// @Tags platform-models
// @Produce json
// @Security BearerAuth
// @Success 200 {object} PlatformModelListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/models [get]
func (s *Server) listModels(w http.ResponseWriter, r *http.Request) {
models, err := s.store.ListModels(r.Context())
if err != nil {
@ -343,6 +516,17 @@ func (s *Server) listModels(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
}
// listPlayableModels godoc
// @Summary 列出可调用模型
// @Description 按当前用户权限返回可用于 Playground 或 API 调用的模型列表。
// @Tags playground
// @Produce json
// @Security BearerAuth
// @Success 200 {object} PlatformModelListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/models [get]
// @Router /api/v1/playground/models [get]
func (s *Server) listPlayableModels(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
models, err := s.store.ListAccessiblePlatformModels(r.Context(), user)
@ -354,6 +538,17 @@ func (s *Server) listPlayableModels(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": s.platformModelResponses(r.Context(), models)})
}
// listPricingRules godoc
// @Summary 列出定价规则
// @Description 返回所有定价规则明细,便于管理端排查有效价格。
// @Tags pricing
// @Produce json
// @Security BearerAuth
// @Success 200 {object} PricingRuleListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/pricing/rules [get]
func (s *Server) listPricingRules(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListPricingRules(r.Context())
if err != nil {
@ -364,6 +559,17 @@ func (s *Server) listPricingRules(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// listTenants godoc
// @Summary 列出租户
// @Description 管理端返回网关租户列表。
// @Tags identity
// @Produce json
// @Security BearerAuth
// @Success 200 {object} TenantListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/tenants [get]
func (s *Server) listTenants(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListTenants(r.Context())
if err != nil {
@ -374,6 +580,17 @@ func (s *Server) listTenants(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// listUsers godoc
// @Summary 列出用户
// @Description 管理端返回网关用户列表及钱包摘要。
// @Tags identity
// @Produce json
// @Security BearerAuth
// @Success 200 {object} UserListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/users [get]
func (s *Server) listUsers(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListUsers(r.Context())
if err != nil {
@ -384,6 +601,17 @@ func (s *Server) listUsers(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// listUserGroups godoc
// @Summary 列出用户组
// @Description 管理端返回用户组及其计费、限流和配额策略。
// @Tags identity
// @Produce json
// @Security BearerAuth
// @Success 200 {object} UserGroupListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/user-groups [get]
func (s *Server) listUserGroups(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListUserGroups(r.Context())
if err != nil {
@ -394,6 +622,16 @@ func (s *Server) listUserGroups(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// listAPIKeys godoc
// @Summary 列出 API Key
// @Description 返回当前用户创建的 API Key 元数据secret 只在创建时返回。
// @Tags api-keys
// @Produce json
// @Security BearerAuth
// @Success 200 {object} APIKeyListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/api-keys [get]
func (s *Server) listAPIKeys(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
items, err := s.store.ListAPIKeys(r.Context(), user)
@ -405,6 +643,17 @@ func (s *Server) listAPIKeys(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// listPlayableAPIKeys godoc
// @Summary 列出 Playground API Key
// @Description 返回当前本地用户可在 Playground 中直接使用的 API Key 和 secret。
// @Tags playground
// @Produce json
// @Security BearerAuth
// @Success 200 {object} PlayableAPIKeyListResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/playground/api-keys [get]
func (s *Server) listPlayableAPIKeys(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
items, err := s.store.ListPlayableAPIKeys(r.Context(), user)
@ -420,6 +669,19 @@ func (s *Server) listPlayableAPIKeys(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// createAPIKey godoc
// @Summary 创建 API Key
// @Description 为当前本地用户创建 API Keysecret 仅在本次响应中返回。
// @Tags api-keys
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.CreateAPIKeyInput true "API Key 创建请求"
// @Success 201 {object} store.CreatedAPIKey
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/api-keys [post]
func (s *Server) createAPIKey(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
var input store.CreateAPIKeyInput
@ -440,6 +702,19 @@ func (s *Server) createAPIKey(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, created)
}
// disableAPIKey godoc
// @Summary 禁用 API Key
// @Description 禁用当前用户拥有的 API Key保留记录但不再允许调用。
// @Tags api-keys
// @Produce json
// @Security BearerAuth
// @Param apiKeyID path string true "API Key ID"
// @Success 200 {object} store.APIKey
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/api-keys/{apiKeyID}/disable [patch]
func (s *Server) disableAPIKey(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
item, err := s.store.DisableAPIKey(r.Context(), r.PathValue("apiKeyID"), user)
@ -459,6 +734,19 @@ func (s *Server) disableAPIKey(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusInternalServerError, "disable api key failed")
}
// deleteAPIKey godoc
// @Summary 删除 API Key
// @Description 删除当前用户拥有的 API Key。
// @Tags api-keys
// @Produce json
// @Security BearerAuth
// @Param apiKeyID path string true "API Key ID"
// @Success 204 "No Content"
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/api-keys/{apiKeyID} [delete]
func (s *Server) deleteAPIKey(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
err := s.store.DeleteAPIKey(r.Context(), r.PathValue("apiKeyID"), user)
@ -478,6 +766,22 @@ func (s *Server) deleteAPIKey(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusInternalServerError, "delete api key failed")
}
// estimatePricing godoc
// @Summary 估算请求价格
// @Description 按当前用户、模型候选、任务类型和请求参数估算计费条目。
// @Tags pricing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body PricingEstimateRequest true "计费估算请求kind 默认为 chat.completions"
// @Success 200 {object} PricingEstimateResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 429 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/pricing/estimate [post]
func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
var body map[string]any
@ -511,6 +815,17 @@ func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, estimate)
}
// listRateLimitWindows godoc
// @Summary 列出限流窗口
// @Description 管理端查看当前运行时限流窗口状态。
// @Tags runtime
// @Produce json
// @Security BearerAuth
// @Success 200 {object} RateLimitWindowListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/runtime/rate-limit-windows [get]
func (s *Server) listRateLimitWindows(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListRateLimitWindows(r.Context())
if err != nil {
@ -521,6 +836,17 @@ func (s *Server) listRateLimitWindows(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// listModelRateLimitStatuses godoc
// @Summary 列出模型限流状态
// @Description 管理端查看平台模型维度的限流和冷却状态。
// @Tags runtime
// @Produce json
// @Security BearerAuth
// @Success 200 {object} ModelRateLimitStatusListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/runtime/model-rate-limits [get]
func (s *Server) listModelRateLimitStatuses(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListModelRateLimitStatuses(r.Context())
if err != nil {
@ -531,6 +857,36 @@ func (s *Server) listModelRateLimitStatuses(w http.ResponseWriter, r *http.Reque
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// createTask godoc
// @Summary 创建或执行 AI 任务
// @Description 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible 路径同步返回兼容响应或 SSE 流。
// @Tags tasks
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param X-Async header bool false "true 时异步创建任务并返回 202"
// @Param input body TaskRequest true "AI 任务请求,字段随任务类型变化"
// @Success 200 {object} CompatibleResponse
// @Success 202 {object} TaskAcceptedResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 402 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 429 {object} ErrorEnvelope
// @Failure 502 {object} ErrorEnvelope
// @Router /api/v1/responses [post]
// @Router /api/v1/images/generations [post]
// @Router /api/v1/images/edits [post]
// @Router /api/v1/videos/generations [post]
// @Router /chat/completions [post]
// @Router /v1/chat/completions [post]
// @Router /responses [post]
// @Router /v1/responses [post]
// @Router /images/generations [post]
// @Router /v1/images/generations [post]
// @Router /images/edits [post]
// @Router /v1/images/edits [post]
func (s *Server) createTask(kind string, compatible bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
@ -553,13 +909,13 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
return
}
asyncMode := asyncRequest(r)
responsePlan := planTaskResponse(kind, compatible, body, r)
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
Kind: kind,
Model: model,
RunMode: runModeFromRequest(body),
Async: asyncMode,
Async: responsePlan.asyncMode,
Request: body,
}, user)
if err != nil {
@ -567,7 +923,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
writeError(w, http.StatusInternalServerError, "create task failed")
return
}
if asyncMode {
if responsePlan.asyncMode {
if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil {
writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed")
return
@ -577,65 +933,8 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
}
runCtx, cancelRun := s.requestExecutionContext(r)
defer cancelRun()
if compatible {
if boolValue(body, "stream") {
flusher := prepareCompatibleStream(w)
result, runErr := s.runner.ExecuteStream(runCtx, task, user, func(delta string) error {
if !requestStillConnected(r) {
return nil
}
writeCompatibleDelta(w, kind, model, delta)
if flusher != nil {
flusher.Flush()
}
return nil
})
if runErr != nil {
if !requestStillConnected(r) {
return
}
status := statusFromRunError(runErr)
errorPayload := map[string]any{
"code": runErrorCode(runErr),
"message": runErrorMessage(runErr),
"status": status,
}
if result.Task.ID != "" {
errorPayload["taskId"] = result.Task.ID
}
if result.Task.RequestID != "" {
errorPayload["requestId"] = result.Task.RequestID
}
for key, value := range runErrorDetails(runErr) {
errorPayload[key] = value
}
sendSSE(w, "error", map[string]any{"error": errorPayload})
if flusher != nil {
flusher.Flush()
}
return
}
if !requestStillConnected(r) {
return
}
writeCompatibleDone(w, kind, model, result.Output)
if flusher != nil {
flusher.Flush()
}
return
}
result, runErr := s.runner.Execute(runCtx, task, user)
if runErr != nil {
if !requestStillConnected(r) {
return
}
writeErrorWithDetails(w, statusFromRunError(runErr), runErrorMessage(runErr), runErrorDetails(runErr), runErrorCode(runErr))
return
}
if !requestStillConnected(r) {
return
}
writeJSON(w, http.StatusOK, result.Output)
if responsePlan.compatibleMode {
writeCompatibleTaskResponse(runCtx, w, r, s.runner, kind, model, task, user, responsePlan.streamMode)
return
}
result, runErr := s.runner.Execute(runCtx, task, user)
@ -650,6 +949,29 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
})
}
// createAPIV1ChatCompletions godoc
// @Summary 创建 Chat Completions
// @Description /api/v1/chat/completions 同步执行stream=true 返回 text/event-stream SSEstream=false 或未传返回兼容 JSON该接口忽略 X-Async。
// @Tags tasks
// @Accept json
// @Produce json
// @Produce text/event-stream
// @Security BearerAuth
// @Param X-Async header bool false "该接口忽略此参数"
// @Param input body TaskRequest true "Chat Completions 请求"
// @Success 200 {object} ChatCompletionCompatibleResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 402 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 429 {object} ErrorEnvelope
// @Failure 502 {object} ErrorEnvelope
// @Router /api/v1/chat/completions [post]
func (s *Server) createAPIV1ChatCompletions() http.Handler {
return s.createTask("chat.completions", false)
}
func (s *Server) requestExecutionContext(r *http.Request) (context.Context, context.CancelFunc) {
base := context.WithoutCancel(r.Context())
if s.ctx == nil {
@ -675,11 +997,98 @@ func requestStillConnected(r *http.Request) bool {
}
}
type taskExecutor interface {
Execute(context.Context, store.GatewayTask, *auth.User) (runner.Result, error)
ExecuteStream(context.Context, store.GatewayTask, *auth.User, clients.StreamDelta) (runner.Result, error)
}
func writeCompatibleTaskResponse(runCtx context.Context, w http.ResponseWriter, r *http.Request, executor taskExecutor, kind string, model string, task store.GatewayTask, user *auth.User, streamMode bool) {
if streamMode {
flusher := prepareCompatibleStream(w)
result, runErr := executor.ExecuteStream(runCtx, task, user, func(delta string) error {
if !requestStillConnected(r) {
return nil
}
writeCompatibleDelta(w, kind, model, delta)
if flusher != nil {
flusher.Flush()
}
return nil
})
if runErr != nil {
if !requestStillConnected(r) {
return
}
status := statusFromRunError(runErr)
errorPayload := map[string]any{
"code": runErrorCode(runErr),
"message": runErrorMessage(runErr),
"status": status,
}
if result.Task.ID != "" {
errorPayload["taskId"] = result.Task.ID
}
if result.Task.RequestID != "" {
errorPayload["requestId"] = result.Task.RequestID
}
for key, value := range runErrorDetails(runErr) {
errorPayload[key] = value
}
sendSSE(w, "error", map[string]any{"error": errorPayload})
if flusher != nil {
flusher.Flush()
}
return
}
if !requestStillConnected(r) {
return
}
writeCompatibleDone(w, kind, model, result.Output)
if flusher != nil {
flusher.Flush()
}
return
}
result, runErr := executor.Execute(runCtx, task, user)
if runErr != nil {
if !requestStillConnected(r) {
return
}
writeErrorWithDetails(w, statusFromRunError(runErr), runErrorMessage(runErr), runErrorDetails(runErr), runErrorCode(runErr))
return
}
if !requestStillConnected(r) {
return
}
writeJSON(w, http.StatusOK, result.Output)
}
func asyncRequest(r *http.Request) bool {
value := strings.TrimSpace(strings.ToLower(r.Header.Get("x-async")))
return value == "1" || value == "true" || value == "yes" || value == "on"
}
type taskResponsePlan struct {
asyncMode bool
compatibleMode bool
streamMode bool
}
func planTaskResponse(kind string, compatible bool, body map[string]any, r *http.Request) taskResponsePlan {
asyncMode := asyncRequest(r)
compatibleMode := compatible
if kind == "chat.completions" && !compatible {
asyncMode = false
compatibleMode = true
}
return taskResponsePlan{
asyncMode: asyncMode,
compatibleMode: compatibleMode,
streamMode: boolValue(body, "stream"),
}
}
func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) {
writeJSON(w, http.StatusAccepted, map[string]any{
"taskId": task.ID,
@ -877,6 +1286,24 @@ func matchedRateLimitRule(policy map[string]any, metric string) map[string]any {
return nil
}
// listTasks godoc
// @Summary 列出任务
// @Description 按当前用户列出任务,支持关键字、模型类型、时间范围和分页过滤。
// @Tags tasks
// @Produce json
// @Security BearerAuth
// @Param q query string false "搜索关键字,别名 query"
// @Param modelType query string false "模型类型,别名 type"
// @Param createdFrom query string false "创建时间起点,支持 RFC3339 或日期格式,别名 from"
// @Param createdTo query string false "创建时间终点,支持 RFC3339 或日期格式,别名 to"
// @Param page query int false "页码" default(1)
// @Param pageSize query int false "每页数量,别名 limit" default(50)
// @Success 200 {object} TaskListResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks [get]
// @Router /api/v1/tasks [get]
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {
@ -972,6 +1399,19 @@ func boolValue(body map[string]any, key string) bool {
return value
}
// getTask godoc
// @Summary 获取任务详情
// @Description 返回指定任务的请求、状态、输出和执行摘要。
// @Tags tasks
// @Produce json
// @Security BearerAuth
// @Param taskID path string true "任务 ID"
// @Success 200 {object} store.GatewayTask
// @Failure 401 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks/{taskID} [get]
// @Router /api/v1/tasks/{taskID} [get]
func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err == nil {
@ -986,6 +1426,19 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusInternalServerError, "get task failed")
}
// taskParamPreprocessing godoc
// @Summary 获取任务参数预处理日志
// @Description 返回指定任务在执行前的参数改写、校验或模板处理日志。
// @Tags tasks
// @Produce json
// @Security BearerAuth
// @Param taskID path string true "任务 ID"
// @Success 200 {object} TaskParamPreprocessingLogListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks/{taskID}/param-preprocessing [get]
// @Router /api/v1/tasks/{taskID}/param-preprocessing [get]
func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err != nil {
@ -1006,6 +1459,19 @@ func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request)
writeJSON(w, http.StatusOK, map[string]any{"items": logs})
}
// taskEvents godoc
// @Summary 订阅任务事件
// @Description 以 text/event-stream 返回指定任务的历史事件;无事件时返回 task.accepted 占位事件。
// @Tags tasks
// @Produce text/event-stream
// @Security BearerAuth
// @Param taskID path string true "任务 ID"
// @Success 200 {string} string "Server-Sent Eventsdata 为 store.TaskEvent 或 TaskAcceptedEvent"
// @Failure 401 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks/{taskID}/events [get]
// @Router /api/v1/tasks/{taskID}/events [get]
func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err != nil {

View File

@ -8,6 +8,21 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// createTenant godoc
// @Summary 创建租户
// @Description 管理端创建网关租户tenantKey 和 name 必填。
// @Tags identity
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.GatewayTenantInput true "租户请求"
// @Success 201 {object} store.GatewayTenant
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/tenants [post]
func (s *Server) createTenant(w http.ResponseWriter, r *http.Request) {
var input store.GatewayTenantInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -31,6 +46,23 @@ func (s *Server) createTenant(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, item)
}
// updateTenant godoc
// @Summary 更新租户
// @Description 管理端更新网关租户信息。
// @Tags identity
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param tenantID path string true "租户 ID"
// @Param input body store.GatewayTenantInput true "租户请求"
// @Success 200 {object} store.GatewayTenant
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/tenants/{tenantID} [patch]
func (s *Server) updateTenant(w http.ResponseWriter, r *http.Request) {
var input store.GatewayTenantInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -58,6 +90,19 @@ func (s *Server) updateTenant(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// deleteTenant godoc
// @Summary 删除租户
// @Description 管理端删除网关租户。
// @Tags identity
// @Produce json
// @Security BearerAuth
// @Param tenantID path string true "租户 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/tenants/{tenantID} [delete]
func (s *Server) deleteTenant(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteTenant(r.Context(), r.PathValue("tenantID")); err != nil {
if store.IsNotFound(err) {
@ -71,6 +116,21 @@ func (s *Server) deleteTenant(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// createGatewayUser godoc
// @Summary 创建用户
// @Description 管理端创建网关用户password 为空时不设置本地密码,非空时至少 8 位。
// @Tags identity
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.GatewayUserInput true "用户请求"
// @Success 201 {object} store.GatewayUser
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/users [post]
func (s *Server) createGatewayUser(w http.ResponseWriter, r *http.Request) {
var input store.GatewayUserInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -98,6 +158,23 @@ func (s *Server) createGatewayUser(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, item)
}
// updateGatewayUser godoc
// @Summary 更新用户
// @Description 管理端更新网关用户资料、角色、默认用户组和可选本地密码。
// @Tags identity
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userID path string true "用户 ID"
// @Param input body store.GatewayUserInput true "用户请求"
// @Success 200 {object} store.GatewayUser
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/users/{userID} [patch]
func (s *Server) updateGatewayUser(w http.ResponseWriter, r *http.Request) {
var input store.GatewayUserInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -129,6 +206,19 @@ func (s *Server) updateGatewayUser(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// deleteGatewayUser godoc
// @Summary 删除用户
// @Description 管理端删除网关用户。
// @Tags identity
// @Produce json
// @Security BearerAuth
// @Param userID path string true "用户 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/users/{userID} [delete]
func (s *Server) deleteGatewayUser(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteGatewayUser(r.Context(), r.PathValue("userID")); err != nil {
if store.IsNotFound(err) {
@ -142,6 +232,21 @@ func (s *Server) deleteGatewayUser(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// createUserGroup godoc
// @Summary 创建用户组
// @Description 管理端创建用户组,可配置默认定价、运行策略、限流和配额策略。
// @Tags identity
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.UserGroupInput true "用户组请求"
// @Success 201 {object} store.UserGroup
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/user-groups [post]
func (s *Server) createUserGroup(w http.ResponseWriter, r *http.Request) {
var input store.UserGroupInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -165,6 +270,23 @@ func (s *Server) createUserGroup(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, item)
}
// updateUserGroup godoc
// @Summary 更新用户组
// @Description 管理端更新用户组基础信息和策略配置。
// @Tags identity
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param groupID path string true "用户组 ID"
// @Param input body store.UserGroupInput true "用户组请求"
// @Success 200 {object} store.UserGroup
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/user-groups/{groupID} [patch]
func (s *Server) updateUserGroup(w http.ResponseWriter, r *http.Request) {
var input store.UserGroupInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -192,6 +314,19 @@ func (s *Server) updateUserGroup(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// deleteUserGroup godoc
// @Summary 删除用户组
// @Description 管理端删除用户组。
// @Tags identity
// @Produce json
// @Security BearerAuth
// @Param groupID path string true "用户组 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/user-groups/{groupID} [delete]
func (s *Server) deleteUserGroup(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteUserGroup(r.Context(), r.PathValue("groupID")); err != nil {
if store.IsNotFound(err) {

View File

@ -123,6 +123,16 @@ type catalogGroup struct {
enabled bool
}
// listModelCatalog godoc
// @Summary 列出模型目录
// @Description 聚合平台模型、基础模型、供应商、运行策略和访问规则,返回前端模型目录所需的过滤器、摘要和展示字段。
// @Tags model-catalog
// @Produce json
// @Security BearerAuth
// @Success 200 {object} ModelCatalogResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/v1/model-catalog [get]
func (s *Server) listModelCatalog(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
models, err := s.store.ListModels(ctx)

View File

@ -0,0 +1,282 @@
package httpapi
import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type HealthResponse struct {
OK bool `json:"ok" example:"true"`
Service string `json:"service" example:"easyai-ai-gateway"`
Env string `json:"env" example:"development"`
IdentityMode string `json:"identityMode" example:"standalone"`
}
type ReadyResponse struct {
OK bool `json:"ok" example:"true"`
}
type ErrorEnvelope struct {
Error ErrorPayload `json:"error"`
}
type ErrorPayload struct {
Message string `json:"message" example:"invalid json body"`
Status int `json:"status" example:"400"`
Code string `json:"code,omitempty" example:"rate_limit"`
}
type AuthResponse struct {
AccessToken string `json:"accessToken" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
TokenType string `json:"tokenType" example:"Bearer"`
ExpiresIn int `json:"expiresIn" example:"86400"`
User *auth.User `json:"user"`
}
type ItemListResponse struct {
Items []map[string]interface{} `json:"items"`
}
type PlatformListResponse struct {
Items []store.Platform `json:"items"`
}
type PlatformModelListResponse struct {
Items []store.PlatformModel `json:"items"`
}
type CatalogProviderListResponse struct {
Items []store.CatalogProvider `json:"items"`
}
type BaseModelListResponse struct {
Items []store.BaseModel `json:"items"`
}
type TenantListResponse struct {
Items []store.GatewayTenant `json:"items"`
}
type UserListResponse struct {
Items []store.GatewayUser `json:"items"`
}
type UserGroupListResponse struct {
Items []store.UserGroup `json:"items"`
}
type AccessRuleListResponse struct {
Items []store.AccessRule `json:"items"`
}
type APIKeyListResponse struct {
Items []store.APIKey `json:"items"`
}
type PlayableAPIKeyListResponse struct {
Items []store.PlayableAPIKey `json:"items"`
}
type PricingRuleListResponse struct {
Items []store.PricingRule `json:"items"`
}
type PricingRuleSetListResponse struct {
Items []store.PricingRuleSet `json:"items"`
}
type RuntimePolicySetListResponse struct {
Items []store.RuntimePolicySet `json:"items"`
}
type RateLimitWindowListResponse struct {
Items []store.RateLimitWindow `json:"items"`
}
type ModelRateLimitStatusListResponse struct {
Items []store.ModelRateLimitStatus `json:"items"`
}
type AuditLogListResponse struct {
Items []store.AuditLog `json:"items"`
}
type WalletTransactionListResponse struct {
Items []store.GatewayWalletTransaction `json:"items"`
Total int `json:"total" example:"42"`
Page int `json:"page" example:"1"`
PageSize int `json:"pageSize" example:"50"`
}
type TaskListResponse struct {
Items []store.GatewayTask `json:"items"`
Total int `json:"total" example:"42"`
Page int `json:"page" example:"1"`
PageSize int `json:"pageSize" example:"50"`
}
type TaskParamPreprocessingLogListResponse struct {
Items []store.TaskParamPreprocessingLog `json:"items"`
}
type TaskEventListResponse struct {
Items []store.TaskEvent `json:"items"`
}
type FileStorageChannelListResponse struct {
Items []store.FileStorageChannel `json:"items"`
}
type FileUploadResponse struct {
ID string `json:"id,omitempty" example:"file_abc123"`
URL string `json:"url,omitempty" example:"/static/uploaded/upload-abc123.png"`
Filename string `json:"filename,omitempty" example:"image.png"`
ContentType string `json:"contentType,omitempty" example:"image/png"`
Size int `json:"size,omitempty" example:"1024"`
AssetStorage map[string]interface{} `json:"assetStorage,omitempty"`
}
type ReplacePlatformModelsRequest struct {
Models []store.CreatePlatformModelInput `json:"models"`
}
type TaskAcceptedResponse struct {
TaskID string `json:"taskId" example:"9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
Task store.GatewayTask `json:"task"`
Next TaskNextLinks `json:"next"`
}
type TaskNextLinks struct {
Events string `json:"events" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25/events"`
Detail string `json:"detail" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
}
type TaskAcceptedEvent struct {
TaskID string `json:"taskId" example:"9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
Status string `json:"status" example:"pending"`
}
type PricingEstimateRequest struct {
Kind string `json:"kind" example:"chat.completions"`
Model string `json:"model" example:"gpt-4o-mini"`
Messages []ChatMessage `json:"messages,omitempty"`
Prompt string `json:"prompt,omitempty" example:"A small orange cat"`
MaxTokens int `json:"max_tokens,omitempty" example:"512"`
N int `json:"n,omitempty" example:"1"`
RunMode string `json:"runMode,omitempty" example:"simulation"`
}
type PricingEstimateResponse struct {
Items []map[string]interface{} `json:"items"`
Resolver string `json:"resolver" example:"effective-pricing-v1"`
}
type TaskRequest struct {
Model string `json:"model" example:"gpt-4o-mini"`
Messages []ChatMessage `json:"messages,omitempty"`
Input string `json:"input,omitempty" example:"Tell me a short story"`
Prompt string `json:"prompt,omitempty" example:"A watercolor robot reading a book"`
Stream bool `json:"stream,omitempty" example:"false"`
RunMode string `json:"runMode,omitempty" example:"simulation"`
MaxTokens int `json:"max_tokens,omitempty" example:"512"`
Size string `json:"size,omitempty" example:"1024x1024"`
Duration int `json:"duration,omitempty" example:"5"`
Resolution string `json:"resolution,omitempty" example:"720p"`
}
type ChatCompletionRequest struct {
Model string `json:"model" example:"gpt-4o-mini"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty" example:"0.7"`
MaxTokens int `json:"max_tokens,omitempty" example:"512"`
Stream bool `json:"stream,omitempty" example:"false"`
RunMode string `json:"runMode,omitempty" example:"simulation"`
}
type ChatMessage struct {
Role string `json:"role" example:"user"`
Content string `json:"content" example:"Hello"`
}
type ResponsesRequest struct {
Model string `json:"model" example:"gpt-4o-mini"`
Input interface{} `json:"input" example:"Tell me a short story"`
Stream bool `json:"stream,omitempty" example:"false"`
RunMode string `json:"runMode,omitempty" example:"simulation"`
}
type ImageGenerationRequest struct {
Model string `json:"model" example:"gpt-image-1"`
Prompt string `json:"prompt" example:"A watercolor robot reading a book"`
N int `json:"n,omitempty" example:"1"`
Size string `json:"size,omitempty" example:"1024x1024"`
Quality string `json:"quality,omitempty" example:"standard"`
ResponseFormat string `json:"response_format,omitempty" example:"url"`
RunMode string `json:"runMode,omitempty" example:"simulation"`
}
type ImageEditRequest struct {
Model string `json:"model" example:"gpt-image-1"`
Prompt string `json:"prompt" example:"Add a sunset background"`
Image string `json:"image,omitempty" example:"https://example.com/image.png"`
Mask string `json:"mask,omitempty" example:"https://example.com/mask.png"`
N int `json:"n,omitempty" example:"1"`
Size string `json:"size,omitempty" example:"1024x1024"`
ResponseFormat string `json:"response_format,omitempty" example:"url"`
RunMode string `json:"runMode,omitempty" example:"simulation"`
}
type VideoGenerationRequest struct {
Model string `json:"model" example:"video-model"`
Prompt string `json:"prompt" example:"A cinematic drone shot over mountains"`
Duration int `json:"duration,omitempty" example:"5"`
Resolution string `json:"resolution,omitempty" example:"720p"`
RunMode string `json:"runMode,omitempty" example:"simulation"`
}
type CompatibleResponse struct {
ID string `json:"id" example:"chatcmpl-123"`
Object string `json:"object" example:"chat.completion"`
Model string `json:"model" example:"gpt-4o-mini"`
Choices []map[string]interface{} `json:"choices,omitempty"`
Usage map[string]interface{} `json:"usage,omitempty"`
}
type ChatCompletionCompatibleResponse struct {
ID string `json:"id" example:"chatcmpl-123"`
Object string `json:"object" example:"chat.completion"`
Created int64 `json:"created,omitempty" example:"1710000000"`
Model string `json:"model" example:"gpt-4o-mini"`
Choices []ChatCompletionChoice `json:"choices"`
Usage *ChatCompletionUsage `json:"usage,omitempty"`
}
type ChatCompletionChoice struct {
Index int `json:"index" example:"0"`
Message ChatCompletionChoiceMessage `json:"message"`
FinishReason string `json:"finish_reason,omitempty" example:"stop"`
}
type ChatCompletionChoiceMessage struct {
Role string `json:"role" example:"assistant"`
Content string `json:"content" example:"Hello"`
}
type ChatCompletionUsage struct {
PromptTokens int `json:"prompt_tokens,omitempty" example:"12"`
CompletionTokens int `json:"completion_tokens,omitempty" example:"8"`
TotalTokens int `json:"total_tokens,omitempty" example:"20"`
}
type NetworkProxyConfigResponse struct {
GlobalHTTPProxy string `json:"globalHttpProxy" example:"http://127.0.0.1:7890"`
GlobalHTTPProxySet bool `json:"globalHttpProxySet" example:"true"`
GlobalHTTPProxySource string `json:"globalHttpProxySource" example:"env"`
}
type WalletAdjustmentResponse struct {
Account store.GatewayWalletAccount `json:"account"`
Before store.GatewayWalletAccount `json:"before"`
Transaction store.GatewayWalletTransaction `json:"transaction"`
AuditLog store.AuditLog `json:"auditLog"`
}

View File

@ -9,6 +9,17 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// listPricingRuleSets godoc
// @Summary 列出定价规则集
// @Description 管理端返回可分配给平台、模型、租户或用户组的定价规则集。
// @Tags pricing
// @Produce json
// @Security BearerAuth
// @Success 200 {object} PricingRuleSetListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/pricing/rule-sets [get]
func (s *Server) listPricingRuleSets(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListPricingRuleSets(r.Context())
if err != nil {
@ -19,6 +30,21 @@ func (s *Server) listPricingRuleSets(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// createPricingRuleSet godoc
// @Summary 创建定价规则集
// @Description 管理端创建定价规则集ruleSetKey、name 和至少一条 rule 必填。
// @Tags pricing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.PricingRuleSetInput true "定价规则集请求"
// @Success 201 {object} store.PricingRuleSet
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/pricing/rule-sets [post]
func (s *Server) createPricingRuleSet(w http.ResponseWriter, r *http.Request) {
var input store.PricingRuleSetInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -42,6 +68,23 @@ func (s *Server) createPricingRuleSet(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, item)
}
// updatePricingRuleSet godoc
// @Summary 更新定价规则集
// @Description 管理端更新定价规则集及其规则列表。
// @Tags pricing
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param ruleSetID path string true "定价规则集 ID"
// @Param input body store.PricingRuleSetInput true "定价规则集请求"
// @Success 200 {object} store.PricingRuleSet
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/pricing/rule-sets/{ruleSetID} [patch]
func (s *Server) updatePricingRuleSet(w http.ResponseWriter, r *http.Request) {
var input store.PricingRuleSetInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -69,6 +112,19 @@ func (s *Server) updatePricingRuleSet(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// deletePricingRuleSet godoc
// @Summary 删除定价规则集
// @Description 管理端删除非默认定价规则集;默认规则集受保护。
// @Tags pricing
// @Produce json
// @Security BearerAuth
// @Param ruleSetID path string true "定价规则集 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/pricing/rule-sets/{ruleSetID} [delete]
func (s *Server) deletePricingRuleSet(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeletePricingRuleSet(r.Context(), r.PathValue("ruleSetID")); err != nil {
if store.IsNotFound(err) {

View File

@ -9,6 +9,17 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// listRuntimePolicySets godoc
// @Summary 列出运行策略集
// @Description 管理端返回可分配给平台、模型或用户组的运行策略集。
// @Tags runtime
// @Produce json
// @Security BearerAuth
// @Success 200 {object} RuntimePolicySetListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/runtime/policy-sets [get]
func (s *Server) listRuntimePolicySets(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListRuntimePolicySets(r.Context())
if err != nil {
@ -19,6 +30,17 @@ func (s *Server) listRuntimePolicySets(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// getRunnerPolicy godoc
// @Summary 获取 Runner 策略
// @Description 管理端获取当前生效的默认 Runner 调度策略。
// @Tags runtime
// @Produce json
// @Security BearerAuth
// @Success 200 {object} store.RunnerPolicy
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/runtime/runner-policy [get]
func (s *Server) getRunnerPolicy(w http.ResponseWriter, r *http.Request) {
item, err := s.store.GetActiveRunnerPolicy(r.Context())
if err != nil {
@ -29,6 +51,20 @@ func (s *Server) getRunnerPolicy(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, item)
}
// updateRunnerPolicy godoc
// @Summary 更新 Runner 策略
// @Description 管理端写入默认 Runner 调度策略。
// @Tags runtime
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.RunnerPolicyInput true "Runner 策略请求"
// @Success 200 {object} store.RunnerPolicy
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/runtime/runner-policy [patch]
func (s *Server) updateRunnerPolicy(w http.ResponseWriter, r *http.Request) {
var input store.RunnerPolicyInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -45,10 +81,26 @@ func (s *Server) updateRunnerPolicy(w http.ResponseWriter, r *http.Request) {
}
type updatePlatformDynamicPriorityRequest struct {
DynamicPriority *int `json:"dynamicPriority"`
Reset bool `json:"reset"`
DynamicPriority *int `json:"dynamicPriority" example:"10"`
Reset bool `json:"reset" example:"false"`
}
// updatePlatformDynamicPriority godoc
// @Summary 更新平台动态优先级
// @Description 管理端调整平台运行时动态优先级reset 为 true 时清空动态值。
// @Tags runtime
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param platformID path string true "平台 ID"
// @Param input body updatePlatformDynamicPriorityRequest true "动态优先级请求"
// @Success 200 {object} store.Platform
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/platforms/{platformID}/dynamic-priority [patch]
func (s *Server) updatePlatformDynamicPriority(w http.ResponseWriter, r *http.Request) {
var input updatePlatformDynamicPriorityRequest
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -82,6 +134,21 @@ func (s *Server) updatePlatformDynamicPriority(w http.ResponseWriter, r *http.Re
writeJSON(w, http.StatusOK, item)
}
// createRuntimePolicySet godoc
// @Summary 创建运行策略集
// @Description 管理端创建运行策略集policyKey 和 name 必填。
// @Tags runtime
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param input body store.RuntimePolicySetInput true "运行策略集请求"
// @Success 201 {object} store.RuntimePolicySet
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/runtime/policy-sets [post]
func (s *Server) createRuntimePolicySet(w http.ResponseWriter, r *http.Request) {
var input store.RuntimePolicySetInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -105,6 +172,23 @@ func (s *Server) createRuntimePolicySet(w http.ResponseWriter, r *http.Request)
writeJSON(w, http.StatusCreated, item)
}
// updateRuntimePolicySet godoc
// @Summary 更新运行策略集
// @Description 管理端更新运行策略集及其限流、重试、超时等策略配置。
// @Tags runtime
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param policySetID path string true "运行策略集 ID"
// @Param input body store.RuntimePolicySetInput true "运行策略集请求"
// @Success 200 {object} store.RuntimePolicySet
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/runtime/policy-sets/{policySetID} [patch]
func (s *Server) updateRuntimePolicySet(w http.ResponseWriter, r *http.Request) {
var input store.RuntimePolicySetInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -132,6 +216,19 @@ func (s *Server) updateRuntimePolicySet(w http.ResponseWriter, r *http.Request)
writeJSON(w, http.StatusOK, item)
}
// deleteRuntimePolicySet godoc
// @Summary 删除运行策略集
// @Description 管理端删除非默认运行策略集;默认策略集受保护。
// @Tags runtime
// @Produce json
// @Security BearerAuth
// @Param policySetID path string true "运行策略集 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/runtime/policy-sets/{policySetID} [delete]
func (s *Server) deleteRuntimePolicySet(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteRuntimePolicySet(r.Context(), r.PathValue("policySetID")); err != nil {
if store.IsNotFound(err) {

View File

@ -126,7 +126,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
mux.Handle("GET /api/v1/playground/models", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listPlayableModels)))
mux.Handle("GET /api/admin/runtime/rate-limit-windows", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listRateLimitWindows)))
mux.Handle("GET /api/admin/runtime/model-rate-limits", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listModelRateLimitStatuses)))
mux.Handle("POST /api/v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", false)))
mux.Handle("POST /api/v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createAPIV1ChatCompletions()))
mux.Handle("POST /api/v1/responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", false)))
mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false)))
mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false)))

View File

@ -18,6 +18,16 @@ const simulationVideoMP4Base64 = "AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1wNDEAAA
var simulationVideoMP4 = mustDecodeSimulationAsset(simulationVideoMP4Base64)
// serveSimulationAsset godoc
// @Summary 获取模拟资源
// @Description 返回本地模拟模式使用的图片、视频封面或短视频资源。
// @Tags simulation
// @Produce image/svg+xml
// @Produce video/mp4
// @Param asset path string true "资源文件名,可选 image.svg、image.png、image-edit.svg、image-edit.png、video-poster.svg、video.mp4"
// @Success 200 {file} binary
// @Failure 404 {string} string "Not Found"
// @Router /static/simulation/{asset} [get]
func serveSimulationAsset(w http.ResponseWriter, r *http.Request) {
asset := strings.ToLower(strings.TrimSpace(r.PathValue("asset")))
switch asset {

View File

@ -9,10 +9,28 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
)
// serveGeneratedStaticAsset godoc
// @Summary 获取本地生成资源
// @Description 从本地生成资源目录读取图片、视频等任务产物;不存在时返回 404。
// @Tags static
// @Produce octet-stream
// @Param asset path string true "资源文件名"
// @Success 200 {file} file
// @Failure 404 {string} string "Not Found"
// @Router /static/generated/{asset} [get]
func (s *Server) serveGeneratedStaticAsset(w http.ResponseWriter, r *http.Request) {
s.serveLocalStaticAsset(w, r, s.cfg.LocalGeneratedStorageDir, config.DefaultLocalGeneratedStorageDir)
}
// serveUploadedStaticAsset godoc
// @Summary 获取本地上传资源
// @Description 从本地上传资源目录读取用户上传文件;不存在时返回 404。
// @Tags static
// @Produce octet-stream
// @Param asset path string true "资源文件名"
// @Success 200 {file} file
// @Failure 404 {string} string "Not Found"
// @Router /static/uploaded/{asset} [get]
func (s *Server) serveUploadedStaticAsset(w http.ResponseWriter, r *http.Request) {
s.serveLocalStaticAsset(w, r, s.cfg.LocalUploadedStorageDir, config.DefaultLocalUploadedStorageDir)
}

View File

@ -8,6 +8,17 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// listFileStorageChannels godoc
// @Summary 列出文件存储通道
// @Description 返回所有未删除的文件存储通道,用于管理上传与生成资源回传策略。
// @Tags system
// @Produce json
// @Security BearerAuth
// @Success 200 {object} FileStorageChannelListResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/system/file-storage/channels [get]
func (s *Server) listFileStorageChannels(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListFileStorageChannels(r.Context())
if err != nil {
@ -18,6 +29,17 @@ func (s *Server) listFileStorageChannels(w http.ResponseWriter, r *http.Request)
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
// getFileStorageSettings godoc
// @Summary 获取文件存储设置
// @Description 返回文件存储系统设置;数据库对象尚未创建时返回默认设置。
// @Tags system
// @Produce json
// @Security BearerAuth
// @Success 200 {object} store.FileStorageSettings
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/system/file-storage/settings [get]
func (s *Server) getFileStorageSettings(w http.ResponseWriter, r *http.Request) {
settings, err := s.store.GetFileStorageSettings(r.Context())
if err != nil {
@ -32,6 +54,20 @@ func (s *Server) getFileStorageSettings(w http.ResponseWriter, r *http.Request)
writeJSON(w, http.StatusOK, settings)
}
// updateFileStorageSettings godoc
// @Summary 更新文件存储设置
// @Description 更新生成资源上传策略等文件存储系统设置。
// @Tags system
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param body body store.FileStorageSettingsInput true "文件存储设置"
// @Success 200 {object} store.FileStorageSettings
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/system/file-storage/settings [patch]
func (s *Server) updateFileStorageSettings(w http.ResponseWriter, r *http.Request) {
var input store.FileStorageSettingsInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -47,6 +83,21 @@ func (s *Server) updateFileStorageSettings(w http.ResponseWriter, r *http.Reques
writeJSON(w, http.StatusOK, settings)
}
// createFileStorageChannel godoc
// @Summary 创建文件存储通道
// @Description 创建文件存储通道,当前主要用于配置 server-main OpenAPI 上传通道。
// @Tags system
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param body body store.FileStorageChannelInput true "文件存储通道"
// @Success 201 {object} store.FileStorageChannel
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/system/file-storage/channels [post]
func (s *Server) createFileStorageChannel(w http.ResponseWriter, r *http.Request) {
var input store.FileStorageChannelInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -70,6 +121,23 @@ func (s *Server) createFileStorageChannel(w http.ResponseWriter, r *http.Request
writeJSON(w, http.StatusCreated, item)
}
// updateFileStorageChannel godoc
// @Summary 更新文件存储通道
// @Description 更新指定文件存储通道的名称、凭证、场景、优先级、状态和重试策略。
// @Tags system
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param channelID path string true "文件存储通道 ID"
// @Param body body store.FileStorageChannelInput true "文件存储通道"
// @Success 200 {object} store.FileStorageChannel
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 409 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/system/file-storage/channels/{channelID} [patch]
func (s *Server) updateFileStorageChannel(w http.ResponseWriter, r *http.Request) {
var input store.FileStorageChannelInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
@ -107,6 +175,19 @@ func (s *Server) updateFileStorageChannel(w http.ResponseWriter, r *http.Request
writeJSON(w, http.StatusOK, item)
}
// deleteFileStorageChannel godoc
// @Summary 删除文件存储通道
// @Description 软删除指定文件存储通道。
// @Tags system
// @Produce json
// @Security BearerAuth
// @Param channelID path string true "文件存储通道 ID"
// @Success 204 "No Content"
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/admin/system/file-storage/channels/{channelID} [delete]
func (s *Server) deleteFileStorageChannel(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteFileStorageChannel(r.Context(), r.PathValue("channelID")); err != nil {
if store.IsNotFound(err) {

View File

@ -7,6 +7,17 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
// getWallet godoc
// @Summary 获取钱包摘要
// @Description 返回当前用户的钱包账户、余额和最近消费摘要,可按 currency 过滤。
// @Tags wallet
// @Produce json
// @Security BearerAuth
// @Param currency query string false "币种" default(USD)
// @Success 200 {object} store.WalletSummary
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/wallet [get]
func (s *Server) getWallet(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {
@ -22,6 +33,24 @@ func (s *Server) getWallet(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, summary)
}
// listWalletTransactions godoc
// @Summary 列出钱包交易
// @Description 返回当前用户的钱包交易流水,支持关键字、方向、交易类型、时间范围和分页过滤。
// @Tags wallet
// @Produce json
// @Security BearerAuth
// @Param q query string false "搜索关键字,别名 query"
// @Param direction query string false "交易方向"
// @Param transactionType query string false "交易类型"
// @Param createdFrom query string false "创建时间起点,别名 from"
// @Param createdTo query string false "创建时间终点,别名 to"
// @Param page query int false "页码" default(1)
// @Param pageSize query int false "每页数量,别名 limit" default(50)
// @Success 200 {object} WalletTransactionListResponse
// @Failure 400 {object} ErrorEnvelope
// @Failure 401 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/wallet/transactions [get]
func (s *Server) listWalletTransactions(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {

View File

@ -1,6 +1,7 @@
package runner
import (
"errors"
"strings"
"time"
@ -210,6 +211,9 @@ func failureMetrics(err error, simulated bool) (string, map[string]any, time.Tim
metrics["error"] = err.Error()
metrics["errorCategory"] = info.Category
metrics["retryable"] = retryable
if detail := rateLimitFailureDetail(err); len(detail) > 0 {
metrics["rateLimit"] = detail
}
}
if meta.StatusCode > 0 {
metrics["statusCode"] = meta.StatusCode
@ -226,6 +230,47 @@ func failureMetrics(err error, simulated bool) (string, map[string]any, time.Tim
return meta.RequestID, metrics, meta.ResponseStartedAt, meta.ResponseFinishedAt, meta.ResponseDurationMS
}
func rateLimitFailureDetail(err error) map[string]any {
var limitErr *store.RateLimitExceededError
if !errors.As(err, &limitErr) {
return nil
}
detail := map[string]any{
"scopeType": limitErr.ScopeType,
"scopeKey": limitErr.ScopeKey,
"scopeName": limitErr.ScopeName,
"metric": limitErr.Metric,
"limit": limitErr.Limit,
"amount": limitErr.Amount,
"current": limitErr.Current,
"used": limitErr.Used,
"reserved": limitErr.Reserved,
"projected": limitErr.Projected,
"windowSeconds": limitErr.WindowSeconds,
"retryable": limitErr.Retryable,
"exceeded": map[string]any{
"metric": limitErr.Metric,
"current": limitErr.Current,
"amount": limitErr.Amount,
"projected": limitErr.Projected,
"limit": limitErr.Limit,
},
}
if limitErr.RetryAfter > 0 {
detail["retryAfterMs"] = limitErr.RetryAfter.Milliseconds()
}
if !limitErr.ResetAt.IsZero() {
detail["resetAt"] = limitErr.ResetAt.UTC().Format(time.RFC3339Nano)
}
if len(limitErr.ScopeMetadata) > 0 {
detail["scopeMetadata"] = limitErr.ScopeMetadata
}
if len(limitErr.Policy) > 0 {
detail["rateLimitPolicy"] = limitErr.Policy
}
return detail
}
func mergeMetrics(values ...map[string]any) map[string]any {
out := map[string]any{}
for _, value := range values {

View File

@ -55,6 +55,8 @@ func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
"openai": clients.OpenAIClient{HTTPClient: httpClients.none},
"gemini": clients.GeminiClient{HTTPClient: httpClients.none},
"volces": clients.VolcesClient{HTTPClient: httpClients.none},
"keling": clients.KelingClient{HTTPClient: httpClients.none},
"kling": clients.KelingClient{HTTPClient: httpClients.none},
"simulation": clients.SimulationClient{},
},
httpClients: httpClients,
@ -82,6 +84,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
}
}
if err := validateRequest(task.Kind, body); err != nil {
s.recordFailedAttempt(ctx, failedAttemptRecord{
Task: task,
Body: body,
AttemptNo: task.AttemptCount + 1,
Code: "bad_request",
Cause: err,
Simulated: task.RunMode == "simulation",
Scope: "request_validation",
Reason: "request_validation_failed",
ModelType: modelType,
})
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
if finishErr != nil {
return Result{}, finishErr
@ -90,6 +103,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
}
candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType, user)
if err != nil {
s.recordFailedAttempt(ctx, failedAttemptRecord{
Task: task,
Body: body,
AttemptNo: task.AttemptCount + 1,
Code: store.ModelCandidateErrorCode(err),
Cause: err,
Simulated: task.RunMode == "simulation",
Scope: "candidate_selection",
Reason: "candidate_selection_failed",
ModelType: modelType,
})
failed, finishErr := s.failTask(ctx, task.ID, store.ModelCandidateErrorCode(err), err.Error(), task.RunMode == "simulation", err)
if finishErr != nil {
return Result{}, finishErr
@ -98,6 +122,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
}
firstCandidateBody := body
normalizedModelType := modelType
attemptNo := task.AttemptCount
var firstPreprocessing parameterPreprocessingLog
if len(candidates) > 0 {
preprocessing := preprocessRequestWithLog(task.Kind, body, candidates[0])
@ -106,9 +131,20 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
normalizedModelType = candidates[0].ModelType
if preprocessing.Err != nil {
clientErr := parameterPreprocessClientError(preprocessing.Err)
if logErr := s.recordTaskParameterPreprocessing(ctx, task.ID, "", 0, candidates[0], firstPreprocessing); logErr != nil {
return Result{}, logErr
}
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
Task: task,
Body: firstCandidateBody,
Candidate: &candidates[0],
AttemptNo: attemptNo + 1,
Code: clients.ErrorCode(clientErr),
Cause: clientErr,
Simulated: task.RunMode == "simulation",
Scope: "parameter_preprocessing",
Reason: "parameter_preprocessing_failed",
ExtraMetrics: []map[string]any{parameterPreprocessingMetrics(firstPreprocessing)},
Preprocessing: &firstPreprocessing,
ModelType: normalizedModelType,
})
failed, finishErr := s.failTask(ctx, task.ID, clients.ErrorCode(clientErr), clientErr.Error(), task.RunMode == "simulation", clientErr, parameterPreprocessingMetrics(firstPreprocessing))
if finishErr != nil {
return Result{}, finishErr
@ -121,9 +157,20 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil {
if errors.Is(err, store.ErrInsufficientWalletBalance) {
if logErr := s.recordTaskParameterPreprocessing(ctx, task.ID, "", 0, candidates[0], firstPreprocessing); logErr != nil {
return Result{}, logErr
}
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
Task: task,
Body: firstCandidateBody,
Candidate: &candidates[0],
AttemptNo: attemptNo + 1,
Code: "insufficient_balance",
Cause: err,
Simulated: task.RunMode == "simulation",
Scope: "wallet_balance",
Reason: "wallet_balance_check_failed",
ExtraMetrics: []map[string]any{parameterPreprocessingMetrics(firstPreprocessing)},
Preprocessing: &firstPreprocessing,
ModelType: normalizedModelType,
})
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err, parameterPreprocessingMetrics(firstPreprocessing))
if finishErr != nil {
return Result{}, finishErr
@ -143,7 +190,6 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
}
maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy)
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
attemptNo := task.AttemptCount
var lastErr error
var lastCandidate store.RuntimeModelCandidate
var lastPreprocessing *parameterPreprocessingLog
@ -162,6 +208,20 @@ candidatesLoop:
lastPreprocessing = &preprocessingLog
if preprocessing.Err != nil {
lastErr = parameterPreprocessClientError(preprocessing.Err)
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
Task: task,
Body: preprocessing.Body,
Candidate: &candidate,
AttemptNo: nextAttemptNo,
Code: clients.ErrorCode(lastErr),
Cause: lastErr,
Simulated: isSimulation(task, candidate),
Scope: "parameter_preprocessing",
Reason: "parameter_preprocessing_failed",
ExtraMetrics: []map[string]any{parameterPreprocessingMetrics(preprocessingLog)},
Preprocessing: &preprocessingLog,
ModelType: candidate.ModelType,
})
break candidatesLoop
}
candidateBody := preprocessing.Body
@ -222,6 +282,19 @@ candidatesLoop:
}
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay}
}
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
Task: task,
Body: candidateBody,
Candidate: &candidate,
AttemptNo: nextAttemptNo,
Code: clients.ErrorCode(err),
Cause: err,
Simulated: isSimulation(task, candidate),
Scope: "rate_limit",
Reason: "local_rate_limit_blocked",
ExtraMetrics: []map[string]any{parameterPreprocessingMetrics(preprocessing.Log)},
ModelType: candidate.ModelType,
})
break candidatesLoop
}
attemptNo = nextAttemptNo
@ -616,6 +689,110 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess
return failed, nil
}
type failedAttemptRecord struct {
Task store.GatewayTask
Body map[string]any
Candidate *store.RuntimeModelCandidate
AttemptNo int
Code string
Cause error
Simulated bool
Scope string
Reason string
ExtraMetrics []map[string]any
Preprocessing *parameterPreprocessingLog
ModelType string
}
func (s *Service) recordFailedAttempt(ctx context.Context, input failedAttemptRecord) int {
attemptNo := input.AttemptNo
if attemptNo <= 0 {
attemptNo = input.Task.AttemptCount + 1
}
code := firstNonEmptyString(input.Code, clients.ErrorCode(input.Cause))
message := ""
if input.Cause != nil {
message = input.Cause.Error()
}
retryable := clients.IsRetryable(input.Cause)
requestID, failure, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(input.Cause, input.Simulated)
scope := firstNonEmptyString(input.Scope, "pre_provider")
reason := firstNonEmptyString(input.Reason, "pre_provider_failed")
trace := failureTraceEntryWithReason(input.Cause, retryable, scope, reason)
statusCode := clients.ErrorResponseMetadata(input.Cause).StatusCode
category := failureCategory(strings.ToLower(strings.TrimSpace(code)), statusCode, message)
if code != "" {
failure["errorCode"] = code
trace["errorCode"] = code
}
if category != "" {
failure["errorCategory"] = category
trace["category"] = category
}
failure["failureScope"] = scope
failure["failureReason"] = reason
failure["trace"] = []any{trace}
baseMetrics := map[string]any{
"attempt": attemptNo,
"kind": input.Task.Kind,
"runMode": input.Task.RunMode,
"requestedModel": input.Task.Model,
"simulated": input.Simulated,
}
if input.ModelType != "" {
baseMetrics["modelType"] = input.ModelType
}
var platformID, platformModelID, clientID, queueKey string
if input.Candidate != nil {
baseMetrics = attemptMetrics(*input.Candidate, attemptNo, input.Simulated)
baseMetrics["kind"] = input.Task.Kind
baseMetrics["runMode"] = input.Task.RunMode
baseMetrics["requestedModel"] = input.Task.Model
platformID = input.Candidate.PlatformID
platformModelID = input.Candidate.PlatformModelID
clientID = input.Candidate.ClientID
queueKey = input.Candidate.QueueKey
}
metrics := mergeMetrics(append([]map[string]any{baseMetrics, failure}, input.ExtraMetrics...)...)
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
TaskID: input.Task.ID,
AttemptNo: attemptNo,
PlatformID: platformID,
PlatformModelID: platformModelID,
ClientID: clientID,
QueueKey: queueKey,
Status: "running",
Simulated: input.Simulated,
RequestSnapshot: input.Body,
Metrics: metrics,
})
if err != nil {
s.logger.Warn("record failed task attempt failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
return attemptNo
}
if input.Preprocessing != nil && input.Candidate != nil {
if err := s.recordTaskParameterPreprocessing(ctx, input.Task.ID, attemptID, attemptNo, *input.Candidate, *input.Preprocessing); err != nil {
s.logger.Warn("record failed attempt parameter preprocessing failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
}
}
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "failed",
Retryable: retryable,
RequestID: requestID,
Metrics: metrics,
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS,
ErrorCode: code,
ErrorMessage: message,
}); err != nil {
s.logger.Warn("finish failed task attempt failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
}
return attemptNo
}
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error, candidate store.RuntimeModelCandidate) (store.GatewayTask, time.Duration, error) {
delay := localRateLimitRetryAfter(cause)
if delay <= 0 {

View File

@ -7,8 +7,12 @@ import (
)
func failureTraceEntry(err error, retryable bool) map[string]any {
return failureTraceEntryWithReason(err, retryable, "client", "client_call_failed")
}
func failureTraceEntryWithReason(err error, retryable bool, scope string, reason string) map[string]any {
info := failureInfoFromError(err)
entry := policyTraceEntry("failure", "client", "failed", "client_call_failed", policyRuleMatch{}, info)
entry := policyTraceEntry("failure", scope, "failed", reason, policyRuleMatch{}, info)
entry["retryable"] = retryable
return entry
}

View File

@ -0,0 +1,25 @@
UPDATE base_model_catalog
SET capabilities = jsonb_set(
jsonb_set(capabilities, '{omni_video,input_audio}', 'false'::jsonb, true),
'{omni_video,max_audios}', '0'::jsonb, true
),
metadata = jsonb_set(
jsonb_set(metadata, '{rawModel,capabilities,omni_video,input_audio}', 'false'::jsonb, true),
'{rawModel,capabilities,omni_video,max_audios}', '0'::jsonb, true
),
updated_at = now()
WHERE provider_key = 'keling'
AND provider_model_name IN ('kling-video-o1', 'kling-v3-omni')
AND capabilities ? 'omni_video';
UPDATE platform_models m
SET capabilities = jsonb_set(
jsonb_set(m.capabilities, '{omni_video,input_audio}', 'false'::jsonb, true),
'{omni_video,max_audios}', '0'::jsonb, true
),
updated_at = now()
FROM integration_platforms p
WHERE m.platform_id = p.id
AND p.provider = 'keling'
AND COALESCE(NULLIF(m.provider_model_name, ''), m.model_name) IN ('kling-video-o1', 'kling-v3-omni')
AND m.capabilities ? 'omni_video';

View File

@ -19,6 +19,14 @@
"command": "go run ./cmd/migrate"
}
},
"openapi": {
"executor": "nx:run-commands",
"outputs": ["{projectRoot}/docs/swagger.json", "{projectRoot}/docs/swagger.yaml"],
"options": {
"cwd": "apps/api",
"command": "go run github.com/swaggo/swag/cmd/swag@v1.16.4 init --parseInternal -d ./cmd/gateway,./internal/httpapi,./internal/store,./internal/auth -g main.go -o docs --outputTypes json,yaml"
}
},
"test": {
"executor": "nx:run-commands",
"outputs": ["{workspaceRoot}/coverage/apps/api"],

View File

@ -812,7 +812,7 @@ function TaskRecord(props: { task: GatewayTask; token: string; onCopyRequestId:
<TableCell>{props.task.apiKeyName || props.task.apiKeyPrefix || props.task.apiKeyId || '-'}</TableCell>
<TableCell className="taskRecordTokenCell">{tokenUsage}</TableCell>
<TableCell>{chargeText}</TableCell>
<TableCell>{formatDuration(props.task.responseDurationMs)}</TableCell>
<TableCell>{formatDuration(taskDurationMs(props.task))}</TableCell>
<TableCell>{formatDateTime(props.task.createdAt)}</TableCell>
<TableCell>
<Button type="button" variant="ghost" size="sm" className="taskRecordJsonButton" title={taskErrorText(props.task) || '查看原始 JSON'} onClick={() => props.onOpenJson(props.task)}>
@ -971,28 +971,33 @@ function TaskAttemptPopoverContent(props: { task: GatewayTask }) {
const attempts = props.task.attempts ?? [];
return (
<span className="taskRecordAttemptPopover" role="tooltip">
{attempts.map((attempt) => (
<span
key={attempt.id || `${props.task.id}-${attempt.attemptNo}`}
className={`taskRecordAttemptDetail ${attempt.status === 'failed' ? 'failed' : attempt.status === 'succeeded' ? 'succeeded' : ''}`}
>
<span className="taskRecordAttemptDetailHeader">
<strong>#{attempt.attemptNo} {taskAttemptTarget(attempt)}</strong>
<Badge variant={attempt.status === 'succeeded' ? 'success' : attempt.status === 'failed' ? 'destructive' : 'secondary'}>{taskAttemptStatusText(attempt.status)}</Badge>
</span>
<small>{taskAttemptMeta(attempt)}</small>
{attempt.status === 'failed' && <span className="taskRecordAttemptError">{taskAttemptFailureReason(attempt)}</span>}
{taskAttemptTrace(attempt).length > 0 && (
<span className="taskRecordAttemptTrace">
{taskAttemptTrace(attempt).map((entry, index) => (
<span key={`${attempt.id || attempt.attemptNo}-trace-${index}`} className="taskRecordAttemptTraceItem">
{taskAttemptTraceText(entry)}
</span>
))}
{attempts.map((attempt) => {
const trace = taskAttemptTrace(attempt);
const rateLimitText = taskAttemptRateLimitText(attempt);
return (
<span
key={attempt.id || `${props.task.id}-${attempt.attemptNo}`}
className={`taskRecordAttemptDetail ${attempt.status === 'failed' ? 'failed' : attempt.status === 'succeeded' ? 'succeeded' : ''}`}
>
<span className="taskRecordAttemptDetailHeader">
<strong>#{attempt.attemptNo} {taskAttemptTarget(attempt)}</strong>
<Badge variant={attempt.status === 'succeeded' ? 'success' : attempt.status === 'failed' ? 'destructive' : 'secondary'}>{taskAttemptStatusText(attempt.status)}</Badge>
</span>
)}
</span>
))}
<small>{taskAttemptMeta(attempt)}</small>
{attempt.status === 'failed' && <span className="taskRecordAttemptError">{taskAttemptFailureReason(attempt)}</span>}
{(rateLimitText || trace.length > 0) && (
<span className="taskRecordAttemptTrace">
{rateLimitText && <span className="taskRecordAttemptTraceItem">{rateLimitText}</span>}
{trace.map((entry, index) => (
<span key={`${attempt.id || attempt.attemptNo}-trace-${index}`} className="taskRecordAttemptTraceItem">
{taskAttemptTraceText(entry)}
</span>
))}
</span>
)}
</span>
);
})}
</span>
);
}
@ -1024,7 +1029,7 @@ function taskAttemptMeta(attempt: NonNullable<GatewayTask['attempts']>[number])
attempt.providerModelName || attempt.modelName || attempt.modelAlias,
attempt.requestId ? `RequestID ${attempt.requestId}` : '',
statusCode ? `状态码 ${statusCode}` : '',
attempt.responseDurationMs ? formatDuration(attempt.responseDurationMs) : '',
formatDuration(attemptDurationMs(attempt)),
].filter(Boolean);
return values.join(' · ') || attempt.clientId || '-';
}
@ -1055,6 +1060,29 @@ function taskAttemptTrace(attempt: NonNullable<GatewayTask['attempts']>[number])
return raw.filter((item): item is Record<string, unknown> => Boolean(item) && typeof item === 'object' && !Array.isArray(item));
}
function taskAttemptRateLimitText(attempt: NonNullable<GatewayTask['attempts']>[number]) {
const detail = metadataObject(attempt.metrics, 'rateLimit');
if (!Object.keys(detail).length) return '';
const scopeName = objectString(detail, 'scopeName') || objectString(detail, 'scopeKey') || '限流对象';
const metric = objectString(detail, 'metric') || 'rate_limit';
const current = metadataNumber(detail, 'current');
const amount = metadataNumber(detail, 'amount');
const projected = metadataNumber(detail, 'projected');
const limit = metadataNumber(detail, 'limit');
const windowSeconds = metadataNumber(detail, 'windowSeconds');
const retryAfterMs = metadataNumber(detail, 'retryAfterMs');
const values = [
`限流 ${scopeName} · ${metric}`,
current !== null ? `当前 ${formatCellValue(current)}` : '',
amount !== null ? `本次 ${formatCellValue(amount)}` : '',
projected !== null ? `预计 ${formatCellValue(projected)}` : '',
limit !== null ? `限制 ${formatCellValue(limit)}` : '',
windowSeconds !== null ? `窗口 ${Math.trunc(windowSeconds)}` : '',
retryAfterMs !== null ? `${formatDuration(Math.trunc(retryAfterMs))} 后可重试` : '',
].filter(Boolean);
return values.join(' · ');
}
function taskAttemptTraceText(entry: Record<string, unknown>) {
const event = objectString(entry, 'event');
const action = objectString(entry, 'action');
@ -1116,6 +1144,12 @@ function taskAttemptTraceReasonLabel(reason: string) {
client_retryable: '客户端标记可重试',
client_non_retryable: '客户端标记不可重试',
same_client_max_attempts: '达到本平台最大尝试次数',
request_validation_failed: '请求校验失败',
candidate_selection_failed: '候选模型选择失败',
parameter_preprocessing_failed: '参数预处理失败',
wallet_balance_check_failed: '余额校验失败',
local_rate_limit_blocked: '本地限流拦截',
pre_provider_failed: '调用上游前失败',
local_rate_limit_wait_queue: '本地限流排队等待',
failover_time_budget_exceeded: '超过全局切换时间预算',
runner_policy_disabled: '全局调度策略停用',
@ -1321,10 +1355,41 @@ function tokenValue(value: unknown) {
return Number.isFinite(numericValue) ? numericValue : null;
}
function taskDurationMs(task: GatewayTask) {
return (
positiveDurationMs(task.responseDurationMs) ??
elapsedDurationMs(task.responseStartedAt, task.responseFinishedAt) ??
elapsedDurationMs(task.createdAt, task.finishedAt)
);
}
function attemptDurationMs(attempt: NonNullable<GatewayTask['attempts']>[number]) {
return (
positiveDurationMs(attempt.responseDurationMs) ??
elapsedDurationMs(attempt.responseStartedAt, attempt.responseFinishedAt) ??
elapsedDurationMs(attempt.startedAt, attempt.finishedAt)
);
}
function positiveDurationMs(value?: number) {
if (value === undefined || value === null) return undefined;
const numericValue = Number(value);
return Number.isFinite(numericValue) && numericValue > 0 ? numericValue : undefined;
}
function elapsedDurationMs(start?: string, end?: string) {
if (!start || !end) return undefined;
const startedAt = new Date(start).getTime();
const finishedAt = new Date(end).getTime();
if (!Number.isFinite(startedAt) || !Number.isFinite(finishedAt)) return undefined;
const elapsed = finishedAt - startedAt;
return elapsed > 0 ? Math.max(1, Math.round(elapsed)) : undefined;
}
function formatDuration(value?: number) {
if (value === undefined || value === null) return '-';
const milliseconds = Math.max(0, Math.round(value));
if (milliseconds === 0) return '0秒';
if (milliseconds === 0) return '-';
if (milliseconds < 1000) return `${milliseconds}毫秒`;
const totalSeconds = Math.round(milliseconds / 1000);
const hours = Math.floor(totalSeconds / 3600);

View File

@ -9,7 +9,8 @@
"test": "nx run-many -t test -p api web",
"lint": "nx run-many -t lint -p web contracts",
"db:create": "scripts/create-database.sh",
"migrate": "nx run api:migrate"
"migrate": "nx run api:migrate",
"openapi": "nx run api:openapi"
},
"devDependencies": {
"@nx/vite": "^21.0.0",