From ca7e76e815f518781bbd6e2ef48242836457b303 Mon Sep 17 00:00:00 2001 From: wangbo Date: Mon, 11 May 2026 08:48:02 +0800 Subject: [PATCH] fix gateway loopback validation chains --- apps/api/internal/auth/auth.go | 3 + apps/api/internal/clients/simulation.go | 28 +- apps/api/internal/clients/volces.go | 13 +- .../httpapi/core_flow_integration_test.go | 545 +++++++++++++++++- apps/api/internal/httpapi/handlers.go | 63 ++ apps/api/internal/httpapi/server.go | 1 + apps/api/internal/runner/limits.go | 29 +- apps/api/internal/runner/pricing.go | 16 +- apps/api/internal/runner/runtime_policy.go | 84 +++ apps/api/internal/runner/service.go | 45 +- apps/api/internal/store/base_models.go | 84 ++- apps/api/internal/store/candidates.go | 33 +- apps/api/internal/store/platform_models.go | 28 +- apps/api/internal/store/postgres.go | 17 +- apps/api/internal/store/pricing_rules.go | 57 ++ apps/api/internal/store/runtime_policies.go | 27 + apps/api/internal/store/runtime_types.go | 7 +- apps/api/internal/store/tasks_runtime.go | 155 +++++ apps/api/migrations/0001_init.sql | 9 +- .../0002_invitation_relationship_only.sql | 9 +- .../migrations/0003_phase1_seed_runtime.sql | 16 +- .../0019_platform_model_type_array.sql | 138 +++++ ...020_runtime_policy_set_owns_rate_limit.sql | 11 + .../migrations/0021_base_model_type_array.sql | 108 ++++ ...022_sync_platform_model_type_from_base.sql | 9 + .../0023_wallet_task_billing_schema.sql | 48 ++ apps/web/src/App.tsx | 21 +- apps/web/src/api.ts | 4 + apps/web/src/app-state.ts | 1 + apps/web/src/components/Dashboard.tsx | 2 +- apps/web/src/pages/AdminPage.tsx | 6 +- apps/web/src/pages/ModelsPage.tsx | 34 +- apps/web/src/pages/PlaygroundPage.tsx | 4 +- apps/web/src/pages/WorkspacePage.tsx | 58 +- .../pages/admin/AccessPermissionEditor.tsx | 2 +- .../pages/admin/PlatformManagementPanel.tsx | 6 +- apps/web/src/pages/admin/platform-form.ts | 2 +- apps/web/src/pages/playground-media.tsx | 2 +- apps/web/src/styles.css | 17 + apps/web/src/types.ts | 4 +- docs/test/loopback-test-checklist.md | 20 +- packages/contracts/src/index.ts | 4 +- 42 files changed, 1641 insertions(+), 129 deletions(-) create mode 100644 apps/api/internal/runner/runtime_policy.go create mode 100644 apps/api/migrations/0019_platform_model_type_array.sql create mode 100644 apps/api/migrations/0020_runtime_policy_set_owns_rate_limit.sql create mode 100644 apps/api/migrations/0021_base_model_type_array.sql create mode 100644 apps/api/migrations/0022_sync_platform_model_type_from_base.sql create mode 100644 apps/api/migrations/0023_wallet_task_billing_schema.sql diff --git a/apps/api/internal/auth/auth.go b/apps/api/internal/auth/auth.go index 3d76b91..db90018 100644 --- a/apps/api/internal/auth/auth.go +++ b/apps/api/internal/auth/auth.go @@ -40,6 +40,7 @@ type User struct { APIKeySecret string `json:"apiKeySecret,omitempty"` APIKeyName string `json:"apiKeyName,omitempty"` APIKeyPrefix string `json:"apiKeyPrefix,omitempty"` + APIKeyScopes []string `json:"apiKeyScopes,omitempty"` } type contextKey string @@ -137,6 +138,7 @@ func (a *Authenticator) verifyJWT(tokenString string) (*User, error) { APIKeySecret: stringClaim(claims, "apiKeySecret"), APIKeyName: stringClaim(claims, "apiKeyName"), APIKeyPrefix: stringClaim(claims, "apiKeyPrefix"), + APIKeyScopes: stringSliceClaim(claims, "apiKeyScopes"), } if user.Source == "" { user.Source = "gateway" @@ -167,6 +169,7 @@ func (a *Authenticator) SignJWT(user *User, ttl time.Duration) (string, error) { "apiKeyId": user.APIKeyID, "apiKeyName": user.APIKeyName, "apiKeyPrefix": user.APIKeyPrefix, + "apiKeyScopes": user.APIKeyScopes, "iat": now.Unix(), "exp": now.Add(ttl).Unix(), } diff --git a/apps/api/internal/clients/simulation.go b/apps/api/internal/clients/simulation.go index f6521be..c4e7085 100644 --- a/apps/api/internal/clients/simulation.go +++ b/apps/api/internal/clients/simulation.go @@ -34,7 +34,8 @@ func (c SimulationClient) Run(ctx context.Context, request Request) (Response, e } } responseFinishedAt := time.Now() - if profile == "retryable_failure" { + switch profile { + case "retryable_failure": return Response{}, &ClientError{ Code: "server_error", Message: "simulated retryable failure", @@ -44,8 +45,27 @@ func (c SimulationClient) Run(ctx context.Context, request Request) (Response, e ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt), Retryable: true, } - } - if profile == "fatal_failure" || profile == "non_retryable_failure" { + case "rate_limit", "overloaded": + return Response{}, &ClientError{ + Code: profile, + Message: "simulated " + profile, + RequestID: "simulated-request", + ResponseStartedAt: responseStartedAt, + ResponseFinishedAt: responseFinishedAt, + ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt), + Retryable: true, + } + case "invalid_api_key": + return Response{}, &ClientError{ + Code: "invalid_api_key", + Message: "simulated invalid_api_key", + RequestID: "simulated-request", + ResponseStartedAt: responseStartedAt, + ResponseFinishedAt: responseFinishedAt, + ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt), + Retryable: false, + } + case "fatal_failure", "non_retryable_failure": return Response{}, &ClientError{ Code: "bad_request", Message: "simulated non-retryable failure", @@ -184,7 +204,7 @@ func simulatedVideoData(request Request) []any { } func simulatedUsage(request Request) Usage { - if request.ModelType == "chat" || request.Kind == "responses" { + if request.ModelType == "chat" || request.ModelType == "text_generate" || request.Kind == "responses" { return Usage{InputTokens: 12, OutputTokens: 8, TotalTokens: 20} } return Usage{} diff --git a/apps/api/internal/clients/volces.go b/apps/api/internal/clients/volces.go index 3144e9e..acdbe18 100644 --- a/apps/api/internal/clients/volces.go +++ b/apps/api/internal/clients/volces.go @@ -251,9 +251,18 @@ func buildVolcesContentFromBody(body map[string]any) []map[string]any { } } - appendURLContent("image_url", "first_frame", firstNonEmptyStringValue(body, "first_frame", "firstFrame")) + firstFrame := firstNonEmptyStringValue(body, "first_frame", "firstFrame") + appendURLContent("image_url", "first_frame", firstFrame) appendURLContent("image_url", "last_frame", firstNonEmptyStringValue(body, "last_frame", "lastFrame")) - for _, url := range firstNonEmptyStringListFromAny(body["image"], body["images"], body["image_url"], body["imageUrl"], body["image_urls"], body["imageUrls"], body["reference_image"], body["referenceImage"]) { + imageURLs := firstNonEmptyStringListFromAny(body["image"], body["images"], body["image_url"], body["imageUrl"], body["image_urls"], body["imageUrls"]) + if firstFrame == "" && len(imageURLs) > 0 { + appendURLContent("image_url", "first_frame", imageURLs[0]) + imageURLs = imageURLs[1:] + } + for _, url := range imageURLs { + appendURLContent("image_url", "reference_image", url) + } + for _, url := range firstNonEmptyStringListFromAny(body["reference_image"], body["referenceImage"]) { appendURLContent("image_url", "reference_image", url) } for _, url := range firstNonEmptyStringListFromAny(body["video"], body["video_url"], body["videoUrl"], body["reference_video"], body["referenceVideo"]) { diff --git a/apps/api/internal/httpapi/core_flow_integration_test.go b/apps/api/internal/httpapi/core_flow_integration_test.go index 26f42f5..b056b5e 100644 --- a/apps/api/internal/httpapi/core_flow_integration_test.go +++ b/apps/api/internal/httpapi/core_flow_integration_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "io" "log/slog" + "math" "net/http" "net/http/httptest" "os" @@ -123,8 +124,43 @@ func TestCoreLocalFlow(t *testing.T) { if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil { t.Fatalf("promote smoke user: %v", err) } + var smokeGatewayUserID string + if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil { + t.Fatalf("read smoke gateway user id: %v", err) + } doJSON(t, server.URL, http.MethodGet, "/api/admin/models", apiKeyResponse.Secret, nil, http.StatusForbidden, nil) + var chatOnlyAPIKeyResponse struct { + Secret string `json:"secret"` + APIKey struct { + ID string `json:"id"` + } `json:"apiKey"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/api-keys", loginResponse.AccessToken, map[string]any{ + "name": "chat only key", + "scopes": []string{"chat"}, + }, http.StatusCreated, &chatOnlyAPIKeyResponse) + var taskCountBefore int + if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountBefore); err != nil { + t.Fatalf("count tasks before scoped request: %v", err) + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", chatOnlyAPIKeyResponse.Secret, map[string]any{ + "model": "gpt-image-1", + "prompt": "scope should block this", + }, http.StatusForbidden, nil) + doJSON(t, server.URL, http.MethodPost, "/api/v1/pricing/estimate", chatOnlyAPIKeyResponse.Secret, map[string]any{ + "kind": "images.generations", + "model": "gpt-image-1", + "prompt": "scope should block this estimate", + }, http.StatusForbidden, nil) + var taskCountAfter int + if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountAfter); err != nil { + t.Fatalf("count tasks after scoped request: %v", err) + } + if taskCountAfter != taskCountBefore { + t.Fatalf("scoped API key rejection should happen before task creation, before=%d after=%d", taskCountBefore, taskCountAfter) + } + inviteCode := "INVITE-" + suffixText if _, err := testPool.Exec(ctx, ` INSERT INTO gateway_invitations (invite_code, max_uses, metadata) @@ -322,6 +358,317 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { t.Fatalf("unexpected image edit task: %+v", imageEditResponse.Task) } + var gptImageModelTypesRaw []byte + if err := testPool.QueryRow(ctx, ` +SELECT model_type +FROM platform_models +WHERE model_name = 'gpt-image-1' + AND model_type @> '["image_generate","image_edit"]'::jsonb +LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil { + t.Fatalf("gpt-image-1 platform model should store both image model types: %v", err) + } + var gptImageModelTypes []string + if err := json.Unmarshal(gptImageModelTypesRaw, &gptImageModelTypes); err != nil { + t.Fatalf("decode gpt-image-1 model_type: %v raw=%s", err, string(gptImageModelTypesRaw)) + } + if !stringSliceContains(gptImageModelTypes, "image_generate") || !stringSliceContains(gptImageModelTypes, "image_edit") { + t.Fatalf("gpt-image-1 model_type should include generation and edit types: %+v", gptImageModelTypes) + } + + deniedModel := "permission-deny-smoke-" + suffixText + var deniedPlatformModel struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{ + "canonicalModelKey": "openai:gpt-4o-mini", + "modelName": deniedModel, + "modelAlias": deniedModel, + "modelType": []string{"text_generate"}, + "displayName": "Permission Deny Smoke", + }, http.StatusCreated, &deniedPlatformModel) + doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{ + "subjectType": "api_key", + "subjectId": chatOnlyAPIKeyResponse.APIKey.ID, + "resourceType": "platform_model", + "resourceId": deniedPlatformModel.ID, + "effect": "deny", + "priority": 10, + "minPermissionLevel": 0, + "status": "active", + }, http.StatusCreated, nil) + var deniedTask struct { + Task struct { + Status string `json:"status"` + ErrorCode string `json:"errorCode"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{ + "model": deniedModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "permission deny"}}, + }, http.StatusAccepted, &deniedTask) + if deniedTask.Task.Status != "failed" || deniedTask.Task.ErrorCode != "no_model_candidate" { + t.Fatalf("deny access rule should hide denied model from runtime candidates: %+v", deniedTask.Task) + } + var restrictedModels struct { + Items []struct { + ID string `json:"id"` + ModelName string `json:"modelName"` + } `json:"items"` + } + doJSON(t, server.URL, http.MethodGet, "/api/v1/playground/models", chatOnlyAPIKeyResponse.Secret, nil, http.StatusOK, &restrictedModels) + if modelListContains(restrictedModels.Items, deniedPlatformModel.ID) { + t.Fatalf("deny access rule should hide denied model from playable list: %+v", restrictedModels.Items) + } + + controlledModel := "permission-allow-smoke-" + suffixText + var controlledPlatformModel struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{ + "canonicalModelKey": "openai:gpt-4o-mini", + "modelName": controlledModel, + "modelAlias": controlledModel, + "modelType": []string{"text_generate"}, + "displayName": "Permission Allow Smoke", + }, http.StatusCreated, &controlledPlatformModel) + doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{ + "subjectType": "api_key", + "subjectId": apiKeyResponse.APIKey.ID, + "resourceType": "platform_model", + "resourceId": controlledPlatformModel.ID, + "effect": "allow", + "priority": 10, + "minPermissionLevel": 0, + "status": "active", + }, http.StatusCreated, nil) + var blockedControlledTask struct { + Task struct { + Status string `json:"status"` + ErrorCode string `json:"errorCode"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{ + "model": controlledModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "allow should block other keys"}}, + }, http.StatusAccepted, &blockedControlledTask) + if blockedControlledTask.Task.Status != "failed" || blockedControlledTask.Task.ErrorCode != "no_model_candidate" { + t.Fatalf("allow access rule should make the resource unavailable to unmatched subjects: %+v", blockedControlledTask.Task) + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{ + "subjectType": "api_key", + "subjectId": chatOnlyAPIKeyResponse.APIKey.ID, + "resourceType": "platform_model", + "resourceId": controlledPlatformModel.ID, + "effect": "allow", + "priority": 10, + "minPermissionLevel": 0, + "status": "active", + }, http.StatusCreated, nil) + var allowedControlledTask struct { + Task struct { + Status string `json:"status"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{ + "model": controlledModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "allow should pass"}}, + }, http.StatusAccepted, &allowedControlledTask) + if allowedControlledTask.Task.Status != "succeeded" { + t.Fatalf("matching allow access rule should make the controlled model usable: %+v", allowedControlledTask.Task) + } + + var customPricingRuleSet struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/pricing/rule-sets", loginResponse.AccessToken, map[string]any{ + "ruleSetKey": "smoke-pricing-" + suffixText, + "name": "Smoke Pricing", + "currency": "resource", + "rules": []map[string]any{ + {"ruleKey": "text_input", "displayName": "Text Input", "resourceType": "text_input", "unit": "1k_tokens", "basePrice": 1}, + {"ruleKey": "text_output", "displayName": "Text Output", "resourceType": "text_output", "unit": "1k_tokens", "basePrice": 2}, + {"ruleKey": "image", "displayName": "Image", "resourceType": "image", "unit": "image", "basePrice": 7}, + {"ruleKey": "image_edit", "displayName": "Image Edit", "resourceType": "image_edit", "unit": "image", "basePrice": 11}, + {"ruleKey": "video", "displayName": "Video", "resourceType": "video", "unit": "video", "basePrice": 13}, + }, + }, http.StatusCreated, &customPricingRuleSet) + pricingModel := "pricing-smoke-" + suffixText + var pricingPlatformModel map[string]any + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{ + "canonicalModelKey": "openai:gpt-4o-mini", + "modelName": pricingModel, + "modelAlias": pricingModel, + "modelType": []string{"text_generate"}, + "displayName": "Pricing Smoke", + "pricingRuleSetId": customPricingRuleSet.ID, + }, http.StatusCreated, &pricingPlatformModel) + if _, err := testPool.Exec(ctx, ` +UPDATE gateway_wallet_accounts +SET balance = 100, total_spent = 0, updated_at = now() +WHERE gateway_user_id = $1::uuid + AND currency = 'resource'`, smokeGatewayUserID); err != nil { + t.Fatalf("seed wallet balance: %v", err) + } + var walletBalanceBefore float64 + if err := testPool.QueryRow(ctx, ` +SELECT balance::float8 +FROM gateway_wallet_accounts +WHERE gateway_user_id = $1::uuid + AND currency = 'resource'`, smokeGatewayUserID).Scan(&walletBalanceBefore); err != nil { + t.Fatalf("read wallet balance before pricing task: %v", err) + } + var pricingTask struct { + Task struct { + ID string `json:"id"` + Status string `json:"status"` + BillingSummary map[string]any `json:"billingSummary"` + FinalChargeAmount float64 `json:"finalChargeAmount"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ + "model": pricingModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "priced ping"}}, + }, http.StatusAccepted, &pricingTask) + if pricingTask.Task.Status != "succeeded" || !floatNear(pricingTask.Task.FinalChargeAmount, 0.028) { + t.Fatalf("custom pricing rule set should drive text billing, got task=%+v", pricingTask.Task) + } + var walletBalanceAfter float64 + var walletSpentAfter float64 + if err := testPool.QueryRow(ctx, ` +SELECT balance::float8, total_spent::float8 +FROM gateway_wallet_accounts +WHERE gateway_user_id = $1::uuid + AND currency = 'resource'`, smokeGatewayUserID).Scan(&walletBalanceAfter, &walletSpentAfter); err != nil { + t.Fatalf("read wallet balance after pricing task: %v", err) + } + if !floatNear(walletBalanceAfter, walletBalanceBefore-pricingTask.Task.FinalChargeAmount) || !floatNear(walletSpentAfter, pricingTask.Task.FinalChargeAmount) { + t.Fatalf("task billing should debit wallet balance and spent totals, before=%f after=%f spent=%f task=%+v", walletBalanceBefore, walletBalanceAfter, walletSpentAfter, pricingTask.Task) + } + var walletTransactionAmount float64 + if err := testPool.QueryRow(ctx, ` +SELECT amount::float8 +FROM gateway_wallet_transactions +WHERE reference_type = 'gateway_task' + AND reference_id = $1 + AND transaction_type = 'task_billing'`, pricingTask.Task.ID).Scan(&walletTransactionAmount); err != nil { + t.Fatalf("read task billing wallet transaction: %v", err) + } + if !floatNear(walletTransactionAmount, pricingTask.Task.FinalChargeAmount) { + t.Fatalf("task billing transaction amount=%f want=%f", walletTransactionAmount, pricingTask.Task.FinalChargeAmount) + } + + rateLimitedModel := "rate-limit-smoke-" + suffixText + var rateLimitPolicySet struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{ + "policyKey": "smoke-rate-limit-" + suffixText, + "name": "Smoke Rate Limit", + "retryPolicy": map[string]any{ + "enabled": false, + "maxAttempts": 1, + }, + "rateLimitPolicy": map[string]any{ + "rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": 60}}, + }, + }, http.StatusCreated, &rateLimitPolicySet) + var rateLimitPlatformModel map[string]any + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{ + "canonicalModelKey": "openai:gpt-4o-mini", + "modelName": rateLimitedModel, + "modelAlias": rateLimitedModel, + "modelType": []string{"text_generate"}, + "displayName": "Rate Limit Smoke", + "runtimePolicySetId": rateLimitPolicySet.ID, + "runtimePolicyOverride": map[string]any{}, + }, http.StatusCreated, &rateLimitPlatformModel) + var rateLimitTaskOne struct { + Task struct { + Status string `json:"status"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ + "model": rateLimitedModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "first"}}, + }, http.StatusAccepted, &rateLimitTaskOne) + if rateLimitTaskOne.Task.Status != "succeeded" { + t.Fatalf("first rate-limited task should succeed: %+v", rateLimitTaskOne.Task) + } + var rateLimitTaskTwo struct { + Task struct { + Status string `json:"status"` + ErrorCode string `json:"errorCode"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ + "model": rateLimitedModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "second"}}, + }, http.StatusAccepted, &rateLimitTaskTwo) + if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" { + t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task) + } + + videoRouteModel := "video-route-smoke-" + suffixText + var videoRoutePlatformModel map[string]any + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{ + "canonicalModelKey": "openai:gpt-4o-mini", + "modelName": videoRouteModel, + "modelAlias": videoRouteModel, + "modelType": []string{"video_generate", "image_to_video"}, + "displayName": "Video Route Smoke", + }, http.StatusCreated, &videoRoutePlatformModel) + var textToVideoTask struct { + Task struct { + Status string `json:"status"` + ModelType string `json:"modelType"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/videos/generations", apiKeyResponse.Secret, map[string]any{ + "model": videoRouteModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "prompt": "text to video route", + }, http.StatusAccepted, &textToVideoTask) + if textToVideoTask.Task.Status != "succeeded" || textToVideoTask.Task.ModelType != "video_generate" { + t.Fatalf("text-to-video request should use video_generate model_type: %+v", textToVideoTask.Task) + } + var imageToVideoTask struct { + Task struct { + Status string `json:"status"` + ModelType string `json:"modelType"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/videos/generations", apiKeyResponse.Secret, map[string]any{ + "model": videoRouteModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "prompt": "image to video route", + "image": "https://example.com/source.png", + }, http.StatusAccepted, &imageToVideoTask) + if imageToVideoTask.Task.Status != "succeeded" || imageToVideoTask.Task.ModelType != "image_to_video" { + t.Fatalf("image-to-video request should use image_to_video model_type: %+v", imageToVideoTask.Task) + } + failoverModel := "phase1-failover-" + suffixText var failedPlatform struct { ID string `json:"id"` @@ -353,7 +700,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { "canonicalModelKey": "openai:gpt-4o-mini", "modelName": failoverModel, "modelAlias": failoverModel, - "modelType": "chat", + "modelType": []string{"text_generate"}, "displayName": "Failover Smoke", "retryPolicy": map[string]any{"enabled": true, "maxAttempts": 2}, }, http.StatusCreated, &platformModel) @@ -376,6 +723,142 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { t.Fatalf("failover task should succeed through second client: %+v", failoverTask.Task) } + var degradePolicySet struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{ + "policyKey": "smoke-degrade-" + suffixText, + "name": "Smoke Degrade", + "degradePolicy": map[string]any{ + "enabled": true, + "keywords": []string{"rate_limit"}, + "cooldownSeconds": 600, + }, + }, http.StatusCreated, °radePolicySet) + degradeModel := "degrade-smoke-" + suffixText + var degradedPlatform struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{ + "provider": "openai", + "platformKey": "openai-degrade-" + suffixText, + "name": "OpenAI Degrade Failure", + "baseUrl": "https://api.openai.com/v1", + "authType": "bearer", + "credentials": map[string]any{"mode": "simulation", "simulationFailure": "rate_limit"}, + "priority": 30, + }, http.StatusCreated, °radedPlatform) + var degradeSuccessPlatform struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{ + "provider": "openai", + "platformKey": "openai-degrade-success-" + suffixText, + "name": "OpenAI Degrade Success", + "baseUrl": "https://api.openai.com/v1", + "authType": "bearer", + "credentials": map[string]any{"mode": "simulation"}, + "priority": 40, + }, http.StatusCreated, °radeSuccessPlatform) + for _, item := range []struct { + platformID string + runtimePolicySetID string + }{ + {platformID: degradedPlatform.ID, runtimePolicySetID: degradePolicySet.ID}, + {platformID: degradeSuccessPlatform.ID}, + } { + var platformModel map[string]any + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+item.platformID+"/models", loginResponse.AccessToken, map[string]any{ + "canonicalModelKey": "openai:gpt-4o-mini", + "modelName": degradeModel, + "modelAlias": degradeModel, + "modelType": []string{"text_generate"}, + "displayName": "Degrade Smoke", + "runtimePolicySetId": item.runtimePolicySetID, + }, http.StatusCreated, &platformModel) + } + var degradeTask struct { + Task struct { + ID string `json:"id"` + Status string `json:"status"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ + "model": degradeModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "degrade please"}}, + }, http.StatusAccepted, °radeTask) + if degradeTask.Task.Status != "succeeded" { + t.Fatalf("degrade task should fail over after cooling down failed platform: %+v", degradeTask.Task) + } + var cooledDown bool + if err := testPool.QueryRow(ctx, `SELECT COALESCE(cooldown_until > now(), false) FROM integration_platforms WHERE id = $1::uuid`, degradedPlatform.ID).Scan(&cooledDown); err != nil { + t.Fatalf("read degraded platform cooldown: %v", err) + } + if !cooledDown { + t.Fatal("degrade policy should set platform cooldown_until") + } + + var autoDisablePolicySet struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{ + "policyKey": "smoke-auto-disable-" + suffixText, + "name": "Smoke Auto Disable", + "autoDisablePolicy": map[string]any{ + "enabled": true, + "keywords": []string{"invalid_api_key"}, + "threshold": 1, + }, + }, http.StatusCreated, &autoDisablePolicySet) + autoDisableModel := "auto-disable-smoke-" + suffixText + var invalidKeyPlatform struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{ + "provider": "openai", + "platformKey": "openai-invalid-key-" + suffixText, + "name": "OpenAI Invalid Key", + "baseUrl": "https://api.openai.com/v1", + "authType": "bearer", + "credentials": map[string]any{"mode": "simulation", "simulationFailure": "invalid_api_key"}, + "priority": 50, + }, http.StatusCreated, &invalidKeyPlatform) + var invalidKeyPlatformModel map[string]any + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+invalidKeyPlatform.ID+"/models", loginResponse.AccessToken, map[string]any{ + "canonicalModelKey": "openai:gpt-4o-mini", + "modelName": autoDisableModel, + "modelAlias": autoDisableModel, + "modelType": []string{"text_generate"}, + "displayName": "Auto Disable Smoke", + "runtimePolicySetId": autoDisablePolicySet.ID, + }, http.StatusCreated, &invalidKeyPlatformModel) + var autoDisableTask struct { + Task struct { + Status string `json:"status"` + ErrorCode string `json:"errorCode"` + } `json:"task"` + } + doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ + "model": autoDisableModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "disable please"}}, + }, http.StatusAccepted, &autoDisableTask) + if autoDisableTask.Task.Status != "failed" || autoDisableTask.Task.ErrorCode != "invalid_api_key" { + t.Fatalf("auto disable task should fail with invalid_api_key: %+v", autoDisableTask.Task) + } + var invalidKeyPlatformStatus string + if err := testPool.QueryRow(ctx, `SELECT status FROM integration_platforms WHERE id = $1::uuid`, invalidKeyPlatform.ID).Scan(&invalidKeyPlatformStatus); err != nil { + t.Fatalf("read invalid key platform status: %v", err) + } + if invalidKeyPlatformStatus != "disabled" { + t.Fatalf("auto disable policy should disable platform, got %q", invalidKeyPlatformStatus) + } + var taskDetail struct { ID string `json:"id"` Status string `json:"status"` @@ -393,6 +876,21 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { if taskDetail.APIKeyName != apiKeyResponse.APIKey.Name || taskDetail.RequestID == "" || taskDetail.Usage["totalTokens"] == nil || taskDetail.FinalChargeAmount <= 0 { t.Fatalf("task detail should expose enriched record fields: %+v", taskDetail) } + var taskList struct { + Items []struct { + ID string `json:"id"` + Status string `json:"status"` + APIKeyName string `json:"apiKeyName"` + ModelType string `json:"modelType"` + FinalCharge float64 `json:"finalChargeAmount"` + ErrorCode string `json:"errorCode"` + ErrorMessage string `json:"errorMessage"` + } `json:"items"` + } + doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks?limit=20", loginResponse.AccessToken, nil, http.StatusOK, &taskList) + if !taskListContains(taskList.Items, taskResponse.Task.ID) || !taskListContains(taskList.Items, pricingTask.Task.ID) { + t.Fatalf("task list should include persisted task records, got %+v", taskList.Items) + } req, err := http.NewRequest(http.MethodGet, server.URL+"/api/v1/tasks/"+taskResponse.Task.ID+"/events", nil) if err != nil { @@ -411,6 +909,9 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { if !bytes.Contains(body, []byte("task.progress")) { t.Fatalf("events response should include progress events body=%s", string(body)) } + if !bytes.Contains(body, []byte("task.billing.settled")) { + t.Fatalf("events response should include billing settlement event body=%s", string(body)) + } req, err = http.NewRequest(http.MethodGet, server.URL+"/api/v1/tasks/"+failoverTask.Task.ID+"/events", nil) if err != nil { @@ -508,3 +1009,45 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri } } } + +func stringSliceContains(values []string, target string) bool { + for _, value := range values { + if value == target { + return true + } + } + return false +} + +func modelListContains(values []struct { + ID string `json:"id"` + ModelName string `json:"modelName"` +}, target string) bool { + for _, value := range values { + if value.ID == target { + return true + } + } + return false +} + +func taskListContains(values []struct { + ID string `json:"id"` + Status string `json:"status"` + APIKeyName string `json:"apiKeyName"` + ModelType string `json:"modelType"` + FinalCharge float64 `json:"finalChargeAmount"` + ErrorCode string `json:"errorCode"` + ErrorMessage string `json:"errorMessage"` +}, target string) bool { + for _, value := range values { + if value.ID == target { + return true + } + } + return false +} + +func floatNear(value float64, expected float64) bool { + return math.Abs(value-expected) < 0.000001 +} diff --git a/apps/api/internal/httpapi/handlers.go b/apps/api/internal/httpapi/handlers.go index b4adff5..6d400fb 100644 --- a/apps/api/internal/httpapi/handlers.go +++ b/apps/api/internal/httpapi/handlers.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "strconv" "strings" "time" @@ -468,6 +469,10 @@ func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "model is required") return } + if !apiKeyScopeAllowed(user, kind) { + writeError(w, http.StatusForbidden, "api key scope does not allow this capability") + return + } estimate, err := s.runner.Estimate(r.Context(), kind, model, body, user) if err != nil { if errors.Is(err, store.ErrNoModelCandidate) { @@ -509,6 +514,10 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { writeError(w, http.StatusBadRequest, "model is required") return } + if !apiKeyScopeAllowed(user, kind) { + writeError(w, http.StatusForbidden, "api key scope does not allow this capability") + return + } task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{ Kind: kind, @@ -567,6 +576,36 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { }) } +func apiKeyScopeAllowed(user *auth.User, kind string) bool { + if user == nil || strings.TrimSpace(user.APIKeyID) == "" || len(user.APIKeyScopes) == 0 { + return true + } + required := scopeForTaskKind(kind) + for _, scope := range user.APIKeyScopes { + scope = strings.TrimSpace(strings.ToLower(scope)) + if scope == "*" || scope == "all" || scope == required { + return true + } + if required == "chat" && (scope == "text" || scope == "text_generate") { + return true + } + } + return false +} + +func scopeForTaskKind(kind string) string { + switch kind { + case "chat.completions", "responses": + return "chat" + case "images.generations", "images.edits": + return "image" + case "videos.generations": + return "video" + default: + return kind + } +} + func statusFromRunError(err error) int { switch { case errors.Is(err, store.ErrNoModelCandidate): @@ -578,6 +617,30 @@ func statusFromRunError(err error) int { } } +func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) { + user, ok := auth.UserFromContext(r.Context()) + if !ok { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + limit := 50 + if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" { + parsed, err := strconv.Atoi(raw) + if err != nil || parsed <= 0 { + writeError(w, http.StatusBadRequest, "invalid limit") + return + } + limit = parsed + } + tasks, err := s.store.ListTasks(r.Context(), user, limit) + if err != nil { + s.logger.Error("list tasks failed", "error", err) + writeError(w, http.StatusInternalServerError, "list tasks failed") + return + } + writeJSON(w, http.StatusOK, map[string]any{"items": tasks}) +} + func boolValue(body map[string]any, key string) bool { value, _ := body[key].(bool) return value diff --git a/apps/api/internal/httpapi/server.go b/apps/api/internal/httpapi/server.go index 07d5803..47abf86 100644 --- a/apps/api/internal/httpapi/server.go +++ b/apps/api/internal/httpapi/server.go @@ -101,6 +101,7 @@ func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Han mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false))) mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false))) mux.Handle("POST /api/v1/videos/generations", server.auth.Require(auth.PermissionBasic, server.createTask("videos.generations", false))) + mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks))) mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask))) mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents))) mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true))) diff --git a/apps/api/internal/runner/limits.go b/apps/api/internal/runner/limits.go index fbf0bb5..c83d366 100644 --- a/apps/api/internal/runner/limits.go +++ b/apps/api/internal/runner/limits.go @@ -18,15 +18,36 @@ func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, ca } func effectiveRateLimitPolicy(candidate store.RuntimeModelCandidate) map[string]any { - if hasRules(candidate.ModelRateLimitPolicy) { - return candidate.ModelRateLimitPolicy + policy := candidate.PlatformRateLimitPolicy + if hasRules(candidate.RuntimeRateLimitPolicy) { + policy = mergeMap(policy, candidate.RuntimeRateLimitPolicy) } - if hasRules(candidate.PlatformRateLimitPolicy) { - return candidate.PlatformRateLimitPolicy + if nested, ok := candidate.RuntimePolicyOverride["rateLimitPolicy"].(map[string]any); ok && len(nested) > 0 { + policy = mergeMap(policy, nested) + } + if hasRules(candidate.ModelRateLimitPolicy) { + policy = mergeMap(policy, candidate.ModelRateLimitPolicy) + } + if hasRules(policy) { + return policy } return nil } +func effectiveRetryPolicy(candidate store.RuntimeModelCandidate) map[string]any { + policy := candidate.PlatformRetryPolicy + if len(candidate.RuntimeRetryPolicy) > 0 { + policy = mergeMap(policy, candidate.RuntimeRetryPolicy) + } + if nested, ok := candidate.RuntimePolicyOverride["retryPolicy"].(map[string]any); ok && len(nested) > 0 { + policy = mergeMap(policy, nested) + } + if len(candidate.ModelRetryPolicy) > 0 { + policy = mergeMap(policy, candidate.ModelRetryPolicy) + } + return policy +} + func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string]any, body map[string]any) []store.RateLimitReservation { if scopeKey == "" || !hasRules(policy) { return nil diff --git a/apps/api/internal/runner/pricing.go b/apps/api/internal/runner/pricing.go index 1b4418c..82976e3 100644 --- a/apps/api/internal/runner/pricing.go +++ b/apps/api/internal/runner/pricing.go @@ -16,7 +16,7 @@ type EstimateResult struct { } func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) { - candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind), user) + candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind, body), user) if err != nil { return EstimateResult{}, err } @@ -38,7 +38,7 @@ func (s *Service) Estimate(ctx context.Context, kind string, model string, body } func (s *Service) billings(ctx context.Context, user *auth.User, kind string, body map[string]any, candidate store.RuntimeModelCandidate, response clients.Response, simulated bool) []any { - config := effectiveBillingConfig(candidate) + config := s.effectiveBillingConfig(ctx, candidate) discount := effectiveDiscount(ctx, s.store, user, candidate) if isTextGenerationKind(kind) { inputTokens := response.Usage.InputTokens @@ -74,11 +74,21 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo return []any{billingLine(candidate, resource, unit, count, roundPrice(amount), discount, simulated)} } -func effectiveBillingConfig(candidate store.RuntimeModelCandidate) map[string]any { +func (s *Service) effectiveBillingConfig(ctx context.Context, candidate store.RuntimeModelCandidate) map[string]any { base := candidate.BaseBillingConfig + if ruleSetID := firstNonEmptyString(candidate.BasePricingRuleSetID, candidate.PlatformPricingRuleSetID); ruleSetID != "" { + if ruleSetConfig, err := s.store.PricingRuleSetBillingConfig(ctx, ruleSetID); err == nil && len(ruleSetConfig) > 0 { + base = ruleSetConfig + } + } if len(candidate.BillingConfig) > 0 { base = candidate.BillingConfig } + if candidate.ModelPricingRuleSetID != "" { + if ruleSetConfig, err := s.store.PricingRuleSetBillingConfig(ctx, candidate.ModelPricingRuleSetID); err == nil && len(ruleSetConfig) > 0 { + base = ruleSetConfig + } + } if len(candidate.BillingConfigOverride) > 0 { base = mergeMap(base, candidate.BillingConfigOverride) } diff --git a/apps/api/internal/runner/runtime_policy.go b/apps/api/internal/runner/runtime_policy.go new file mode 100644 index 0000000..ab0c378 --- /dev/null +++ b/apps/api/internal/runner/runtime_policy.go @@ -0,0 +1,84 @@ +package runner + +import ( + "context" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func (s *Service) applyCandidateFailurePolicies(ctx context.Context, taskID string, candidate store.RuntimeModelCandidate, cause error, simulated bool) { + code := clients.ErrorCode(cause) + message := "" + if cause != nil { + message = cause.Error() + } + + autoDisablePolicy := effectiveRuntimePolicy(candidate.AutoDisablePolicy, candidate.RuntimePolicyOverride, "autoDisablePolicy") + if failurePolicyMatches(autoDisablePolicy, code, message) && intFromPolicy(autoDisablePolicy, "threshold") <= 1 { + if err := s.store.DisableCandidatePlatform(ctx, candidate.PlatformID); err == nil { + _ = s.emit(ctx, taskID, "task.policy.auto_disabled", "running", "auto_disable", 0.48, "candidate platform disabled by failure policy", map[string]any{ + "platformId": candidate.PlatformID, + "platformModelId": candidate.PlatformModelID, + "code": code, + }, simulated) + } + } + + degradePolicy := effectiveRuntimePolicy(candidate.DegradePolicy, candidate.RuntimePolicyOverride, "degradePolicy") + if failurePolicyMatches(degradePolicy, code, message) { + cooldownSeconds := intFromPolicy(degradePolicy, "cooldownSeconds") + if err := s.store.CooldownCandidatePlatform(ctx, candidate.PlatformID, cooldownSeconds); err == nil { + _ = s.emit(ctx, taskID, "task.policy.degraded", "running", "degrade", 0.5, "candidate platform cooled down by failure policy", map[string]any{ + "platformId": candidate.PlatformID, + "platformModelId": candidate.PlatformModelID, + "cooldownSeconds": cooldownSeconds, + "code": code, + }, simulated) + } + } +} + +func effectiveRuntimePolicy(base map[string]any, override map[string]any, key string) map[string]any { + policy := base + if nested, ok := override[key].(map[string]any); ok && len(nested) > 0 { + policy = mergeMap(policy, nested) + } + return policy +} + +func failurePolicyMatches(policy map[string]any, code string, message string) bool { + if len(policy) == 0 || !boolFromMap(policy, "enabled") { + return false + } + keywords := stringListFromPolicy(policy, "keywords") + if len(keywords) == 0 { + return false + } + target := strings.ToLower(strings.TrimSpace(code + " " + message)) + for _, keyword := range keywords { + keyword = strings.ToLower(strings.TrimSpace(keyword)) + if keyword != "" && strings.Contains(target, keyword) { + return true + } + } + return false +} + +func stringListFromPolicy(values map[string]any, key string) []string { + raw, ok := values[key].([]any) + if !ok { + if typed, ok := values[key].([]string); ok { + return typed + } + return nil + } + out := make([]string, 0, len(raw)) + for _, item := range raw { + if text, ok := item.(string); ok && strings.TrimSpace(text) != "" { + out = append(out, text) + } + } + return out +} diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index 72827ad..4fdc45a 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -50,8 +50,8 @@ func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, use } func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) { - modelType := modelTypeFromKind(task.Kind) body := normalizeRequest(task.Kind, task.Request) + modelType := modelTypeFromKind(task.Kind, body) if err := validateRequest(task.Kind, body); err != nil { failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err) if finishErr != nil { @@ -102,6 +102,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut if finishErr != nil { return Result{}, finishErr } + if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil { + return Result{}, settleErr + } + if finished.FinalChargeAmount > 0 { + if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{ + "amount": finished.FinalChargeAmount, + "currency": stringFromAny(record.BillingSummary["currency"]), + }, isSimulation(task, candidate)); err != nil { + return Result{}, err + } + } if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{ "result": response.Result, "billings": billings, @@ -226,6 +237,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user ErrorMessage: err.Error(), }) _ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable, "requestId": requestID, "metrics": metrics}, simulated) + s.applyCandidateFailurePolicies(ctx, task.ID, candidate, err, simulated) return clients.Response{}, err } uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result) @@ -317,7 +329,7 @@ func (s *Service) emit(ctx context.Context, taskID string, eventType string, sta return nil } -func modelTypeFromKind(kind string) string { +func modelTypeFromKind(kind string, body map[string]any) string { switch kind { case "chat.completions", "responses": return "text_generate" @@ -327,12 +339,30 @@ func modelTypeFromKind(kind string) string { } return "image_generate" case "videos.generations": + if videoRequestHasReferenceImage(body) { + return "image_to_video" + } return "video_generate" default: return "task" } } +func videoRequestHasReferenceImage(body map[string]any) bool { + if body == nil { + return false + } + for _, key := range []string{ + "image", "images", "image_url", "imageUrl", "image_urls", "imageUrls", + "reference_image", "referenceImage", "first_frame", "firstFrame", "last_frame", "lastFrame", + } { + if hasAnyString(body, key) { + return true + } + } + return false +} + func isTextGenerationKind(kind string) bool { return kind == "chat.completions" || kind == "responses" } @@ -345,10 +375,8 @@ func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate) } func retryEnabled(candidate store.RuntimeModelCandidate) bool { - if enabled, ok := candidate.ModelRetryPolicy["enabled"].(bool); ok { - return enabled - } - if enabled, ok := candidate.PlatformRetryPolicy["enabled"].(bool); ok { + policy := effectiveRetryPolicy(candidate) + if enabled, ok := policy["enabled"].(bool); ok { return enabled } return true @@ -360,10 +388,7 @@ func maxAttemptsForCandidates(candidates []store.RuntimeModelCandidate) int { } maxAttempts := len(candidates) for _, candidate := range candidates { - if value := intFromPolicy(candidate.ModelRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts { - maxAttempts = value - } - if value := intFromPolicy(candidate.PlatformRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts { + if value := intFromPolicy(effectiveRetryPolicy(candidate), "maxAttempts"); value > 0 && value < maxAttempts { maxAttempts = value } } diff --git a/apps/api/internal/store/base_models.go b/apps/api/internal/store/base_models.go index 3254c51..31c8898 100644 --- a/apps/api/internal/store/base_models.go +++ b/apps/api/internal/store/base_models.go @@ -59,7 +59,7 @@ func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) { rows, err := s.pool.Query(ctx, ` SELECT `+baseModelColumns+` FROM base_model_catalog -ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`) +ORDER BY provider_key ASC, canonical_model_key ASC`) if err != nil { return nil, err } @@ -84,7 +84,7 @@ func (s *Store) CreateBaseModel(ctx context.Context, input BaseModelInput) (Base runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride)) metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata)) defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot)) - modelType := primaryString(input.ModelType, "text_generate") + modelType, _ := json.Marshal(input.ModelType) return scanBaseModel(s.pool.QueryRow(ctx, ` INSERT INTO base_model_catalog ( @@ -94,7 +94,7 @@ INSERT INTO base_model_catalog ( ) VALUES ( (SELECT id FROM model_catalog_providers WHERE provider_key = $1 OR provider_code = $1 LIMIT 1), - $1, $2, $3, $4, $5, $6, $7, $8, + $1, $2, $3, $4::jsonb, $5, $6, $7, $8, COALESCE(NULLIF($9, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)), COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)), $11, $12, NULLIF($13, ''), NULLIF($14::jsonb, '{}'::jsonb), $15, $16 @@ -103,7 +103,7 @@ RETURNING `+baseModelColumns, input.ProviderKey, input.CanonicalModelKey, input.ProviderModelName, - modelType, + string(modelType), input.ModelAlias, capabilities, billingConfig, @@ -127,7 +127,7 @@ func (s *Store) UpdateBaseModel(ctx context.Context, id string, input BaseModelI runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride)) metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata)) defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot)) - modelType := primaryString(input.ModelType, "text_generate") + modelType, _ := json.Marshal(input.ModelType) return scanBaseModel(s.pool.QueryRow(ctx, ` UPDATE base_model_catalog @@ -135,7 +135,7 @@ SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = $ provider_key = $2, canonical_model_key = $3, provider_model_name = $4, - model_type = $5, + model_type = $5::jsonb, display_name = $6, capabilities = $7, base_billing_config = $8, @@ -156,7 +156,7 @@ RETURNING `+baseModelColumns, input.ProviderKey, input.CanonicalModelKey, input.ProviderModelName, - modelType, + string(modelType), input.ModelAlias, capabilities, billingConfig, @@ -195,7 +195,7 @@ SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = C provider_key = COALESCE($2::text, provider_key), canonical_model_key = COALESCE($3::text, canonical_model_key), provider_model_name = COALESCE($4::text, provider_model_name), - model_type = COALESCE($5::text, model_type), + model_type = COALESCE($5::jsonb, model_type), display_name = COALESCE($6::text, display_name), capabilities = COALESCE($7::jsonb, capabilities), base_billing_config = COALESCE($8::jsonb, base_billing_config), @@ -214,7 +214,7 @@ RETURNING `+baseModelColumns, stringFromSnapshot(snapshot, "providerKey"), stringFromSnapshot(snapshot, "canonicalModelKey"), stringFromSnapshot(snapshot, "providerModelName"), - stringFromSnapshot(snapshot, "modelType"), + jsonStringListFromSnapshot(snapshot, "modelType"), stringFromSnapshot(snapshot, "modelAlias", "displayName"), jsonFromSnapshot(snapshot, "capabilities"), jsonFromSnapshot(snapshot, "baseBillingConfig"), @@ -242,9 +242,10 @@ SET provider_id = ( canonical_model_key = COALESCE(NULLIF(default_snapshot->>'canonicalModelKey', ''), canonical_model_key), provider_model_name = COALESCE(NULLIF(default_snapshot->>'providerModelName', ''), provider_model_name), model_type = COALESCE(NULLIF(CASE - WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType'->>0 - ELSE default_snapshot->>'modelType' - END, ''), model_type), + WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType' + WHEN COALESCE(default_snapshot->>'modelType', '') <> '' THEN jsonb_build_array(default_snapshot->>'modelType') + ELSE NULL + END, '[]'::jsonb), model_type), display_name = COALESCE(NULLIF(COALESCE(default_snapshot->>'modelAlias', default_snapshot->>'displayName'), ''), display_name), capabilities = COALESCE(default_snapshot->'capabilities', capabilities), base_billing_config = COALESCE(default_snapshot->'baseBillingConfig', base_billing_config), @@ -292,7 +293,7 @@ func scanBaseModelRows(rows pgx.Rows) ([]BaseModel, error) { func scanBaseModel(scanner baseModelScanner) (BaseModel, error) { var item BaseModel - var modelType string + var modelType []byte var modelAlias string var capabilities []byte var billingConfig []byte @@ -330,7 +331,7 @@ func scanBaseModel(scanner baseModelScanner) (BaseModel, error) { item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride) item.Metadata = decodeObject(metadata) item.DefaultSnapshot = decodeObject(defaultSnapshot) - item.ModelType = baseModelTypes(item.Capabilities, item.Metadata, modelType) + item.ModelType = decodeStringArray(modelType) item.ModelAlias = modelAlias item.DisplayName = modelAlias return item, nil @@ -420,14 +421,22 @@ func jsonFromSnapshot(snapshot map[string]any, key string) any { return string(raw) } -func baseModelTypes(capabilities map[string]any, metadata map[string]any, fallback string) StringList { - values := make([]string, 0) - values = append(values, stringListFromAny(capabilities["originalTypes"])...) - values = append(values, stringListFromAny(metadata["originalTypes"])...) - if fallback != "" { - values = append(values, fallback) +func jsonStringListFromSnapshot(snapshot map[string]any, key string) any { + values := stringListFromAny(snapshot[key]) + if len(values) == 0 { + if value, ok := snapshot[key].(string); ok { + values = []string{value} + } } - return uniqueStringList(values) + normalized := uniqueStringList(values) + if len(normalized) == 0 { + return nil + } + raw, err := json.Marshal(normalized) + if err != nil { + return nil + } + return string(raw) } func stringListFromAny(value any) []string { @@ -451,16 +460,39 @@ func uniqueStringList(values []string) StringList { out := make([]string, 0, len(values)) seen := map[string]bool{} for _, value := range values { - value = strings.TrimSpace(value) - if value == "" || seen[value] { - continue + for _, normalized := range modelTypeAliases(value) { + normalized = strings.TrimSpace(normalized) + if normalized == "" || seen[normalized] { + continue + } + seen[normalized] = true + out = append(out, normalized) } - seen[value] = true - out = append(out, value) } return out } +func normalizeModelTypeList(values []string) StringList { + return uniqueStringList(values) +} + +func modelTypeAliases(value string) []string { + switch strings.TrimSpace(value) { + case "chat", "text", "responses": + return []string{"text_generate"} + case "image": + return []string{"image_generate", "image_edit"} + case "images.generations": + return []string{"image_generate"} + case "images.edits": + return []string{"image_edit"} + case "video", "videos.generations": + return []string{"video_generate"} + default: + return []string{value} + } +} + func primaryString(values []string, fallback string) string { for _, value := range values { if value = strings.TrimSpace(value); value != "" { diff --git a/apps/api/internal/store/candidates.go b/apps/api/internal/store/candidates.go index c18a4ef..6f77e8d 100644 --- a/apps/api/internal/store/candidates.go +++ b/apps/api/internal/store/candidates.go @@ -18,27 +18,25 @@ SELECT p.id::text, p.platform_key, p.name, p.provider, COALESCE(p.dynamic_priority, p.priority) AS effective_priority, m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''), COALESCE(b.provider_model_name, ''), m.model_name, COALESCE(m.model_alias, ''), - m.model_type, m.display_name, m.capabilities, m.capability_override, + $2 AS requested_model_type, m.display_name, m.capabilities, m.capability_override, COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override, m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''), - m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')), - COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb) + COALESCE(b.pricing_rule_set_id::text, ''), + m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')), + COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), + COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb), + COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb) FROM platform_models m JOIN integration_platforms p ON p.id = m.platform_id LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider LEFT JOIN base_model_catalog b ON b.id = m.base_model_id +LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id) LEFT JOIN runtime_client_states s - ON s.client_id = p.platform_key || ':' || m.model_type || ':' || m.model_name + ON s.client_id = p.platform_key || ':' || $2 || ':' || m.model_name WHERE p.status = 'enabled' AND p.deleted_at IS NULL AND m.enabled = true - AND ( - m.model_type = $2 - OR ($2 = 'text_generate' AND m.model_type IN ('chat', 'responses', 'text')) - OR ($2 = 'image_generate' AND m.model_type IN ('image', 'images.generations')) - OR ($2 = 'image_edit' AND m.model_type IN ('images.edits')) - OR ($2 = 'video_generate' AND m.model_type IN ('video', 'videos.generations', 'video_generate', 'text_to_video', 'image_to_video', 'omni_video', 'video_edit', 'video_reference', 'video_first_last_frame')) - ) + AND m.model_type @> jsonb_build_array($2) AND (p.cooldown_until IS NULL OR p.cooldown_until <= now()) AND ( m.model_name = $1 @@ -73,6 +71,10 @@ ORDER BY effective_priority ASC, var modelRetryPolicy []byte var modelRateLimitPolicy []byte var runtimePolicyOverride []byte + var runtimeRetryPolicy []byte + var runtimeRateLimitPolicy []byte + var autoDisablePolicy []byte + var degradePolicy []byte if err := rows.Scan( &item.PlatformID, &item.PlatformKey, @@ -105,11 +107,16 @@ ORDER BY effective_priority ASC, &item.PricingMode, &item.DiscountFactor, &item.ModelPricingRuleSetID, + &item.BasePricingRuleSetID, &permissionConfig, &modelRetryPolicy, &modelRateLimitPolicy, &item.RuntimePolicySetID, &runtimePolicyOverride, + &runtimeRetryPolicy, + &runtimeRateLimitPolicy, + &autoDisablePolicy, + °radePolicy, ); err != nil { return nil, err } @@ -126,6 +133,10 @@ ORDER BY effective_priority ASC, item.ModelRetryPolicy = decodeObject(modelRetryPolicy) item.ModelRateLimitPolicy = decodeObject(modelRateLimitPolicy) item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride) + item.RuntimeRetryPolicy = decodeObject(runtimeRetryPolicy) + item.RuntimeRateLimitPolicy = decodeObject(runtimeRateLimitPolicy) + item.AutoDisablePolicy = decodeObject(autoDisablePolicy) + item.DegradePolicy = decodeObject(degradePolicy) item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, item.ModelName) item.QueueKey = item.ClientID items = append(items, item) diff --git a/apps/api/internal/store/platform_models.go b/apps/api/internal/store/platform_models.go index ebeab8a..4609bf2 100644 --- a/apps/api/internal/store/platform_models.go +++ b/apps/api/internal/store/platform_models.go @@ -17,7 +17,7 @@ type modelCatalogSnapshot struct { ProviderKey string CanonicalModelKey string ProviderModelName string - ModelType string + ModelType StringList DisplayName string Capabilities map[string]any BaseBillingConfig map[string]any @@ -83,9 +83,13 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier, if err != nil && !IsNotFound(err) { return PlatformModel{}, err } - if input.ModelType == "" { + if len(input.ModelType) == 0 { input.ModelType = base.ModelType } + input.ModelType = normalizeModelTypeList(input.ModelType) + if len(input.ModelType) == 0 { + input.ModelType = StringList{"text_generate"} + } if input.ModelName == "" { input.ModelName = base.ProviderModelName } @@ -104,11 +108,12 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier, if len(billingConfig) == 0 { billingConfig = mergeObjects(base.BaseBillingConfig, input.BillingConfigOverride) } + explicitRuntimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID) rateLimitPolicy := input.RateLimitPolicy - if len(rateLimitPolicy) == 0 { + if len(rateLimitPolicy) == 0 && explicitRuntimePolicySetID == "" { rateLimitPolicy = base.DefaultRateLimitPolicy } - runtimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID) + runtimePolicySetID := explicitRuntimePolicySetID if runtimePolicySetID == "" { runtimePolicySetID = base.RuntimePolicySetID } @@ -119,6 +124,7 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier, capabilityOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.CapabilityOverride)) capabilitiesJSON, _ := json.Marshal(emptyObjectIfNil(capabilities)) + modelTypeJSON, _ := json.Marshal(input.ModelType) billingOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingConfigOverride)) billingJSON, _ := json.Marshal(emptyObjectIfNil(billingConfig)) permissionJSON, _ := json.Marshal(emptyObjectIfNil(input.PermissionConfig)) @@ -144,6 +150,7 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier, var retryPolicyBytes []byte var rateLimitPolicyBytes []byte var runtimePolicyOverrideBytes []byte + var modelTypeBytes []byte err = q.QueryRow(ctx, ` INSERT INTO platform_models ( platform_id, base_model_id, model_name, model_alias, model_type, display_name, @@ -152,12 +159,12 @@ INSERT INTO platform_models ( runtime_policy_set_id, runtime_policy_override, enabled ) VALUES ( - $1::uuid, $2::uuid, $3, NULLIF($4, ''), $5, $6, + $1::uuid, $2::uuid, $3, NULLIF($4, ''), $5::jsonb, $6, $7::jsonb, $8::jsonb, $9, $10::numeric, NULLIF($11, '')::uuid, $12::jsonb, $13::jsonb, $14::jsonb, $15::jsonb, $16::jsonb, NULLIF($17, '')::uuid, $18::jsonb, true ) -ON CONFLICT (platform_id, model_name, model_type) DO UPDATE +ON CONFLICT (platform_id, model_name) DO UPDATE SET base_model_id = EXCLUDED.base_model_id, model_alias = EXCLUDED.model_alias, display_name = EXCLUDED.display_name, @@ -185,7 +192,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_ baseID, input.ModelName, input.ModelAlias, - input.ModelType, + string(modelTypeJSON), input.DisplayName, string(capabilityOverrideJSON), string(capabilitiesJSON), @@ -205,7 +212,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_ &model.BaseModelID, &model.ModelName, &model.ModelAlias, - &model.ModelType, + &modelTypeBytes, &model.DisplayName, &capabilityOverrideBytes, &capabilitiesBytes, @@ -228,6 +235,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_ } model.CapabilityOverride = decodeObject(capabilityOverrideBytes) model.Capabilities = decodeObject(capabilitiesBytes) + model.ModelType = decodeStringArray(modelTypeBytes) model.BillingConfigOverride = decodeObject(billingOverrideBytes) model.BillingConfig = decodeObject(billingBytes) model.PermissionConfig = decodeObject(permissionBytes) @@ -265,6 +273,7 @@ func (s *Store) lookupBaseModel(ctx context.Context, q platformModelQuerier, id var billingConfig []byte var rateLimitPolicy []byte var runtimePolicyOverride []byte + var modelTypeBytes []byte err := q.QueryRow(ctx, ` SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name, capabilities, base_billing_config, default_rate_limit_policy, @@ -279,7 +288,7 @@ LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSp &item.ProviderKey, &item.CanonicalModelKey, &item.ProviderModelName, - &item.ModelType, + &modelTypeBytes, &item.DisplayName, &capabilities, &billingConfig, @@ -297,6 +306,7 @@ LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSp item.BaseBillingConfig = decodeObject(billingConfig) item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy) item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride) + item.ModelType = normalizeModelTypeList(decodeStringArray(modelTypeBytes)) return item, nil } diff --git a/apps/api/internal/store/postgres.go b/apps/api/internal/store/postgres.go index 9794bc5..84659b9 100644 --- a/apps/api/internal/store/postgres.go +++ b/apps/api/internal/store/postgres.go @@ -134,7 +134,7 @@ type PlatformModel struct { PlatformName string `json:"platformName,omitempty"` ModelName string `json:"modelName"` ModelAlias string `json:"modelAlias,omitempty"` - ModelType string `json:"modelType"` + ModelType StringList `json:"modelType"` DisplayName string `json:"displayName"` CapabilityOverride map[string]any `json:"capabilityOverride,omitempty"` Capabilities map[string]any `json:"capabilities,omitempty"` @@ -390,6 +390,8 @@ type GatewayTask struct { ResponseDurationMS int64 `json:"responseDurationMs"` FinishedAt string `json:"finishedAt,omitempty"` Error string `json:"error,omitempty"` + ErrorCode string `json:"errorCode,omitempty"` + ErrorMessage string `json:"errorMessage,omitempty"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } @@ -404,6 +406,7 @@ request, status, COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb), COALESCE(usage, '{}'::jsonb), COALESCE(metrics, '{}'::jsonb), COALESCE(billing_summary, '{}'::jsonb), COALESCE(final_charge_amount, 0)::float8, COALESCE(response_started_at::text, ''), COALESCE(response_finished_at::text, ''), COALESCE(response_duration_ms, 0), COALESCE(error, ''), +COALESCE(error_code, ''), COALESCE(error_message, ''), created_at, updated_at, COALESCE(finished_at::text, '')` type TaskEvent struct { @@ -713,6 +716,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...) var retryPolicy []byte var rateLimitPolicy []byte var runtimePolicyOverride []byte + var modelTypeBytes []byte if err := rows.Scan( &model.ID, &model.PlatformID, @@ -721,7 +725,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...) &model.PlatformName, &model.ModelName, &model.ModelAlias, - &model.ModelType, + &modelTypeBytes, &model.DisplayName, &capabilityOverride, &capabilities, @@ -743,6 +747,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...) } model.CapabilityOverride = decodeObject(capabilityOverride) model.Capabilities = decodeObject(capabilities) + model.ModelType = decodeStringArray(modelTypeBytes) model.BillingConfigOverride = decodeObject(billingConfigOverride) model.BillingConfig = decodeObject(billingConfig) model.PermissionConfig = decodeObject(permissionConfig) @@ -1231,7 +1236,7 @@ func (s *Store) VerifyLocalAPIKey(ctx context.Context, secret string) (*auth.Use return nil, auth.ErrUnauthorized } rows, err := s.pool.Query(ctx, ` -SELECT k.id::text, k.key_hash, k.key_prefix, k.name, COALESCE(k.user_group_id::text, u.default_user_group_id::text, ''), +SELECT k.id::text, k.key_hash, k.key_prefix, k.name, k.scopes, COALESCE(k.user_group_id::text, u.default_user_group_id::text, ''), u.id::text, u.username, u.roles, COALESCE(u.gateway_tenant_id::text, ''), COALESCE(u.tenant_id, ''), COALESCE(u.tenant_key, '') FROM gateway_api_keys k @@ -1252,6 +1257,7 @@ WHERE k.key_prefix = $1 var hash string var keyPrefix string var keyName string + var scopesBytes []byte var userGroupID string var gatewayUserID string var username string @@ -1259,7 +1265,7 @@ WHERE k.key_prefix = $1 var gatewayTenantID string var tenantID string var tenantKey string - if err := rows.Scan(&apiKeyID, &hash, &keyPrefix, &keyName, &userGroupID, &gatewayUserID, &username, &rolesBytes, &gatewayTenantID, &tenantID, &tenantKey); err != nil { + if err := rows.Scan(&apiKeyID, &hash, &keyPrefix, &keyName, &scopesBytes, &userGroupID, &gatewayUserID, &username, &rolesBytes, &gatewayTenantID, &tenantID, &tenantKey); err != nil { return nil, err } if bcrypt.CompareHashAndPassword([]byte(hash), []byte(secret)) != nil { @@ -1281,6 +1287,7 @@ WHERE k.key_prefix = $1 APIKeyID: apiKeyID, APIKeyName: keyName, APIKeyPrefix: keyPrefix, + APIKeyScopes: decodeStringArray(scopesBytes), }, nil } if err := rows.Err(); err != nil { @@ -1661,6 +1668,8 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) { &task.ResponseFinishedAt, &task.ResponseDurationMS, &task.Error, + &task.ErrorCode, + &task.ErrorMessage, &task.CreatedAt, &task.UpdatedAt, &task.FinishedAt, diff --git a/apps/api/internal/store/pricing_rules.go b/apps/api/internal/store/pricing_rules.go index 8cf2ea8..5aab217 100644 --- a/apps/api/internal/store/pricing_rules.go +++ b/apps/api/internal/store/pricing_rules.go @@ -176,6 +176,63 @@ func (s *Store) DeletePricingRuleSet(ctx context.Context, id string) error { return nil } +func (s *Store) PricingRuleSetBillingConfig(ctx context.Context, id string) (map[string]any, error) { + id = strings.TrimSpace(id) + if id == "" { + return nil, nil + } + rows, err := s.pool.Query(ctx, ` +SELECT resource_type, base_price::float8, dynamic_weight +FROM model_pricing_rules +WHERE rule_set_id = $1::uuid + AND status = 'active' +ORDER BY priority ASC, resource_type ASC`, id) + if err != nil { + return nil, err + } + defer rows.Close() + + config := map[string]any{} + for rows.Next() { + var resourceType string + var basePrice float64 + var dynamicWeightBytes []byte + if err := rows.Scan(&resourceType, &basePrice, &dynamicWeightBytes); err != nil { + return nil, err + } + dynamicWeight := decodeObject(dynamicWeightBytes) + switch resourceType { + case "text_input": + config["textInputPer1k"] = basePrice + case "text_output": + config["textOutputPer1k"] = basePrice + case "image": + config["imageBase"] = basePrice + config["image"] = pricingResourceConfig(basePrice, dynamicWeight) + case "image_edit": + config["editBase"] = basePrice + config["image_edit"] = pricingResourceConfig(basePrice, dynamicWeight) + case "video": + config["videoBase"] = basePrice + config["video"] = pricingResourceConfig(basePrice, dynamicWeight) + default: + config[resourceType] = pricingResourceConfig(basePrice, dynamicWeight) + } + } + if err := rows.Err(); err != nil { + return nil, err + } + return config, nil +} + +func pricingResourceConfig(basePrice float64, dynamicWeight map[string]any) map[string]any { + config := map[string]any{"basePrice": basePrice} + if len(dynamicWeight) > 0 { + config["dynamicWeight"] = dynamicWeight + } + return config +} + func insertPricingRules(ctx context.Context, tx pgx.Tx, ruleSetID string, defaultCurrency string, rules []PricingRuleInput) error { for index, rule := range rules { rule = normalizePricingRule(rule, index, defaultCurrency) diff --git a/apps/api/internal/store/runtime_policies.go b/apps/api/internal/store/runtime_policies.go index 309c4ca..6c44db0 100644 --- a/apps/api/internal/store/runtime_policies.go +++ b/apps/api/internal/store/runtime_policies.go @@ -107,6 +107,33 @@ func (s *Store) DeleteRuntimePolicySet(ctx context.Context, id string) error { return nil } +func (s *Store) DisableCandidatePlatform(ctx context.Context, platformID string) error { + if strings.TrimSpace(platformID) == "" { + return nil + } + _, err := s.pool.Exec(ctx, ` +UPDATE integration_platforms +SET status = 'disabled', + updated_at = now() +WHERE id = $1::uuid`, platformID) + return err +} + +func (s *Store) CooldownCandidatePlatform(ctx context.Context, platformID string, cooldownSeconds int) error { + if strings.TrimSpace(platformID) == "" { + return nil + } + if cooldownSeconds <= 0 { + cooldownSeconds = 300 + } + _, err := s.pool.Exec(ctx, ` +UPDATE integration_platforms +SET cooldown_until = now() + ($2::int * interval '1 second'), + updated_at = now() +WHERE id = $1::uuid`, platformID, cooldownSeconds) + return err +} + func scanRuntimePolicySet(scanner runtimePolicyScanner) (RuntimePolicySet, error) { var item RuntimePolicySet var rateLimitPolicy []byte diff --git a/apps/api/internal/store/runtime_types.go b/apps/api/internal/store/runtime_types.go index 8dfa0a6..ad968b7 100644 --- a/apps/api/internal/store/runtime_types.go +++ b/apps/api/internal/store/runtime_types.go @@ -16,7 +16,7 @@ type CreatePlatformModelInput struct { CanonicalModelKey string `json:"canonicalModelKey"` ModelName string `json:"modelName"` ModelAlias string `json:"modelAlias"` - ModelType string `json:"modelType"` + ModelType StringList `json:"modelType"` DisplayName string `json:"displayName"` CapabilityOverride map[string]any `json:"capabilityOverride"` Capabilities map[string]any `json:"capabilities"` @@ -64,12 +64,17 @@ type RuntimeModelCandidate struct { PermissionConfig map[string]any PricingMode string DiscountFactor float64 + BasePricingRuleSetID string PlatformPricingRuleSetID string ModelPricingRuleSetID string ModelRetryPolicy map[string]any ModelRateLimitPolicy map[string]any RuntimePolicySetID string RuntimePolicyOverride map[string]any + RuntimeRetryPolicy map[string]any + RuntimeRateLimitPolicy map[string]any + AutoDisablePolicy map[string]any + DegradePolicy map[string]any ClientID string QueueKey string } diff --git a/apps/api/internal/store/tasks_runtime.go b/apps/api/internal/store/tasks_runtime.go index 2d622ac..0406047 100644 --- a/apps/api/internal/store/tasks_runtime.go +++ b/apps/api/internal/store/tasks_runtime.go @@ -3,9 +3,67 @@ package store import ( "context" "encoding/json" + "strings" "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) +func (s *Store) ListTasks(ctx context.Context, user *auth.User, limit int) ([]GatewayTask, error) { + if limit <= 0 { + limit = 50 + } + if limit > 100 { + limit = 100 + } + gatewayUserID := localGatewayUserID(user) + apiKeyID := "" + userID := "" + if user != nil { + apiKeyID = strings.TrimSpace(user.APIKeyID) + userID = strings.TrimSpace(user.ID) + } + if gatewayUserID == "" && userID == "" { + return nil, ErrLocalUserRequired + } + rows, err := s.pool.Query(ctx, ` +SELECT `+gatewayTaskColumns+` +FROM gateway_tasks +WHERE ( + ( + NULLIF($1, '')::uuid IS NOT NULL + AND gateway_user_id = NULLIF($1, '')::uuid + ) + OR ( + NULLIF($1, '')::uuid IS NULL + AND NULLIF($2, '') IS NOT NULL + AND user_id = $2 + ) + ) + AND ( + NULLIF($3, '') IS NULL + OR api_key_id = $3 + ) +ORDER BY created_at DESC +LIMIT $4`, gatewayUserID, userID, apiKeyID, limit) + if err != nil { + return nil, err + } + defer rows.Close() + + items := make([]GatewayTask, 0) + for rows.Next() { + task, err := scanGatewayTask(rows) + if err != nil { + return nil, err + } + items = append(items, task) + } + return items, rows.Err() +} + func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error { normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest)) _, err := s.pool.Exec(ctx, ` @@ -140,6 +198,103 @@ WHERE id = $1::uuid`, return s.GetTask(ctx, input.TaskID) } +func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error { + if task.FinalChargeAmount <= 0 || strings.TrimSpace(task.GatewayUserID) == "" { + return nil + } + currency := strings.TrimSpace(taskBillingString(task.BillingSummary["currency"])) + if currency == "" || currency == "mixed" { + currency = "resource" + } + metadata, _ := json.Marshal(map[string]any{ + "taskId": task.ID, + "kind": task.Kind, + "model": task.Model, + "resolvedModel": task.ResolvedModel, + "billings": task.Billings, + "billingSummary": task.BillingSummary, + }) + return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error { + if _, err := tx.Exec(ctx, ` +INSERT INTO gateway_wallet_accounts ( + gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency +) +VALUES (NULLIF($1, '')::uuid, $2::uuid, NULLIF($3, ''), NULLIF($4, ''), NULLIF($5, ''), $6) +ON CONFLICT (gateway_user_id, currency) DO NOTHING`, + task.GatewayTenantID, task.GatewayUserID, task.TenantID, task.TenantKey, task.UserID, currency); err != nil { + return err + } + var exists bool + if err := tx.QueryRow(ctx, ` +SELECT EXISTS ( + SELECT 1 + FROM gateway_wallet_transactions t + JOIN gateway_wallet_accounts a ON a.id = t.account_id + WHERE a.gateway_user_id = $1::uuid + AND a.currency = $2 + AND t.idempotency_key = $3 +)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil { + return err + } + if exists { + return nil + } + var accountID string + var balanceBefore float64 + var gatewayTenantID string + if err := tx.QueryRow(ctx, ` +SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '') +FROM gateway_wallet_accounts +WHERE gateway_user_id = $1::uuid + AND currency = $2 +FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil { + return err + } + amount := roundMoney(task.FinalChargeAmount) + balanceAfter := roundMoney(balanceBefore - amount) + if _, err := tx.Exec(ctx, ` +UPDATE gateway_wallet_accounts +SET balance = $2, + total_spent = total_spent + $3, + updated_at = now() +WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil { + return err + } + _, err := tx.Exec(ctx, ` +INSERT INTO gateway_wallet_transactions ( + account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type, + amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata +) +VALUES ( + $1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing', + $4, $5, $6, $7, 'gateway_task', $8, $9::jsonb +)`, + accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)) + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { + return nil + } + return err + }) +} + +func billingIdempotencyKey(taskID string) string { + return "task:" + taskID + ":billing" +} + +func roundMoney(value float64) float64 { + if value < 0 { + return -roundMoney(-value) + } + return float64(int64(value*1000000+0.5)) / 1000000 +} + +func taskBillingString(value any) string { + if text, ok := value.(string); ok { + return text + } + return "" +} + func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) { metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) if _, err := s.pool.Exec(ctx, ` diff --git a/apps/api/migrations/0001_init.sql b/apps/api/migrations/0001_init.sql index b993663..199dae4 100644 --- a/apps/api/migrations/0001_init.sql +++ b/apps/api/migrations/0001_init.sql @@ -376,7 +376,7 @@ CREATE TABLE IF NOT EXISTS platform_models ( base_model_id uuid REFERENCES base_model_catalog(id) ON DELETE SET NULL, model_name text NOT NULL, model_alias text, - model_type text NOT NULL, + model_type jsonb NOT NULL DEFAULT '[]'::jsonb, display_name text NOT NULL DEFAULT '', capability_override jsonb NOT NULL DEFAULT '{}'::jsonb, capabilities jsonb NOT NULL DEFAULT '{}'::jsonb, @@ -391,14 +391,17 @@ CREATE TABLE IF NOT EXISTS platform_models ( enabled boolean NOT NULL DEFAULT true, created_at timestamptz NOT NULL DEFAULT now(), updated_at timestamptz NOT NULL DEFAULT now(), - UNIQUE(platform_id, model_name, model_type) + UNIQUE(platform_id, model_name) ); CREATE INDEX IF NOT EXISTS idx_platform_models_base ON platform_models(base_model_id); CREATE INDEX IF NOT EXISTS idx_platform_models_lookup - ON platform_models(model_type, model_name, enabled); + ON platform_models(model_name, enabled); + +CREATE INDEX IF NOT EXISTS idx_platform_models_model_type + ON platform_models USING gin(model_type); CREATE INDEX IF NOT EXISTS idx_platform_models_alias ON platform_models(model_alias); diff --git a/apps/api/migrations/0002_invitation_relationship_only.sql b/apps/api/migrations/0002_invitation_relationship_only.sql index 65dad25..5a5b22f 100644 --- a/apps/api/migrations/0002_invitation_relationship_only.sql +++ b/apps/api/migrations/0002_invitation_relationship_only.sql @@ -323,7 +323,7 @@ CREATE TABLE IF NOT EXISTS platform_models ( base_model_id uuid REFERENCES base_model_catalog(id) ON DELETE SET NULL, model_name text NOT NULL, model_alias text, - model_type text NOT NULL, + model_type jsonb NOT NULL DEFAULT '[]'::jsonb, display_name text NOT NULL DEFAULT '', capability_override jsonb NOT NULL DEFAULT '{}'::jsonb, capabilities jsonb NOT NULL DEFAULT '{}'::jsonb, @@ -338,7 +338,7 @@ CREATE TABLE IF NOT EXISTS platform_models ( enabled boolean NOT NULL DEFAULT true, created_at timestamptz NOT NULL DEFAULT now(), updated_at timestamptz NOT NULL DEFAULT now(), - UNIQUE(platform_id, model_name, model_type) + UNIQUE(platform_id, model_name) ); ALTER TABLE IF EXISTS platform_models @@ -636,7 +636,10 @@ CREATE INDEX IF NOT EXISTS idx_gateway_recharge_orders_user CREATE INDEX IF NOT EXISTS idx_platform_models_base ON platform_models(base_model_id); CREATE INDEX IF NOT EXISTS idx_platform_models_lookup - ON platform_models(model_type, model_name, enabled); + ON platform_models(model_name, enabled); + +CREATE INDEX IF NOT EXISTS idx_platform_models_model_type + ON platform_models USING gin(model_type); CREATE INDEX IF NOT EXISTS idx_platform_models_alias ON platform_models(model_alias); CREATE INDEX IF NOT EXISTS idx_platform_models_capabilities diff --git a/apps/api/migrations/0003_phase1_seed_runtime.sql b/apps/api/migrations/0003_phase1_seed_runtime.sql index 9757de6..e20e43f 100644 --- a/apps/api/migrations/0003_phase1_seed_runtime.sql +++ b/apps/api/migrations/0003_phase1_seed_runtime.sql @@ -199,7 +199,19 @@ INSERT INTO platform_models ( platform_id, base_model_id, model_name, model_alias, model_type, display_name, capabilities, pricing_mode, billing_config, retry_policy, rate_limit_policy, enabled ) -SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key, b.model_type, b.display_name, +SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key, + CASE b.model_type + WHEN 'chat' THEN '["text_generate"]'::jsonb + WHEN 'text' THEN '["text_generate"]'::jsonb + WHEN 'responses' THEN '["text_generate"]'::jsonb + WHEN 'image' THEN '["image_generate","image_edit"]'::jsonb + WHEN 'images.generations' THEN '["image_generate"]'::jsonb + WHEN 'images.edits' THEN '["image_edit"]'::jsonb + WHEN 'video' THEN '["video_generate"]'::jsonb + WHEN 'videos.generations' THEN '["video_generate"]'::jsonb + ELSE jsonb_build_array(b.model_type) + END, + b.display_name, b.capabilities, 'inherit_discount', b.base_billing_config, '{"enabled":true,"maxAttempts":2}'::jsonb, b.default_rate_limit_policy, @@ -207,7 +219,7 @@ SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key, b.model_type, b FROM integration_platforms p JOIN base_model_catalog b ON b.provider_key = p.provider WHERE p.platform_key IN ('openai-simulation', 'gemini-simulation') -ON CONFLICT (platform_id, model_name, model_type) DO UPDATE +ON CONFLICT (platform_id, model_name) DO UPDATE SET base_model_id = EXCLUDED.base_model_id, model_alias = EXCLUDED.model_alias, display_name = EXCLUDED.display_name, diff --git a/apps/api/migrations/0019_platform_model_type_array.sql b/apps/api/migrations/0019_platform_model_type_array.sql new file mode 100644 index 0000000..0a241ff --- /dev/null +++ b/apps/api/migrations/0019_platform_model_type_array.sql @@ -0,0 +1,138 @@ +DO $$ +DECLARE + column_type text; +BEGIN + SELECT data_type + INTO column_type + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'platform_models' + AND column_name = 'model_type'; + + IF column_type IS DISTINCT FROM 'jsonb' THEN + DROP INDEX IF EXISTS idx_platform_models_lookup; + ALTER TABLE platform_models + DROP CONSTRAINT IF EXISTS platform_models_platform_id_model_name_model_type_key; + + ALTER TABLE platform_models + ADD COLUMN IF NOT EXISTS model_type_next jsonb NOT NULL DEFAULT '[]'::jsonb; + + UPDATE platform_models + SET model_type_next = CASE trim(model_type) + WHEN 'chat' THEN '["text_generate"]'::jsonb + WHEN 'text' THEN '["text_generate"]'::jsonb + WHEN 'responses' THEN '["text_generate"]'::jsonb + WHEN 'image' THEN '["image_generate","image_edit"]'::jsonb + WHEN 'images.generations' THEN '["image_generate"]'::jsonb + WHEN 'images.edits' THEN '["image_edit"]'::jsonb + WHEN 'video' THEN '["video_generate"]'::jsonb + WHEN 'videos.generations' THEN '["video_generate"]'::jsonb + ELSE jsonb_build_array(trim(model_type)) + END; + + WITH ranked AS ( + SELECT id, + first_value(id) OVER ( + PARTITION BY platform_id, model_name + ORDER BY created_at ASC, id ASC + ) AS keep_id, + row_number() OVER ( + PARTITION BY platform_id, model_name + ORDER BY created_at ASC, id ASC + ) AS row_number + FROM platform_models + ), + merged AS ( + SELECT ranked.keep_id, + jsonb_agg(DISTINCT type_value ORDER BY type_value) AS model_type + FROM ranked + JOIN platform_models model_row ON model_row.id = ranked.id + CROSS JOIN LATERAL jsonb_array_elements_text(model_row.model_type_next) AS type_item(type_value) + GROUP BY ranked.keep_id + ) + UPDATE platform_models target + SET model_type_next = merged.model_type + FROM merged + WHERE target.id = merged.keep_id; + + WITH ranked AS ( + SELECT id, + first_value(id) OVER ( + PARTITION BY platform_id, model_name + ORDER BY created_at ASC, id ASC + ) AS keep_id, + row_number() OVER ( + PARTITION BY platform_id, model_name + ORDER BY created_at ASC, id ASC + ) AS row_number + FROM platform_models + ) + UPDATE gateway_access_rules rules + SET resource_id = ranked.keep_id + FROM ranked + WHERE ranked.row_number > 1 + AND rules.resource_type = 'platform_model' + AND rules.resource_id = ranked.id; + + WITH ranked AS ( + SELECT id, + row_number() OVER ( + PARTITION BY platform_id, model_name + ORDER BY created_at ASC, id ASC + ) AS row_number + FROM platform_models + ) + DELETE FROM platform_models model_row + USING ranked + WHERE model_row.id = ranked.id + AND ranked.row_number > 1; + + ALTER TABLE platform_models DROP COLUMN model_type; + ALTER TABLE platform_models RENAME COLUMN model_type_next TO model_type; + END IF; +END $$; + +UPDATE platform_models +SET model_type = CASE + WHEN jsonb_typeof(model_type) = 'array' THEN model_type + WHEN jsonb_typeof(model_type) = 'string' THEN jsonb_build_array(model_type #>> '{}') + ELSE '[]'::jsonb +END; + +UPDATE platform_models +SET model_type = COALESCE(( + SELECT jsonb_agg(DISTINCT normalized_type ORDER BY normalized_type) + FROM jsonb_array_elements_text(platform_models.model_type) AS item(model_type_value) + CROSS JOIN LATERAL ( + VALUES + (CASE item.model_type_value + WHEN 'chat' THEN 'text_generate' + WHEN 'text' THEN 'text_generate' + WHEN 'responses' THEN 'text_generate' + WHEN 'images.generations' THEN 'image_generate' + WHEN 'images.edits' THEN 'image_edit' + WHEN 'video' THEN 'video_generate' + WHEN 'videos.generations' THEN 'video_generate' + ELSE item.model_type_value + END) + ) AS normalized(normalized_type) + WHERE trim(normalized_type) <> '' +), '[]'::jsonb); + +UPDATE platform_models +SET model_type = '["image_generate","image_edit"]'::jsonb +WHERE model_type = '["image"]'::jsonb; + +ALTER TABLE platform_models + ALTER COLUMN model_type SET NOT NULL, + ALTER COLUMN model_type SET DEFAULT '[]'::jsonb; + +DROP INDEX IF EXISTS idx_platform_models_lookup; +CREATE INDEX IF NOT EXISTS idx_platform_models_lookup + ON platform_models(model_name, enabled); + +CREATE INDEX IF NOT EXISTS idx_platform_models_model_type + ON platform_models USING gin(model_type); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_platform_models_platform_model_name + ON platform_models(platform_id, model_name); diff --git a/apps/api/migrations/0020_runtime_policy_set_owns_rate_limit.sql b/apps/api/migrations/0020_runtime_policy_set_owns_rate_limit.sql new file mode 100644 index 0000000..4648ec5 --- /dev/null +++ b/apps/api/migrations/0020_runtime_policy_set_owns_rate_limit.sql @@ -0,0 +1,11 @@ +UPDATE platform_models m +SET rate_limit_policy = '{}'::jsonb, + updated_at = now() +FROM base_model_catalog b, + model_runtime_policy_sets rp +WHERE m.base_model_id = b.id + AND m.runtime_policy_set_id = rp.id + AND m.runtime_policy_set_id IS NOT NULL + AND m.runtime_policy_set_id IS DISTINCT FROM b.runtime_policy_set_id + AND m.rate_limit_policy = b.default_rate_limit_policy + AND rp.rate_limit_policy <> '{}'::jsonb; diff --git a/apps/api/migrations/0021_base_model_type_array.sql b/apps/api/migrations/0021_base_model_type_array.sql new file mode 100644 index 0000000..5e4ba4d --- /dev/null +++ b/apps/api/migrations/0021_base_model_type_array.sql @@ -0,0 +1,108 @@ +DO $$ +DECLARE + column_type text; +BEGIN + SELECT data_type + INTO column_type + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'base_model_catalog' + AND column_name = 'model_type'; + + IF column_type IS DISTINCT FROM 'jsonb' THEN + DROP INDEX IF EXISTS idx_base_model_catalog_type; + + ALTER TABLE base_model_catalog + ADD COLUMN IF NOT EXISTS model_type_next jsonb NOT NULL DEFAULT '[]'::jsonb; + + WITH source_types AS ( + SELECT id, + CASE + WHEN jsonb_typeof(capabilities->'originalTypes') = 'array' THEN capabilities->'originalTypes' + WHEN jsonb_typeof(metadata->'originalTypes') = 'array' THEN metadata->'originalTypes' + ELSE jsonb_build_array(model_type) + END AS model_types + FROM base_model_catalog + ), + normalized_types AS ( + SELECT source_types.id, normalized.model_type + FROM source_types + CROSS JOIN LATERAL jsonb_array_elements_text(source_types.model_types) AS raw_type(model_type) + CROSS JOIN LATERAL ( + SELECT 'text_generate' AS model_type + WHERE raw_type.model_type IN ('chat', 'text', 'responses') + UNION ALL SELECT 'image_generate' + WHERE raw_type.model_type = 'image' + UNION ALL SELECT 'image_edit' + WHERE raw_type.model_type = 'image' + UNION ALL SELECT 'image_generate' + WHERE raw_type.model_type = 'images.generations' + UNION ALL SELECT 'image_edit' + WHERE raw_type.model_type = 'images.edits' + UNION ALL SELECT 'video_generate' + WHERE raw_type.model_type IN ('video', 'videos.generations') + UNION ALL SELECT raw_type.model_type + WHERE trim(raw_type.model_type) <> '' + AND raw_type.model_type NOT IN ( + 'chat', 'text', 'responses', 'image', 'images.generations', 'images.edits', 'video', 'videos.generations' + ) + ) AS normalized + ) + UPDATE base_model_catalog target + SET model_type_next = COALESCE(( + SELECT jsonb_agg(DISTINCT normalized_types.model_type ORDER BY normalized_types.model_type) + FROM normalized_types + WHERE normalized_types.id = target.id + ), '[]'::jsonb); + + ALTER TABLE base_model_catalog DROP COLUMN model_type; + ALTER TABLE base_model_catalog RENAME COLUMN model_type_next TO model_type; + END IF; +END $$; + +UPDATE base_model_catalog +SET model_type = CASE + WHEN jsonb_typeof(model_type) = 'array' THEN model_type + WHEN jsonb_typeof(model_type) = 'string' THEN jsonb_build_array(model_type #>> '{}') + ELSE '[]'::jsonb +END; + +UPDATE base_model_catalog +SET default_snapshot = default_snapshot || jsonb_build_object('modelType', model_type) +WHERE catalog_type = 'system' + AND COALESCE(default_snapshot, '{}'::jsonb) <> '{}'::jsonb; + +ALTER TABLE base_model_catalog + ALTER COLUMN model_type SET NOT NULL, + ALTER COLUMN model_type SET DEFAULT '[]'::jsonb; + +CREATE INDEX IF NOT EXISTS idx_base_model_catalog_type + ON base_model_catalog(provider_key, status); + +CREATE INDEX IF NOT EXISTS idx_base_model_catalog_model_type + ON base_model_catalog USING gin(model_type); + +CREATE OR REPLACE FUNCTION fill_system_base_model_default_snapshot() +RETURNS trigger AS $$ +BEGIN + IF NEW.catalog_type = 'system' AND NEW.default_snapshot IS NULL THEN + NEW.default_snapshot = jsonb_build_object( + 'providerKey', NEW.provider_key, + 'canonicalModelKey', NEW.canonical_model_key, + 'providerModelName', NEW.provider_model_name, + 'modelType', NEW.model_type, + 'modelAlias', NEW.display_name, + 'capabilities', NEW.capabilities, + 'baseBillingConfig', NEW.base_billing_config, + 'defaultRateLimitPolicy', NEW.default_rate_limit_policy, + 'pricingRuleSetId', COALESCE(NEW.pricing_rule_set_id::text, ''), + 'runtimePolicySetId', COALESCE(NEW.runtime_policy_set_id::text, ''), + 'runtimePolicyOverride', NEW.runtime_policy_override, + 'metadata', NEW.metadata, + 'pricingVersion', NEW.pricing_version, + 'status', NEW.status + ); + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; diff --git a/apps/api/migrations/0022_sync_platform_model_type_from_base.sql b/apps/api/migrations/0022_sync_platform_model_type_from_base.sql new file mode 100644 index 0000000..69ad26b --- /dev/null +++ b/apps/api/migrations/0022_sync_platform_model_type_from_base.sql @@ -0,0 +1,9 @@ +UPDATE platform_models m +SET model_type = b.model_type, + updated_at = now() +FROM base_model_catalog b +WHERE m.base_model_id = b.id + AND jsonb_typeof(b.model_type) = 'array' + AND jsonb_array_length(b.model_type) > 0 + AND m.model_type <> b.model_type + AND COALESCE(m.capabilities, '{}'::jsonb) = COALESCE(b.capabilities, '{}'::jsonb); diff --git a/apps/api/migrations/0023_wallet_task_billing_schema.sql b/apps/api/migrations/0023_wallet_task_billing_schema.sql new file mode 100644 index 0000000..cd85776 --- /dev/null +++ b/apps/api/migrations/0023_wallet_task_billing_schema.sql @@ -0,0 +1,48 @@ +ALTER TABLE gateway_wallet_accounts + ADD COLUMN IF NOT EXISTS total_recharged numeric NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS total_spent numeric NOT NULL DEFAULT 0; + +ALTER TABLE gateway_wallet_transactions + ADD COLUMN IF NOT EXISTS account_id uuid REFERENCES gateway_wallet_accounts(id) ON DELETE CASCADE, + ADD COLUMN IF NOT EXISTS direction text, + ADD COLUMN IF NOT EXISTS balance_before numeric, + ADD COLUMN IF NOT EXISTS idempotency_key text; + +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_name = 'gateway_wallet_transactions' + AND column_name = 'wallet_account_id' + ) THEN + EXECUTE 'UPDATE gateway_wallet_transactions SET account_id = wallet_account_id WHERE account_id IS NULL'; + END IF; +END $$; + +UPDATE gateway_wallet_transactions +SET direction = CASE + WHEN transaction_type IN ('recharge', 'refund', 'credit') THEN 'credit' + ELSE 'debit' +END +WHERE direction IS NULL; + +UPDATE gateway_wallet_transactions +SET balance_before = COALESCE(balance_after, 0) + CASE + WHEN direction = 'debit' THEN COALESCE(amount, 0) + ELSE -COALESCE(amount, 0) +END +WHERE balance_before IS NULL; + +ALTER TABLE gateway_wallet_transactions + ALTER COLUMN direction SET DEFAULT 'debit', + ALTER COLUMN direction SET NOT NULL, + ALTER COLUMN balance_before SET DEFAULT 0, + ALTER COLUMN balance_before SET NOT NULL; + +CREATE INDEX IF NOT EXISTS idx_gateway_wallet_transactions_account + ON gateway_wallet_transactions(account_id, created_at DESC); + +CREATE UNIQUE INDEX IF NOT EXISTS uniq_gateway_wallet_tx_idempotency + ON gateway_wallet_transactions(account_id, idempotency_key) + WHERE idempotency_key IS NOT NULL; diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx index 13fe34b..b02b89e 100644 --- a/apps/web/src/App.tsx +++ b/apps/web/src/App.tsx @@ -49,6 +49,7 @@ import { listPricingRules, listPricingRuleSets, listRuntimePolicySets, + listTasks, listPublicBaseModels, listPublicCatalogProviders, listRateLimitWindows, @@ -121,6 +122,7 @@ type DataKey = | 'tenants' | 'users' | 'userGroups' + | 'tasks' | 'accessRules' | 'apiKeys'; @@ -157,6 +159,7 @@ export function App() { const [selectedPlaygroundApiKeyId, setSelectedPlaygroundApiKeyId] = useState(''); const [taskForm, setTaskForm] = useState({ kind: 'chat.completions', model: 'gpt-4o-mini', prompt: '用一句话确认 AI Gateway simulation 链路正常。' }); const [taskResult, setTaskResult] = useState(null); + const [tasks, setTasks] = useState([]); const [coreState, setCoreState] = useState('idle'); const [coreMessage, setCoreMessage] = useState(''); const [state, setState] = useState('idle'); @@ -228,10 +231,11 @@ export function App() { rateLimitWindows, runtimePolicySets, taskResult, + tasks, tenants, userGroups, users, - }), [accessRules, apiKeys, baseModels, models, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runtimePolicySets, taskResult, tenants, userGroups, users]); + }), [accessRules, apiKeys, baseModels, models, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runtimePolicySets, taskResult, tasks, tenants, userGroups, users]); async function refresh(nextToken = token) { await ensureRouteData(nextToken, true); @@ -327,6 +331,9 @@ export function App() { case 'userGroups': setUserGroups((await listUserGroups(nextToken)).items); return; + case 'tasks': + setTasks((await listTasks(nextToken)).items); + return; case 'accessRules': setAccessRules((await (activePage === 'workspace' && workspaceSection === 'apiKeys' ? listApiKeyAccessRules(nextToken) @@ -628,6 +635,8 @@ export function App() { const response = await runTask(credential, taskForm); const detail = await getTask(credential, response.task.id); setTaskResult(detail); + setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]); + invalidateDataKeys('tasks'); setCoreState('ready'); setCoreMessage(`${taskForm.kind} 已通过 ${apiKeySecret ? '本地 API Key' : '当前 Access Token'} 完成 simulation。`); } catch (err) { @@ -660,6 +669,7 @@ export function App() { setApiKeySecretsById({}); setSelectedPlaygroundApiKeyId(''); setTaskResult(null); + setTasks([]); setCoreMessage(''); navigatePath('/'); } @@ -852,10 +862,16 @@ export function App() { function platformModelIsSelected(model: PlatformModel, selectedModels: PlatformModelBindingInput[]) { return selectedModels.some((selected) => { if (selected.baseModelId && model.baseModelId) return selected.baseModelId === model.baseModelId; - return selected.modelName === model.modelName && selected.modelType === model.modelType; + return selected.modelName === model.modelName && sameModelTypes(selected.modelType, model.modelType); }); } +function sameModelTypes(left: string[], right: string[]) { + if (left.length !== right.length) return false; + const rightSet = new Set(right); + return left.every((type) => rightSet.has(type)); +} + function mergeExistingPlatformModelInput(input: PlatformModelBindingInput, currentModels: PlatformModel[], platformId: string): PlatformModelBindingInput { const existing = currentModels.find((model) => model.platformId === platformId && platformModelIsSelected(model, [input])); if (!existing) return input; @@ -924,6 +940,7 @@ function dataKeysForRoute( if (activePage === 'workspace') { if (workspaceSection === 'overview') return ['users', 'userGroups', 'apiKeys']; if (workspaceSection === 'apiKeys') return ['apiKeys', 'accessRules', 'playgroundModels']; + if (workspaceSection === 'tasks') return ['tasks']; return []; } diff --git a/apps/web/src/api.ts b/apps/web/src/api.ts index 062a8cd..79af9e0 100644 --- a/apps/web/src/api.ts +++ b/apps/web/src/api.ts @@ -582,6 +582,10 @@ export async function getTask(token: string, taskId: string): Promise(`/api/v1/tasks/${taskId}`, { token }); } +export async function listTasks(token: string, limit = 50): Promise> { + return request>(`/api/v1/tasks?limit=${encodeURIComponent(String(limit))}`, { token }); +} + export function resolveApiAssetUrl(src: string) { if (/^(https?:|data:|blob:)/i.test(src)) return src; return `${API_BASE}${src.startsWith('/') ? src : `/${src}`}`; diff --git a/apps/web/src/app-state.ts b/apps/web/src/app-state.ts index eb95b9b..93e8e51 100644 --- a/apps/web/src/app-state.ts +++ b/apps/web/src/app-state.ts @@ -27,6 +27,7 @@ export interface ConsoleData { rateLimitWindows: RateLimitWindow[]; runtimePolicySets: RuntimePolicySet[]; taskResult: GatewayTask | null; + tasks: GatewayTask[]; tenants: GatewayTenant[]; userGroups: UserGroup[]; users: GatewayUser[]; diff --git a/apps/web/src/components/Dashboard.tsx b/apps/web/src/components/Dashboard.tsx index 34af16c..4eada04 100644 --- a/apps/web/src/components/Dashboard.tsx +++ b/apps/web/src/components/Dashboard.tsx @@ -73,7 +73,7 @@ export function Dashboard(props: { [item.modelName, item.modelType, item.provider ?? item.platformName ?? '-', item.enabled ? '是' : '否'])} + rows={props.models.map((item) => [item.modelName, item.modelType.join(', '), item.provider ?? item.platformName ?? '-', item.enabled ? '是' : '否'])} title="模型" /> diff --git a/apps/web/src/pages/AdminPage.tsx b/apps/web/src/pages/AdminPage.tsx index 574d063..cdd4c12 100644 --- a/apps/web/src/pages/AdminPage.tsx +++ b/apps/web/src/pages/AdminPage.tsx @@ -173,8 +173,8 @@ function identityPanelProps(props: { function OverviewPanel(props: { data: ConsoleData; stats: StatItem[] }) { const enabledPlatforms = props.data.platforms.filter((item) => item.status === 'enabled'); - const chatModels = props.data.models.filter((item) => item.modelType === 'chat' && item.enabled); - const imageModels = props.data.models.filter((item) => item.modelType === 'image' && item.enabled); + const chatModels = props.data.models.filter((item) => item.modelType.includes('text_generate') && item.enabled); + const imageModels = props.data.models.filter((item) => item.modelType.some((type) => type.includes('image')) && item.enabled); return (
@@ -229,7 +229,7 @@ function OverviewPanel(props: { data: ConsoleData; stats: StatItem[] }) { rows={[ ['对话', chatModels.length, 'Phase 1', '/v1/chat/completions'], ['图像', imageModels.length, 'Phase 1', '/v1/images/*'], - ['视频', props.data.models.filter((item) => item.modelType === 'video').length, 'Next', '/v1/videos/*'], + ['视频', props.data.models.filter((item) => item.modelType.some((type) => type.includes('video'))).length, 'Next', '/v1/videos/*'], ]} /> diff --git a/apps/web/src/pages/ModelsPage.tsx b/apps/web/src/pages/ModelsPage.tsx index 4f5bd6e..0428c2f 100644 --- a/apps/web/src/pages/ModelsPage.tsx +++ b/apps/web/src/pages/ModelsPage.tsx @@ -9,7 +9,7 @@ import type { import type { ConsoleData } from '../app-state'; import { PageHeader } from '../components/PageHeader'; import { Badge, Card, CardContent, Input } from '../components/ui'; -import { primaryBaseModelType, stableModelAlias } from './admin/platform-form'; +import { stableModelAlias } from './admin/platform-form'; type ModelListItem = { id: string; @@ -17,7 +17,7 @@ type ModelListItem = { platformName?: string; modelName: string; modelAlias?: string; - modelType: string; + modelType: string[]; displayName: string; capabilities?: Record; pricingMode: string; @@ -69,7 +69,7 @@ const publicModels: PlatformModel[] = [ platformName: 'OpenAI Simulation', modelName: 'gpt-4o-mini', modelAlias: 'gpt-4o-mini', - modelType: 'chat', + modelType: ['text_generate'], displayName: 'gpt-4o-mini', capabilities: { multimodal: true }, pricingMode: 'inherit', @@ -84,7 +84,7 @@ const publicModels: PlatformModel[] = [ platformName: 'OpenAI Simulation', modelName: 'gpt-image-1', modelAlias: 'gpt-image-1', - modelType: 'image', + modelType: ['image_generate', 'image_edit'], displayName: 'gpt-image-1', capabilities: { imageEdit: true }, pricingMode: 'inherit', @@ -99,7 +99,7 @@ const publicModels: PlatformModel[] = [ platformName: 'Gemini Simulation', modelName: 'gemini-2.0-flash', modelAlias: 'gemini-2.0-flash', - modelType: 'chat', + modelType: ['text_generate'], displayName: 'gemini-2.0-flash', capabilities: { multimodal: true, vision: true }, pricingMode: 'inherit_discount', @@ -148,7 +148,7 @@ export function ModelsPage(props: { data: ConsoleData }) { return sourceModels.filter((model) => { const providerInfo = providerMap.get(model.providerKey); const matchedProvider = provider === 'all' || model.providerKey === provider; - const matchedCapability = capability === 'all' || model.modelType === capability; + const matchedCapability = modelMatchesCapability(model.modelType, capability); const matchedQuery = [ model.modelName, model.modelAlias, @@ -331,7 +331,7 @@ function modelFromBaseModel(model: BaseModelCatalogItem): ModelListItem { providerKey: model.providerKey, modelName: model.providerModelName, modelAlias: stableModelAlias(model), - modelType: primaryBaseModelType(model), + modelType: model.modelType, displayName: stableModelAlias(model), capabilities: model.capabilities, pricingMode: 'inherit', @@ -363,7 +363,7 @@ function providerInitials(label: string) { } function tagsForModel(model: ModelListItem) { - const tags = [capabilityName(model.modelType)]; + const tags = model.modelType.map(capabilityName); const capabilities = model.capabilities ?? {}; if (capabilities.multimodal || capabilities.vision) tags.push('多模态'); if (capabilities.reasoning) tags.push('推理'); @@ -373,7 +373,23 @@ function tagsForModel(model: ModelListItem) { } function capabilityName(type: string) { - return capabilityFilters.find((item) => item.value === type)?.label ?? type; + const labels: Record = { + text_generate: '对话', + image_generate: '绘图', + image_edit: '图像编辑', + video_generate: '视频', + image_to_video: '图生视频', + audio_generate: '音频', + }; + return labels[type] ?? capabilityFilters.find((item) => item.value === type)?.label ?? type; +} + +function modelMatchesCapability(modelTypes: string[], capability: string) { + if (capability === 'all') return true; + if (capability === 'chat') return modelTypes.includes('text_generate') || modelTypes.includes('chat'); + if (capability === 'image') return modelTypes.some((type) => type.includes('image')); + if (capability === 'video') return modelTypes.some((type) => type.includes('video')); + return modelTypes.includes(capability); } function priceLabel(model: ModelListItem) { diff --git a/apps/web/src/pages/PlaygroundPage.tsx b/apps/web/src/pages/PlaygroundPage.tsx index d4639f7..0762593 100644 --- a/apps/web/src/pages/PlaygroundPage.tsx +++ b/apps/web/src/pages/PlaygroundPage.tsx @@ -988,8 +988,8 @@ function filterModelsForMode(models: PlatformModel[], mode: PlaygroundMode, hasR } function filterWithFallback(models: PlatformModel[], modelTypes: string[]) { - const exact = models.filter((model) => modelTypes.includes(model.modelType)); - return exact.length ? exact : models.filter((model) => modelTypes.some((type) => model.modelType.includes(type) || type.includes(model.modelType))); + const exact = models.filter((model) => model.modelType.some((type) => modelTypes.includes(type))); + return exact.length ? exact : models.filter((model) => modelTypes.some((type) => model.modelType.some((modelType) => modelType.includes(type) || type.includes(modelType)))); } function buildModelOptions(models: PlatformModel[]): ModelOption[] { diff --git a/apps/web/src/pages/WorkspacePage.tsx b/apps/web/src/pages/WorkspacePage.tsx index 1fda859..eedabb4 100644 --- a/apps/web/src/pages/WorkspacePage.tsx +++ b/apps/web/src/pages/WorkspacePage.tsx @@ -1,6 +1,6 @@ import { useMemo, useState, type FormEvent, type ReactNode } from 'react'; import { Copy, CreditCard, KeyRound, ListChecks, Plus, ShieldCheck, Trash2, UserRound } from 'lucide-react'; -import type { GatewayAccessRuleBatchRequest, GatewayApiKey, IntegrationPlatform, PlatformModel } from '@easyai-ai-gateway/contracts'; +import type { GatewayAccessRuleBatchRequest, GatewayApiKey, GatewayTask, IntegrationPlatform, PlatformModel } from '@easyai-ai-gateway/contracts'; import type { ConsoleData } from '../app-state'; import { EntityTable } from '../components/EntityTable'; import { Badge, Button, Card, CardContent, CardHeader, CardTitle, ConfirmDialog, DateTimePicker, FormDialog, Input, Label, Table, TableCell, TableHead, TableRow, Tabs } from '../components/ui'; @@ -297,30 +297,22 @@ function ApiKeyPanel(props: { } function TaskPanel(props: { data: ConsoleData }) { - const task = props.data.taskResult; - const usage = task?.usage ?? {}; - const tokenText = usage.totalTokens ? `${usage.totalTokens}` : '-'; - const chargeText = task?.finalChargeAmount ? `${task.finalChargeAmount}` : '-'; + const tasks = useMemo(() => { + const latest = props.data.taskResult; + if (!latest) return props.data.tasks; + return [latest, ...props.data.tasks.filter((item) => item.id !== latest.id)]; + }, [props.data.taskResult, props.data.tasks]); return ( 任务记录 - {task ? ( -
- {task.status} - {task.kind} - {task.model} -
- - - - - - -
-
{JSON.stringify({ result: task.result, usage: task.usage, billings: task.billings, billingSummary: task.billingSummary, metrics: task.metrics }, null, 2)}
+ {tasks.length ? ( +
+ {tasks.map((task) => ( + + ))}
) : (
@@ -332,6 +324,34 @@ function TaskPanel(props: { data: ConsoleData }) { ); } +function TaskRecord(props: { task: GatewayTask }) { + const usage = props.task.usage ?? {}; + const tokenText = usage.totalTokens ? `${usage.totalTokens}` : '-'; + const chargeText = props.task.finalChargeAmount !== undefined ? `${props.task.finalChargeAmount}` : '-'; + const badgeVariant = props.task.status === 'succeeded' ? 'success' : props.task.status === 'failed' ? 'destructive' : 'secondary'; + return ( +
+
+ {props.task.status} + {props.task.kind} + {props.task.model} + {formatDateTime(props.task.createdAt)} +
+
+ + + + + + + + +
+
{JSON.stringify({ result: props.task.result, usage: props.task.usage, billings: props.task.billings, billingSummary: props.task.billingSummary, metrics: props.task.metrics }, null, 2)}
+
+ ); +} + function InfoItem(props: { label: string; value: string }) { return (
diff --git a/apps/web/src/pages/admin/AccessPermissionEditor.tsx b/apps/web/src/pages/admin/AccessPermissionEditor.tsx index 8e27f5a..70013f8 100644 --- a/apps/web/src/pages/admin/AccessPermissionEditor.tsx +++ b/apps/web/src/pages/admin/AccessPermissionEditor.tsx @@ -304,7 +304,7 @@ function buildPlatformTree(platforms: IntegrationPlatform[], platformModels: Pla .map((model) => ({ id: model.id, name: modelLabel(model), - subtitle: `${model.modelType} / ${model.modelName}`, + subtitle: `${model.modelType.join(', ')} / ${model.modelName}`, })), })).sort((a, b) => a.name.localeCompare(b.name)); } diff --git a/apps/web/src/pages/admin/PlatformManagementPanel.tsx b/apps/web/src/pages/admin/PlatformManagementPanel.tsx index e4ab2cc..cb83b95 100644 --- a/apps/web/src/pages/admin/PlatformManagementPanel.tsx +++ b/apps/web/src/pages/admin/PlatformManagementPanel.tsx @@ -59,7 +59,7 @@ export function PlatformManagementPanel(props: { model.displayName, model.modelName, model.modelAlias, - model.modelType, + ...model.modelType, model.provider, platform?.name, platform?.internalName, @@ -473,7 +473,7 @@ function PlatformModelTable(props: { {model.modelType}, + {model.modelType.join(', ')}, {model.enabled ? 'enabled' : 'disabled'}, ]} chips={platformModelChips(model)} @@ -889,7 +889,7 @@ function findBaseModelForPlatformModel(platform: IntegrationPlatform | undefined return baseModels.find((item) => item.id === model.baseModelId) ?? baseModels.find((item) => item.canonicalModelKey === model.modelAlias) ?? baseModels.find((item) => stableModelAlias(item) === model.modelAlias) ?? - baseModels.find((item) => item.providerKey === platform?.provider && item.providerModelName === model.modelName && baseModelTypes(item).includes(model.modelType)); + baseModels.find((item) => item.providerKey === platform?.provider && item.providerModelName === model.modelName && model.modelType.some((type) => baseModelTypes(item).includes(type))); } function readPlatformModelIconPath(model: PlatformModel, baseModel?: BaseModelCatalogItem) { diff --git a/apps/web/src/pages/admin/platform-form.ts b/apps/web/src/pages/admin/platform-form.ts index cdd232a..9fea8dc 100644 --- a/apps/web/src/pages/admin/platform-form.ts +++ b/apps/web/src/pages/admin/platform-form.ts @@ -149,7 +149,7 @@ export function platformModelPayloads(models: BaseModelCatalogItem[], form: Plat canonicalModelKey: model.canonicalModelKey, modelName: model.providerModelName, modelAlias: stableModelAlias(model), - modelType: primaryBaseModelType(model), + modelType: baseModelTypes(model), displayName: stableModelAlias(model) || model.providerModelName, pricingMode: 'inherit_discount', discountFactor: optionalPositiveNumber(form.modelDiscountFactors[model.id]) ?? optionalPositiveNumber(form.modelDiscountFactor), diff --git a/apps/web/src/pages/playground-media.tsx b/apps/web/src/pages/playground-media.tsx index 5d2e95b..fc85413 100644 --- a/apps/web/src/pages/playground-media.tsx +++ b/apps/web/src/pages/playground-media.tsx @@ -638,7 +638,7 @@ function capabilityTypeKeys( : [contextKey, 'text_to_video', 'image_to_video', 'omni_video', 'video_generate', 'video']; return uniqueStrings([ ...stringListFromCapability(source.originalTypes), - model.modelType, + ...model.modelType, ...modeTypeHints.filter((item): item is string => Boolean(item)), ]); } diff --git a/apps/web/src/styles.css b/apps/web/src/styles.css index 3021a96..4fb1808 100644 --- a/apps/web/src/styles.css +++ b/apps/web/src/styles.css @@ -224,6 +224,23 @@ strong { gap: 14px; } +.taskList { + display: grid; + gap: 14px; +} + +.taskRecordHeader { + display: flex; + flex-wrap: wrap; + align-items: center; + gap: 8px 12px; + color: var(--muted-foreground); +} + +.taskRecordHeader strong { + color: var(--text-strong); +} + .appShell { min-height: 100vh; } diff --git a/apps/web/src/types.ts b/apps/web/src/types.ts index 6644c7b..8ead73f 100644 --- a/apps/web/src/types.ts +++ b/apps/web/src/types.ts @@ -57,7 +57,7 @@ export interface PlatformModelForm { canonicalModelKey: string; modelName: string; modelAlias: string; - modelType: string; + modelType: string[]; pricingRuleSetId: string; discountFactor: string; } @@ -84,7 +84,7 @@ export interface PlatformModelBindingInput { baseModelId?: string; modelName: string; modelAlias?: string; - modelType: string; + modelType: string[]; displayName?: string; pricingMode?: string; retryPolicy?: Record; diff --git a/docs/test/loopback-test-checklist.md b/docs/test/loopback-test-checklist.md index 9bff7d3..dbe124f 100644 --- a/docs/test/loopback-test-checklist.md +++ b/docs/test/loopback-test-checklist.md @@ -36,6 +36,19 @@ 状态约定:`未执行`、`执行中`、`成功`、`失败`、`阻塞`。 +### 本次自动化回环记录 + +| 项目 | 结果 | +| --- | --- | +| 测试批次 | `loopback-20260511-integration` | +| 执行方式 | `httptest` 本地服务 + 独立 PostgreSQL 测试库 `easyai_ai_gateway_test_codex` | +| 已通过验证 | 注册 / 登录 / API Key、API Key scopes 拦截、管理员接口拒绝 API Key、模型候选路由、真实 Chat / 兼容 Chat / 文生图 / 图生图 / 文生视频 / 图生视频 / 首尾帧视频、任务详情、任务列表、事件流、callback outbox、重试、运行策略限流、降级、自动禁用、计费规则集绑定、钱包扣费和交易流水 | +| 核心修复验证 | `base_model_catalog.model_type` 和 `platform_models.model_type` 均为 JSONB 数组;`doubao-4.5图像编辑` 同时包含 `image_generate` / `image_edit`;`豆包Seedance-1.5-pro` 同时包含 `video_generate` / `image_to_video`;运行时候选查询只用 `model_type` 命中 | +| 真实外部链路 | 已使用主库 `Playground API Key` 和已配置平台 KEY 验证通过;Chat=`qwen-plus`,图像=`doubao-4.5图像编辑`,视频=`豆包Seedance-1.5-pro` | +| 真实任务证据 | Chat `6b454b1f-29b8-48ef-8ebb-740062d2e8b4`;文生图 `2db64a7e-f01d-424c-a5eb-027ef58cacde`;图生图 `af4cae00-c6d7-4c70-8477-64326c5d5cfc`;文生视频 `2436e7aa-3518-442b-8559-1e48a9cb3462`;图生视频 `80cc9655-4c87-466d-b263-5df23d23c157`;首尾帧视频 `5353a0cf-efa8-4fbc-90fd-60d1897fd012` | +| 验证命令 | `AI_GATEWAY_TEST_DATABASE_URL=... go test ./...`、`pnpm --filter @easyai-ai-gateway/web typecheck`、`pnpm test`、`git diff --check` | +| 结论 | 本地自动化覆盖项和真实外部模型链路均为 `成功` | + ## 3. 测试数据准备 | ID | 任务 | 接口 / 方式 | 成功判定 | 状态 | 结果记录 | @@ -98,6 +111,7 @@ | HISTORY-03 | 事件记录保存 | `gateway_task_events` | 事件包含接收、运行、重试、完成或失败,`simulated` 标记准确 | 未执行 | 事件摘要待填写 | | HISTORY-04 | 回调 outbox 保存 | `gateway_task_callback_outbox`,开启回调配置时验证 | 每个 task event 写入幂等 outbox,`task_id + seq + callback_url` 唯一 | 未执行 | outbox 数量待填写 | | HISTORY-05 | API Key 身份保存 | 任务详情或数据库字段 | `apiKeyId`、`apiKeyName`、`apiKeyPrefix` 可追溯到发起 Key | 未执行 | key 信息待填写 | +| HISTORY-06 | 前端任务记录 | `GET /api/v1/tasks` + 工作台“任务记录”页 | 页面从服务端任务列表读取历史记录,不依赖当前会话里刚运行的单条 taskResult | 成功 | 已新增任务列表接口和前端列表渲染;自动化断言 task list 包含已完成任务 | ## 6. 定价规则绑定与扣费计算 @@ -109,6 +123,7 @@ | PRICE-02 | 绑定规则集到平台模型 | `POST /api/v1/platforms/{platformID}/models` 或更新模型 | 平台模型返回 `pricingRuleSetId=loopback-pricing-*` 或 `billingConfigOverride` 生效 | 未执行 | platformModelId 待填写 | | PRICE-03 | Chat 预估计费 | `POST /api/v1/pricing/estimate`,Chat 请求体 | 返回 `resolver=effective-pricing-v1`,`text_input` 与 `text_output` 金额按 tokens 计算 | 未执行 | estimate items 待填写 | | PRICE-04 | Chat 最终扣费 | 执行 Chat 成功任务 | `finalChargeAmount=sum(billings.amount)`,输入/输出 token 数与 usage 对齐 | 未执行 | usage、billings、finalChargeAmount 待填写 | +| PRICE-04B | 钱包实际扣费 | 执行 Chat 成功任务后读取钱包 | `gateway_wallet_accounts.balance` 按 `finalChargeAmount` 扣减,`gateway_wallet_transactions` 写入 `task_billing` 幂等流水 | 成功 | 集成测试断言 100 -> 99.972,交易 amount=0.028 | | PRICE-05 | 文生图扣费 | 执行文生图成功任务 | `resourceType=image`,数量、尺寸、质量权重参与计算 | 未执行 | billings 待填写 | | PRICE-06 | 图生图扣费 | 执行图生图成功任务 | `resourceType=image_edit`,优先使用编辑价格,缺省时回退图片价格 | 未执行 | billings 待填写 | | PRICE-07 | 视频扣费 | 执行文生视频、图生视频、首尾帧视频成功任务 | `resourceType=video`,数量、分辨率或时长权重按规则计算,三类视频请求均有计费记录 | 未执行 | billings 待填写 | @@ -175,7 +190,10 @@ | 时间 | 用例 ID | 失败现象 | 初步原因 | 修复位置 | 重跑结果 | | --- | --- | --- | --- | --- | --- | -| 待填写 | 待填写 | 待填写 | 待填写 | 待填写 | 待填写 | +| 2026-05-11 | TASK-IMAGE-01 | `doubao-4.5图像编辑` 使用 `1024x1024` 返回 `http_400`,提示像素不足 | 测试参数不符合真实模型最小尺寸约束 | 调整真实验证请求为 `2048x2048` | 文生图任务 `2db64a7e-f01d-424c-a5eb-027ef58cacde` 成功 | +| 2026-05-11 | TASK-VIDEO-02 | 图生视频传 `image` 时上游按 `r2v` 处理并返回不支持 `Seedance-1.5-pro` | Volces 适配层把通用 `image` 映射成 `reference_image`,而非图生视频首帧 | `apps/api/internal/clients/volces.go`,将 `image` 默认映射为 `first_frame`,显式 `reference_image` 才走参考图 | 图生视频任务 `80cc9655-4c87-466d-b263-5df23d23c157` 成功 | +| 2026-05-11 | TASK-VIDEO-02 / TASK-VIDEO-03 | 带图片的视频请求任务 `modelType` 仍记录为 `video_generate` | `videos.generations` 未根据请求体区分文生视频和图生视频 | `apps/api/internal/runner/service.go` / `pricing.go`,带图片、首帧或尾帧时使用 `image_to_video` | 图生视频与首尾帧视频均以 `modelType=image_to_video` 成功 | +| 2026-05-11 | HISTORY-06 | 前端“任务记录”没有任何记录 | 前端只展示本地 `taskResult`,后端没有当前用户任务列表接口 | `GET /api/v1/tasks`、`ListTasks`、工作台任务列表状态和渲染 | 集成测试验证任务列表返回已完成任务;前端 typecheck 成功 | ## 12. 清理清单 diff --git a/packages/contracts/src/index.ts b/packages/contracts/src/index.ts index 1f4c267..a8d5aff 100644 --- a/packages/contracts/src/index.ts +++ b/packages/contracts/src/index.ts @@ -560,7 +560,7 @@ export interface PlatformModel { platformName?: string; modelName: string; modelAlias?: string; - modelType: 'chat' | 'image' | 'video' | 'audio' | 'embedding' | string; + modelType: string[]; displayName: string; capabilityOverride?: Record; capabilities?: Record; @@ -625,6 +625,8 @@ export interface GatewayTask { responseDurationMs?: number; finishedAt?: string; error?: string; + errorCode?: string; + errorMessage?: string; createdAt: string; updatedAt: string; }