fix gateway loopback validation chains

This commit is contained in:
wangbo 2026-05-11 08:48:02 +08:00
parent ff666b1ece
commit ca7e76e815
42 changed files with 1641 additions and 129 deletions

View File

@ -40,6 +40,7 @@ type User struct {
APIKeySecret string `json:"apiKeySecret,omitempty"`
APIKeyName string `json:"apiKeyName,omitempty"`
APIKeyPrefix string `json:"apiKeyPrefix,omitempty"`
APIKeyScopes []string `json:"apiKeyScopes,omitempty"`
}
type contextKey string
@ -137,6 +138,7 @@ func (a *Authenticator) verifyJWT(tokenString string) (*User, error) {
APIKeySecret: stringClaim(claims, "apiKeySecret"),
APIKeyName: stringClaim(claims, "apiKeyName"),
APIKeyPrefix: stringClaim(claims, "apiKeyPrefix"),
APIKeyScopes: stringSliceClaim(claims, "apiKeyScopes"),
}
if user.Source == "" {
user.Source = "gateway"
@ -167,6 +169,7 @@ func (a *Authenticator) SignJWT(user *User, ttl time.Duration) (string, error) {
"apiKeyId": user.APIKeyID,
"apiKeyName": user.APIKeyName,
"apiKeyPrefix": user.APIKeyPrefix,
"apiKeyScopes": user.APIKeyScopes,
"iat": now.Unix(),
"exp": now.Add(ttl).Unix(),
}

View File

@ -34,7 +34,8 @@ func (c SimulationClient) Run(ctx context.Context, request Request) (Response, e
}
}
responseFinishedAt := time.Now()
if profile == "retryable_failure" {
switch profile {
case "retryable_failure":
return Response{}, &ClientError{
Code: "server_error",
Message: "simulated retryable failure",
@ -44,8 +45,27 @@ func (c SimulationClient) Run(ctx context.Context, request Request) (Response, e
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
Retryable: true,
}
}
if profile == "fatal_failure" || profile == "non_retryable_failure" {
case "rate_limit", "overloaded":
return Response{}, &ClientError{
Code: profile,
Message: "simulated " + profile,
RequestID: "simulated-request",
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
Retryable: true,
}
case "invalid_api_key":
return Response{}, &ClientError{
Code: "invalid_api_key",
Message: "simulated invalid_api_key",
RequestID: "simulated-request",
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
Retryable: false,
}
case "fatal_failure", "non_retryable_failure":
return Response{}, &ClientError{
Code: "bad_request",
Message: "simulated non-retryable failure",
@ -184,7 +204,7 @@ func simulatedVideoData(request Request) []any {
}
func simulatedUsage(request Request) Usage {
if request.ModelType == "chat" || request.Kind == "responses" {
if request.ModelType == "chat" || request.ModelType == "text_generate" || request.Kind == "responses" {
return Usage{InputTokens: 12, OutputTokens: 8, TotalTokens: 20}
}
return Usage{}

View File

@ -251,9 +251,18 @@ func buildVolcesContentFromBody(body map[string]any) []map[string]any {
}
}
appendURLContent("image_url", "first_frame", firstNonEmptyStringValue(body, "first_frame", "firstFrame"))
firstFrame := firstNonEmptyStringValue(body, "first_frame", "firstFrame")
appendURLContent("image_url", "first_frame", firstFrame)
appendURLContent("image_url", "last_frame", firstNonEmptyStringValue(body, "last_frame", "lastFrame"))
for _, url := range firstNonEmptyStringListFromAny(body["image"], body["images"], body["image_url"], body["imageUrl"], body["image_urls"], body["imageUrls"], body["reference_image"], body["referenceImage"]) {
imageURLs := firstNonEmptyStringListFromAny(body["image"], body["images"], body["image_url"], body["imageUrl"], body["image_urls"], body["imageUrls"])
if firstFrame == "" && len(imageURLs) > 0 {
appendURLContent("image_url", "first_frame", imageURLs[0])
imageURLs = imageURLs[1:]
}
for _, url := range imageURLs {
appendURLContent("image_url", "reference_image", url)
}
for _, url := range firstNonEmptyStringListFromAny(body["reference_image"], body["referenceImage"]) {
appendURLContent("image_url", "reference_image", url)
}
for _, url := range firstNonEmptyStringListFromAny(body["video"], body["video_url"], body["videoUrl"], body["reference_video"], body["referenceVideo"]) {

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"io"
"log/slog"
"math"
"net/http"
"net/http/httptest"
"os"
@ -123,8 +124,43 @@ func TestCoreLocalFlow(t *testing.T) {
if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil {
t.Fatalf("promote smoke user: %v", err)
}
var smokeGatewayUserID string
if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil {
t.Fatalf("read smoke gateway user id: %v", err)
}
doJSON(t, server.URL, http.MethodGet, "/api/admin/models", apiKeyResponse.Secret, nil, http.StatusForbidden, nil)
var chatOnlyAPIKeyResponse struct {
Secret string `json:"secret"`
APIKey struct {
ID string `json:"id"`
} `json:"apiKey"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/api-keys", loginResponse.AccessToken, map[string]any{
"name": "chat only key",
"scopes": []string{"chat"},
}, http.StatusCreated, &chatOnlyAPIKeyResponse)
var taskCountBefore int
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountBefore); err != nil {
t.Fatalf("count tasks before scoped request: %v", err)
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": "gpt-image-1",
"prompt": "scope should block this",
}, http.StatusForbidden, nil)
doJSON(t, server.URL, http.MethodPost, "/api/v1/pricing/estimate", chatOnlyAPIKeyResponse.Secret, map[string]any{
"kind": "images.generations",
"model": "gpt-image-1",
"prompt": "scope should block this estimate",
}, http.StatusForbidden, nil)
var taskCountAfter int
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountAfter); err != nil {
t.Fatalf("count tasks after scoped request: %v", err)
}
if taskCountAfter != taskCountBefore {
t.Fatalf("scoped API key rejection should happen before task creation, before=%d after=%d", taskCountBefore, taskCountAfter)
}
inviteCode := "INVITE-" + suffixText
if _, err := testPool.Exec(ctx, `
INSERT INTO gateway_invitations (invite_code, max_uses, metadata)
@ -322,6 +358,317 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
t.Fatalf("unexpected image edit task: %+v", imageEditResponse.Task)
}
var gptImageModelTypesRaw []byte
if err := testPool.QueryRow(ctx, `
SELECT model_type
FROM platform_models
WHERE model_name = 'gpt-image-1'
AND model_type @> '["image_generate","image_edit"]'::jsonb
LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil {
t.Fatalf("gpt-image-1 platform model should store both image model types: %v", err)
}
var gptImageModelTypes []string
if err := json.Unmarshal(gptImageModelTypesRaw, &gptImageModelTypes); err != nil {
t.Fatalf("decode gpt-image-1 model_type: %v raw=%s", err, string(gptImageModelTypesRaw))
}
if !stringSliceContains(gptImageModelTypes, "image_generate") || !stringSliceContains(gptImageModelTypes, "image_edit") {
t.Fatalf("gpt-image-1 model_type should include generation and edit types: %+v", gptImageModelTypes)
}
deniedModel := "permission-deny-smoke-" + suffixText
var deniedPlatformModel struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": deniedModel,
"modelAlias": deniedModel,
"modelType": []string{"text_generate"},
"displayName": "Permission Deny Smoke",
}, http.StatusCreated, &deniedPlatformModel)
doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{
"subjectType": "api_key",
"subjectId": chatOnlyAPIKeyResponse.APIKey.ID,
"resourceType": "platform_model",
"resourceId": deniedPlatformModel.ID,
"effect": "deny",
"priority": 10,
"minPermissionLevel": 0,
"status": "active",
}, http.StatusCreated, nil)
var deniedTask struct {
Task struct {
Status string `json:"status"`
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": deniedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "permission deny"}},
}, http.StatusAccepted, &deniedTask)
if deniedTask.Task.Status != "failed" || deniedTask.Task.ErrorCode != "no_model_candidate" {
t.Fatalf("deny access rule should hide denied model from runtime candidates: %+v", deniedTask.Task)
}
var restrictedModels struct {
Items []struct {
ID string `json:"id"`
ModelName string `json:"modelName"`
} `json:"items"`
}
doJSON(t, server.URL, http.MethodGet, "/api/v1/playground/models", chatOnlyAPIKeyResponse.Secret, nil, http.StatusOK, &restrictedModels)
if modelListContains(restrictedModels.Items, deniedPlatformModel.ID) {
t.Fatalf("deny access rule should hide denied model from playable list: %+v", restrictedModels.Items)
}
controlledModel := "permission-allow-smoke-" + suffixText
var controlledPlatformModel struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": controlledModel,
"modelAlias": controlledModel,
"modelType": []string{"text_generate"},
"displayName": "Permission Allow Smoke",
}, http.StatusCreated, &controlledPlatformModel)
doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{
"subjectType": "api_key",
"subjectId": apiKeyResponse.APIKey.ID,
"resourceType": "platform_model",
"resourceId": controlledPlatformModel.ID,
"effect": "allow",
"priority": 10,
"minPermissionLevel": 0,
"status": "active",
}, http.StatusCreated, nil)
var blockedControlledTask struct {
Task struct {
Status string `json:"status"`
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": controlledModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "allow should block other keys"}},
}, http.StatusAccepted, &blockedControlledTask)
if blockedControlledTask.Task.Status != "failed" || blockedControlledTask.Task.ErrorCode != "no_model_candidate" {
t.Fatalf("allow access rule should make the resource unavailable to unmatched subjects: %+v", blockedControlledTask.Task)
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{
"subjectType": "api_key",
"subjectId": chatOnlyAPIKeyResponse.APIKey.ID,
"resourceType": "platform_model",
"resourceId": controlledPlatformModel.ID,
"effect": "allow",
"priority": 10,
"minPermissionLevel": 0,
"status": "active",
}, http.StatusCreated, nil)
var allowedControlledTask struct {
Task struct {
Status string `json:"status"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": controlledModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "allow should pass"}},
}, http.StatusAccepted, &allowedControlledTask)
if allowedControlledTask.Task.Status != "succeeded" {
t.Fatalf("matching allow access rule should make the controlled model usable: %+v", allowedControlledTask.Task)
}
var customPricingRuleSet struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/pricing/rule-sets", loginResponse.AccessToken, map[string]any{
"ruleSetKey": "smoke-pricing-" + suffixText,
"name": "Smoke Pricing",
"currency": "resource",
"rules": []map[string]any{
{"ruleKey": "text_input", "displayName": "Text Input", "resourceType": "text_input", "unit": "1k_tokens", "basePrice": 1},
{"ruleKey": "text_output", "displayName": "Text Output", "resourceType": "text_output", "unit": "1k_tokens", "basePrice": 2},
{"ruleKey": "image", "displayName": "Image", "resourceType": "image", "unit": "image", "basePrice": 7},
{"ruleKey": "image_edit", "displayName": "Image Edit", "resourceType": "image_edit", "unit": "image", "basePrice": 11},
{"ruleKey": "video", "displayName": "Video", "resourceType": "video", "unit": "video", "basePrice": 13},
},
}, http.StatusCreated, &customPricingRuleSet)
pricingModel := "pricing-smoke-" + suffixText
var pricingPlatformModel map[string]any
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": pricingModel,
"modelAlias": pricingModel,
"modelType": []string{"text_generate"},
"displayName": "Pricing Smoke",
"pricingRuleSetId": customPricingRuleSet.ID,
}, http.StatusCreated, &pricingPlatformModel)
if _, err := testPool.Exec(ctx, `
UPDATE gateway_wallet_accounts
SET balance = 100, total_spent = 0, updated_at = now()
WHERE gateway_user_id = $1::uuid
AND currency = 'resource'`, smokeGatewayUserID); err != nil {
t.Fatalf("seed wallet balance: %v", err)
}
var walletBalanceBefore float64
if err := testPool.QueryRow(ctx, `
SELECT balance::float8
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
AND currency = 'resource'`, smokeGatewayUserID).Scan(&walletBalanceBefore); err != nil {
t.Fatalf("read wallet balance before pricing task: %v", err)
}
var pricingTask struct {
Task struct {
ID string `json:"id"`
Status string `json:"status"`
BillingSummary map[string]any `json:"billingSummary"`
FinalChargeAmount float64 `json:"finalChargeAmount"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": pricingModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "priced ping"}},
}, http.StatusAccepted, &pricingTask)
if pricingTask.Task.Status != "succeeded" || !floatNear(pricingTask.Task.FinalChargeAmount, 0.028) {
t.Fatalf("custom pricing rule set should drive text billing, got task=%+v", pricingTask.Task)
}
var walletBalanceAfter float64
var walletSpentAfter float64
if err := testPool.QueryRow(ctx, `
SELECT balance::float8, total_spent::float8
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
AND currency = 'resource'`, smokeGatewayUserID).Scan(&walletBalanceAfter, &walletSpentAfter); err != nil {
t.Fatalf("read wallet balance after pricing task: %v", err)
}
if !floatNear(walletBalanceAfter, walletBalanceBefore-pricingTask.Task.FinalChargeAmount) || !floatNear(walletSpentAfter, pricingTask.Task.FinalChargeAmount) {
t.Fatalf("task billing should debit wallet balance and spent totals, before=%f after=%f spent=%f task=%+v", walletBalanceBefore, walletBalanceAfter, walletSpentAfter, pricingTask.Task)
}
var walletTransactionAmount float64
if err := testPool.QueryRow(ctx, `
SELECT amount::float8
FROM gateway_wallet_transactions
WHERE reference_type = 'gateway_task'
AND reference_id = $1
AND transaction_type = 'task_billing'`, pricingTask.Task.ID).Scan(&walletTransactionAmount); err != nil {
t.Fatalf("read task billing wallet transaction: %v", err)
}
if !floatNear(walletTransactionAmount, pricingTask.Task.FinalChargeAmount) {
t.Fatalf("task billing transaction amount=%f want=%f", walletTransactionAmount, pricingTask.Task.FinalChargeAmount)
}
rateLimitedModel := "rate-limit-smoke-" + suffixText
var rateLimitPolicySet struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{
"policyKey": "smoke-rate-limit-" + suffixText,
"name": "Smoke Rate Limit",
"retryPolicy": map[string]any{
"enabled": false,
"maxAttempts": 1,
},
"rateLimitPolicy": map[string]any{
"rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": 60}},
},
}, http.StatusCreated, &rateLimitPolicySet)
var rateLimitPlatformModel map[string]any
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": rateLimitedModel,
"modelAlias": rateLimitedModel,
"modelType": []string{"text_generate"},
"displayName": "Rate Limit Smoke",
"runtimePolicySetId": rateLimitPolicySet.ID,
"runtimePolicyOverride": map[string]any{},
}, http.StatusCreated, &rateLimitPlatformModel)
var rateLimitTaskOne struct {
Task struct {
Status string `json:"status"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": rateLimitedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "first"}},
}, http.StatusAccepted, &rateLimitTaskOne)
if rateLimitTaskOne.Task.Status != "succeeded" {
t.Fatalf("first rate-limited task should succeed: %+v", rateLimitTaskOne.Task)
}
var rateLimitTaskTwo struct {
Task struct {
Status string `json:"status"`
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": rateLimitedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "second"}},
}, http.StatusAccepted, &rateLimitTaskTwo)
if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" {
t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task)
}
videoRouteModel := "video-route-smoke-" + suffixText
var videoRoutePlatformModel map[string]any
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": videoRouteModel,
"modelAlias": videoRouteModel,
"modelType": []string{"video_generate", "image_to_video"},
"displayName": "Video Route Smoke",
}, http.StatusCreated, &videoRoutePlatformModel)
var textToVideoTask struct {
Task struct {
Status string `json:"status"`
ModelType string `json:"modelType"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/videos/generations", apiKeyResponse.Secret, map[string]any{
"model": videoRouteModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"prompt": "text to video route",
}, http.StatusAccepted, &textToVideoTask)
if textToVideoTask.Task.Status != "succeeded" || textToVideoTask.Task.ModelType != "video_generate" {
t.Fatalf("text-to-video request should use video_generate model_type: %+v", textToVideoTask.Task)
}
var imageToVideoTask struct {
Task struct {
Status string `json:"status"`
ModelType string `json:"modelType"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/videos/generations", apiKeyResponse.Secret, map[string]any{
"model": videoRouteModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"prompt": "image to video route",
"image": "https://example.com/source.png",
}, http.StatusAccepted, &imageToVideoTask)
if imageToVideoTask.Task.Status != "succeeded" || imageToVideoTask.Task.ModelType != "image_to_video" {
t.Fatalf("image-to-video request should use image_to_video model_type: %+v", imageToVideoTask.Task)
}
failoverModel := "phase1-failover-" + suffixText
var failedPlatform struct {
ID string `json:"id"`
@ -353,7 +700,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": failoverModel,
"modelAlias": failoverModel,
"modelType": "chat",
"modelType": []string{"text_generate"},
"displayName": "Failover Smoke",
"retryPolicy": map[string]any{"enabled": true, "maxAttempts": 2},
}, http.StatusCreated, &platformModel)
@ -376,6 +723,142 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
t.Fatalf("failover task should succeed through second client: %+v", failoverTask.Task)
}
var degradePolicySet struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{
"policyKey": "smoke-degrade-" + suffixText,
"name": "Smoke Degrade",
"degradePolicy": map[string]any{
"enabled": true,
"keywords": []string{"rate_limit"},
"cooldownSeconds": 600,
},
}, http.StatusCreated, &degradePolicySet)
degradeModel := "degrade-smoke-" + suffixText
var degradedPlatform struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{
"provider": "openai",
"platformKey": "openai-degrade-" + suffixText,
"name": "OpenAI Degrade Failure",
"baseUrl": "https://api.openai.com/v1",
"authType": "bearer",
"credentials": map[string]any{"mode": "simulation", "simulationFailure": "rate_limit"},
"priority": 30,
}, http.StatusCreated, &degradedPlatform)
var degradeSuccessPlatform struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{
"provider": "openai",
"platformKey": "openai-degrade-success-" + suffixText,
"name": "OpenAI Degrade Success",
"baseUrl": "https://api.openai.com/v1",
"authType": "bearer",
"credentials": map[string]any{"mode": "simulation"},
"priority": 40,
}, http.StatusCreated, &degradeSuccessPlatform)
for _, item := range []struct {
platformID string
runtimePolicySetID string
}{
{platformID: degradedPlatform.ID, runtimePolicySetID: degradePolicySet.ID},
{platformID: degradeSuccessPlatform.ID},
} {
var platformModel map[string]any
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+item.platformID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": degradeModel,
"modelAlias": degradeModel,
"modelType": []string{"text_generate"},
"displayName": "Degrade Smoke",
"runtimePolicySetId": item.runtimePolicySetID,
}, http.StatusCreated, &platformModel)
}
var degradeTask struct {
Task struct {
ID string `json:"id"`
Status string `json:"status"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": degradeModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "degrade please"}},
}, http.StatusAccepted, &degradeTask)
if degradeTask.Task.Status != "succeeded" {
t.Fatalf("degrade task should fail over after cooling down failed platform: %+v", degradeTask.Task)
}
var cooledDown bool
if err := testPool.QueryRow(ctx, `SELECT COALESCE(cooldown_until > now(), false) FROM integration_platforms WHERE id = $1::uuid`, degradedPlatform.ID).Scan(&cooledDown); err != nil {
t.Fatalf("read degraded platform cooldown: %v", err)
}
if !cooledDown {
t.Fatal("degrade policy should set platform cooldown_until")
}
var autoDisablePolicySet struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{
"policyKey": "smoke-auto-disable-" + suffixText,
"name": "Smoke Auto Disable",
"autoDisablePolicy": map[string]any{
"enabled": true,
"keywords": []string{"invalid_api_key"},
"threshold": 1,
},
}, http.StatusCreated, &autoDisablePolicySet)
autoDisableModel := "auto-disable-smoke-" + suffixText
var invalidKeyPlatform struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{
"provider": "openai",
"platformKey": "openai-invalid-key-" + suffixText,
"name": "OpenAI Invalid Key",
"baseUrl": "https://api.openai.com/v1",
"authType": "bearer",
"credentials": map[string]any{"mode": "simulation", "simulationFailure": "invalid_api_key"},
"priority": 50,
}, http.StatusCreated, &invalidKeyPlatform)
var invalidKeyPlatformModel map[string]any
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+invalidKeyPlatform.ID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": autoDisableModel,
"modelAlias": autoDisableModel,
"modelType": []string{"text_generate"},
"displayName": "Auto Disable Smoke",
"runtimePolicySetId": autoDisablePolicySet.ID,
}, http.StatusCreated, &invalidKeyPlatformModel)
var autoDisableTask struct {
Task struct {
Status string `json:"status"`
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": autoDisableModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "disable please"}},
}, http.StatusAccepted, &autoDisableTask)
if autoDisableTask.Task.Status != "failed" || autoDisableTask.Task.ErrorCode != "invalid_api_key" {
t.Fatalf("auto disable task should fail with invalid_api_key: %+v", autoDisableTask.Task)
}
var invalidKeyPlatformStatus string
if err := testPool.QueryRow(ctx, `SELECT status FROM integration_platforms WHERE id = $1::uuid`, invalidKeyPlatform.ID).Scan(&invalidKeyPlatformStatus); err != nil {
t.Fatalf("read invalid key platform status: %v", err)
}
if invalidKeyPlatformStatus != "disabled" {
t.Fatalf("auto disable policy should disable platform, got %q", invalidKeyPlatformStatus)
}
var taskDetail struct {
ID string `json:"id"`
Status string `json:"status"`
@ -393,6 +876,21 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
if taskDetail.APIKeyName != apiKeyResponse.APIKey.Name || taskDetail.RequestID == "" || taskDetail.Usage["totalTokens"] == nil || taskDetail.FinalChargeAmount <= 0 {
t.Fatalf("task detail should expose enriched record fields: %+v", taskDetail)
}
var taskList struct {
Items []struct {
ID string `json:"id"`
Status string `json:"status"`
APIKeyName string `json:"apiKeyName"`
ModelType string `json:"modelType"`
FinalCharge float64 `json:"finalChargeAmount"`
ErrorCode string `json:"errorCode"`
ErrorMessage string `json:"errorMessage"`
} `json:"items"`
}
doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks?limit=20", loginResponse.AccessToken, nil, http.StatusOK, &taskList)
if !taskListContains(taskList.Items, taskResponse.Task.ID) || !taskListContains(taskList.Items, pricingTask.Task.ID) {
t.Fatalf("task list should include persisted task records, got %+v", taskList.Items)
}
req, err := http.NewRequest(http.MethodGet, server.URL+"/api/v1/tasks/"+taskResponse.Task.ID+"/events", nil)
if err != nil {
@ -411,6 +909,9 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
if !bytes.Contains(body, []byte("task.progress")) {
t.Fatalf("events response should include progress events body=%s", string(body))
}
if !bytes.Contains(body, []byte("task.billing.settled")) {
t.Fatalf("events response should include billing settlement event body=%s", string(body))
}
req, err = http.NewRequest(http.MethodGet, server.URL+"/api/v1/tasks/"+failoverTask.Task.ID+"/events", nil)
if err != nil {
@ -508,3 +1009,45 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri
}
}
}
func stringSliceContains(values []string, target string) bool {
for _, value := range values {
if value == target {
return true
}
}
return false
}
func modelListContains(values []struct {
ID string `json:"id"`
ModelName string `json:"modelName"`
}, target string) bool {
for _, value := range values {
if value.ID == target {
return true
}
}
return false
}
func taskListContains(values []struct {
ID string `json:"id"`
Status string `json:"status"`
APIKeyName string `json:"apiKeyName"`
ModelType string `json:"modelType"`
FinalCharge float64 `json:"finalChargeAmount"`
ErrorCode string `json:"errorCode"`
ErrorMessage string `json:"errorMessage"`
}, target string) bool {
for _, value := range values {
if value.ID == target {
return true
}
}
return false
}
func floatNear(value float64, expected float64) bool {
return math.Abs(value-expected) < 0.000001
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
@ -468,6 +469,10 @@ func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "model is required")
return
}
if !apiKeyScopeAllowed(user, kind) {
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
return
}
estimate, err := s.runner.Estimate(r.Context(), kind, model, body, user)
if err != nil {
if errors.Is(err, store.ErrNoModelCandidate) {
@ -509,6 +514,10 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
writeError(w, http.StatusBadRequest, "model is required")
return
}
if !apiKeyScopeAllowed(user, kind) {
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
return
}
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
Kind: kind,
@ -567,6 +576,36 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
})
}
func apiKeyScopeAllowed(user *auth.User, kind string) bool {
if user == nil || strings.TrimSpace(user.APIKeyID) == "" || len(user.APIKeyScopes) == 0 {
return true
}
required := scopeForTaskKind(kind)
for _, scope := range user.APIKeyScopes {
scope = strings.TrimSpace(strings.ToLower(scope))
if scope == "*" || scope == "all" || scope == required {
return true
}
if required == "chat" && (scope == "text" || scope == "text_generate") {
return true
}
}
return false
}
func scopeForTaskKind(kind string) string {
switch kind {
case "chat.completions", "responses":
return "chat"
case "images.generations", "images.edits":
return "image"
case "videos.generations":
return "video"
default:
return kind
}
}
func statusFromRunError(err error) int {
switch {
case errors.Is(err, store.ErrNoModelCandidate):
@ -578,6 +617,30 @@ func statusFromRunError(err error) int {
}
}
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized")
return
}
limit := 50
if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" {
parsed, err := strconv.Atoi(raw)
if err != nil || parsed <= 0 {
writeError(w, http.StatusBadRequest, "invalid limit")
return
}
limit = parsed
}
tasks, err := s.store.ListTasks(r.Context(), user, limit)
if err != nil {
s.logger.Error("list tasks failed", "error", err)
writeError(w, http.StatusInternalServerError, "list tasks failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": tasks})
}
func boolValue(body map[string]any, key string) bool {
value, _ := body[key].(bool)
return value

View File

@ -101,6 +101,7 @@ func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Han
mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false)))
mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false)))
mux.Handle("POST /api/v1/videos/generations", server.auth.Require(auth.PermissionBasic, server.createTask("videos.generations", false)))
mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))

View File

@ -18,15 +18,36 @@ func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, ca
}
func effectiveRateLimitPolicy(candidate store.RuntimeModelCandidate) map[string]any {
if hasRules(candidate.ModelRateLimitPolicy) {
return candidate.ModelRateLimitPolicy
policy := candidate.PlatformRateLimitPolicy
if hasRules(candidate.RuntimeRateLimitPolicy) {
policy = mergeMap(policy, candidate.RuntimeRateLimitPolicy)
}
if hasRules(candidate.PlatformRateLimitPolicy) {
return candidate.PlatformRateLimitPolicy
if nested, ok := candidate.RuntimePolicyOverride["rateLimitPolicy"].(map[string]any); ok && len(nested) > 0 {
policy = mergeMap(policy, nested)
}
if hasRules(candidate.ModelRateLimitPolicy) {
policy = mergeMap(policy, candidate.ModelRateLimitPolicy)
}
if hasRules(policy) {
return policy
}
return nil
}
func effectiveRetryPolicy(candidate store.RuntimeModelCandidate) map[string]any {
policy := candidate.PlatformRetryPolicy
if len(candidate.RuntimeRetryPolicy) > 0 {
policy = mergeMap(policy, candidate.RuntimeRetryPolicy)
}
if nested, ok := candidate.RuntimePolicyOverride["retryPolicy"].(map[string]any); ok && len(nested) > 0 {
policy = mergeMap(policy, nested)
}
if len(candidate.ModelRetryPolicy) > 0 {
policy = mergeMap(policy, candidate.ModelRetryPolicy)
}
return policy
}
func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string]any, body map[string]any) []store.RateLimitReservation {
if scopeKey == "" || !hasRules(policy) {
return nil

View File

@ -16,7 +16,7 @@ type EstimateResult struct {
}
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind), user)
candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind, body), user)
if err != nil {
return EstimateResult{}, err
}
@ -38,7 +38,7 @@ func (s *Service) Estimate(ctx context.Context, kind string, model string, body
}
func (s *Service) billings(ctx context.Context, user *auth.User, kind string, body map[string]any, candidate store.RuntimeModelCandidate, response clients.Response, simulated bool) []any {
config := effectiveBillingConfig(candidate)
config := s.effectiveBillingConfig(ctx, candidate)
discount := effectiveDiscount(ctx, s.store, user, candidate)
if isTextGenerationKind(kind) {
inputTokens := response.Usage.InputTokens
@ -74,11 +74,21 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo
return []any{billingLine(candidate, resource, unit, count, roundPrice(amount), discount, simulated)}
}
func effectiveBillingConfig(candidate store.RuntimeModelCandidate) map[string]any {
func (s *Service) effectiveBillingConfig(ctx context.Context, candidate store.RuntimeModelCandidate) map[string]any {
base := candidate.BaseBillingConfig
if ruleSetID := firstNonEmptyString(candidate.BasePricingRuleSetID, candidate.PlatformPricingRuleSetID); ruleSetID != "" {
if ruleSetConfig, err := s.store.PricingRuleSetBillingConfig(ctx, ruleSetID); err == nil && len(ruleSetConfig) > 0 {
base = ruleSetConfig
}
}
if len(candidate.BillingConfig) > 0 {
base = candidate.BillingConfig
}
if candidate.ModelPricingRuleSetID != "" {
if ruleSetConfig, err := s.store.PricingRuleSetBillingConfig(ctx, candidate.ModelPricingRuleSetID); err == nil && len(ruleSetConfig) > 0 {
base = ruleSetConfig
}
}
if len(candidate.BillingConfigOverride) > 0 {
base = mergeMap(base, candidate.BillingConfigOverride)
}

View File

@ -0,0 +1,84 @@
package runner
import (
"context"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func (s *Service) applyCandidateFailurePolicies(ctx context.Context, taskID string, candidate store.RuntimeModelCandidate, cause error, simulated bool) {
code := clients.ErrorCode(cause)
message := ""
if cause != nil {
message = cause.Error()
}
autoDisablePolicy := effectiveRuntimePolicy(candidate.AutoDisablePolicy, candidate.RuntimePolicyOverride, "autoDisablePolicy")
if failurePolicyMatches(autoDisablePolicy, code, message) && intFromPolicy(autoDisablePolicy, "threshold") <= 1 {
if err := s.store.DisableCandidatePlatform(ctx, candidate.PlatformID); err == nil {
_ = s.emit(ctx, taskID, "task.policy.auto_disabled", "running", "auto_disable", 0.48, "candidate platform disabled by failure policy", map[string]any{
"platformId": candidate.PlatformID,
"platformModelId": candidate.PlatformModelID,
"code": code,
}, simulated)
}
}
degradePolicy := effectiveRuntimePolicy(candidate.DegradePolicy, candidate.RuntimePolicyOverride, "degradePolicy")
if failurePolicyMatches(degradePolicy, code, message) {
cooldownSeconds := intFromPolicy(degradePolicy, "cooldownSeconds")
if err := s.store.CooldownCandidatePlatform(ctx, candidate.PlatformID, cooldownSeconds); err == nil {
_ = s.emit(ctx, taskID, "task.policy.degraded", "running", "degrade", 0.5, "candidate platform cooled down by failure policy", map[string]any{
"platformId": candidate.PlatformID,
"platformModelId": candidate.PlatformModelID,
"cooldownSeconds": cooldownSeconds,
"code": code,
}, simulated)
}
}
}
func effectiveRuntimePolicy(base map[string]any, override map[string]any, key string) map[string]any {
policy := base
if nested, ok := override[key].(map[string]any); ok && len(nested) > 0 {
policy = mergeMap(policy, nested)
}
return policy
}
func failurePolicyMatches(policy map[string]any, code string, message string) bool {
if len(policy) == 0 || !boolFromMap(policy, "enabled") {
return false
}
keywords := stringListFromPolicy(policy, "keywords")
if len(keywords) == 0 {
return false
}
target := strings.ToLower(strings.TrimSpace(code + " " + message))
for _, keyword := range keywords {
keyword = strings.ToLower(strings.TrimSpace(keyword))
if keyword != "" && strings.Contains(target, keyword) {
return true
}
}
return false
}
func stringListFromPolicy(values map[string]any, key string) []string {
raw, ok := values[key].([]any)
if !ok {
if typed, ok := values[key].([]string); ok {
return typed
}
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
if text, ok := item.(string); ok && strings.TrimSpace(text) != "" {
out = append(out, text)
}
}
return out
}

View File

@ -50,8 +50,8 @@ func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, use
}
func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
modelType := modelTypeFromKind(task.Kind)
body := normalizeRequest(task.Kind, task.Request)
modelType := modelTypeFromKind(task.Kind, body)
if err := validateRequest(task.Kind, body); err != nil {
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
if finishErr != nil {
@ -102,6 +102,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
if finishErr != nil {
return Result{}, finishErr
}
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
return Result{}, settleErr
}
if finished.FinalChargeAmount > 0 {
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
"amount": finished.FinalChargeAmount,
"currency": stringFromAny(record.BillingSummary["currency"]),
}, isSimulation(task, candidate)); err != nil {
return Result{}, err
}
}
if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{
"result": response.Result,
"billings": billings,
@ -226,6 +237,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
ErrorMessage: err.Error(),
})
_ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable, "requestId": requestID, "metrics": metrics}, simulated)
s.applyCandidateFailurePolicies(ctx, task.ID, candidate, err, simulated)
return clients.Response{}, err
}
uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result)
@ -317,7 +329,7 @@ func (s *Service) emit(ctx context.Context, taskID string, eventType string, sta
return nil
}
func modelTypeFromKind(kind string) string {
func modelTypeFromKind(kind string, body map[string]any) string {
switch kind {
case "chat.completions", "responses":
return "text_generate"
@ -327,12 +339,30 @@ func modelTypeFromKind(kind string) string {
}
return "image_generate"
case "videos.generations":
if videoRequestHasReferenceImage(body) {
return "image_to_video"
}
return "video_generate"
default:
return "task"
}
}
func videoRequestHasReferenceImage(body map[string]any) bool {
if body == nil {
return false
}
for _, key := range []string{
"image", "images", "image_url", "imageUrl", "image_urls", "imageUrls",
"reference_image", "referenceImage", "first_frame", "firstFrame", "last_frame", "lastFrame",
} {
if hasAnyString(body, key) {
return true
}
}
return false
}
func isTextGenerationKind(kind string) bool {
return kind == "chat.completions" || kind == "responses"
}
@ -345,10 +375,8 @@ func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate)
}
func retryEnabled(candidate store.RuntimeModelCandidate) bool {
if enabled, ok := candidate.ModelRetryPolicy["enabled"].(bool); ok {
return enabled
}
if enabled, ok := candidate.PlatformRetryPolicy["enabled"].(bool); ok {
policy := effectiveRetryPolicy(candidate)
if enabled, ok := policy["enabled"].(bool); ok {
return enabled
}
return true
@ -360,10 +388,7 @@ func maxAttemptsForCandidates(candidates []store.RuntimeModelCandidate) int {
}
maxAttempts := len(candidates)
for _, candidate := range candidates {
if value := intFromPolicy(candidate.ModelRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts {
maxAttempts = value
}
if value := intFromPolicy(candidate.PlatformRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts {
if value := intFromPolicy(effectiveRetryPolicy(candidate), "maxAttempts"); value > 0 && value < maxAttempts {
maxAttempts = value
}
}

View File

@ -59,7 +59,7 @@ func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) {
rows, err := s.pool.Query(ctx, `
SELECT `+baseModelColumns+`
FROM base_model_catalog
ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`)
ORDER BY provider_key ASC, canonical_model_key ASC`)
if err != nil {
return nil, err
}
@ -84,7 +84,7 @@ func (s *Store) CreateBaseModel(ctx context.Context, input BaseModelInput) (Base
runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride))
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot))
modelType := primaryString(input.ModelType, "text_generate")
modelType, _ := json.Marshal(input.ModelType)
return scanBaseModel(s.pool.QueryRow(ctx, `
INSERT INTO base_model_catalog (
@ -94,7 +94,7 @@ INSERT INTO base_model_catalog (
)
VALUES (
(SELECT id FROM model_catalog_providers WHERE provider_key = $1 OR provider_code = $1 LIMIT 1),
$1, $2, $3, $4, $5, $6, $7, $8,
$1, $2, $3, $4::jsonb, $5, $6, $7, $8,
COALESCE(NULLIF($9, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)),
COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)),
$11, $12, NULLIF($13, ''), NULLIF($14::jsonb, '{}'::jsonb), $15, $16
@ -103,7 +103,7 @@ RETURNING `+baseModelColumns,
input.ProviderKey,
input.CanonicalModelKey,
input.ProviderModelName,
modelType,
string(modelType),
input.ModelAlias,
capabilities,
billingConfig,
@ -127,7 +127,7 @@ func (s *Store) UpdateBaseModel(ctx context.Context, id string, input BaseModelI
runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride))
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot))
modelType := primaryString(input.ModelType, "text_generate")
modelType, _ := json.Marshal(input.ModelType)
return scanBaseModel(s.pool.QueryRow(ctx, `
UPDATE base_model_catalog
@ -135,7 +135,7 @@ SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = $
provider_key = $2,
canonical_model_key = $3,
provider_model_name = $4,
model_type = $5,
model_type = $5::jsonb,
display_name = $6,
capabilities = $7,
base_billing_config = $8,
@ -156,7 +156,7 @@ RETURNING `+baseModelColumns,
input.ProviderKey,
input.CanonicalModelKey,
input.ProviderModelName,
modelType,
string(modelType),
input.ModelAlias,
capabilities,
billingConfig,
@ -195,7 +195,7 @@ SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = C
provider_key = COALESCE($2::text, provider_key),
canonical_model_key = COALESCE($3::text, canonical_model_key),
provider_model_name = COALESCE($4::text, provider_model_name),
model_type = COALESCE($5::text, model_type),
model_type = COALESCE($5::jsonb, model_type),
display_name = COALESCE($6::text, display_name),
capabilities = COALESCE($7::jsonb, capabilities),
base_billing_config = COALESCE($8::jsonb, base_billing_config),
@ -214,7 +214,7 @@ RETURNING `+baseModelColumns,
stringFromSnapshot(snapshot, "providerKey"),
stringFromSnapshot(snapshot, "canonicalModelKey"),
stringFromSnapshot(snapshot, "providerModelName"),
stringFromSnapshot(snapshot, "modelType"),
jsonStringListFromSnapshot(snapshot, "modelType"),
stringFromSnapshot(snapshot, "modelAlias", "displayName"),
jsonFromSnapshot(snapshot, "capabilities"),
jsonFromSnapshot(snapshot, "baseBillingConfig"),
@ -242,9 +242,10 @@ SET provider_id = (
canonical_model_key = COALESCE(NULLIF(default_snapshot->>'canonicalModelKey', ''), canonical_model_key),
provider_model_name = COALESCE(NULLIF(default_snapshot->>'providerModelName', ''), provider_model_name),
model_type = COALESCE(NULLIF(CASE
WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType'->>0
ELSE default_snapshot->>'modelType'
END, ''), model_type),
WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType'
WHEN COALESCE(default_snapshot->>'modelType', '') <> '' THEN jsonb_build_array(default_snapshot->>'modelType')
ELSE NULL
END, '[]'::jsonb), model_type),
display_name = COALESCE(NULLIF(COALESCE(default_snapshot->>'modelAlias', default_snapshot->>'displayName'), ''), display_name),
capabilities = COALESCE(default_snapshot->'capabilities', capabilities),
base_billing_config = COALESCE(default_snapshot->'baseBillingConfig', base_billing_config),
@ -292,7 +293,7 @@ func scanBaseModelRows(rows pgx.Rows) ([]BaseModel, error) {
func scanBaseModel(scanner baseModelScanner) (BaseModel, error) {
var item BaseModel
var modelType string
var modelType []byte
var modelAlias string
var capabilities []byte
var billingConfig []byte
@ -330,7 +331,7 @@ func scanBaseModel(scanner baseModelScanner) (BaseModel, error) {
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
item.Metadata = decodeObject(metadata)
item.DefaultSnapshot = decodeObject(defaultSnapshot)
item.ModelType = baseModelTypes(item.Capabilities, item.Metadata, modelType)
item.ModelType = decodeStringArray(modelType)
item.ModelAlias = modelAlias
item.DisplayName = modelAlias
return item, nil
@ -420,14 +421,22 @@ func jsonFromSnapshot(snapshot map[string]any, key string) any {
return string(raw)
}
func baseModelTypes(capabilities map[string]any, metadata map[string]any, fallback string) StringList {
values := make([]string, 0)
values = append(values, stringListFromAny(capabilities["originalTypes"])...)
values = append(values, stringListFromAny(metadata["originalTypes"])...)
if fallback != "" {
values = append(values, fallback)
func jsonStringListFromSnapshot(snapshot map[string]any, key string) any {
values := stringListFromAny(snapshot[key])
if len(values) == 0 {
if value, ok := snapshot[key].(string); ok {
values = []string{value}
}
}
return uniqueStringList(values)
normalized := uniqueStringList(values)
if len(normalized) == 0 {
return nil
}
raw, err := json.Marshal(normalized)
if err != nil {
return nil
}
return string(raw)
}
func stringListFromAny(value any) []string {
@ -451,16 +460,39 @@ func uniqueStringList(values []string) StringList {
out := make([]string, 0, len(values))
seen := map[string]bool{}
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" || seen[value] {
continue
for _, normalized := range modelTypeAliases(value) {
normalized = strings.TrimSpace(normalized)
if normalized == "" || seen[normalized] {
continue
}
seen[normalized] = true
out = append(out, normalized)
}
seen[value] = true
out = append(out, value)
}
return out
}
func normalizeModelTypeList(values []string) StringList {
return uniqueStringList(values)
}
func modelTypeAliases(value string) []string {
switch strings.TrimSpace(value) {
case "chat", "text", "responses":
return []string{"text_generate"}
case "image":
return []string{"image_generate", "image_edit"}
case "images.generations":
return []string{"image_generate"}
case "images.edits":
return []string{"image_edit"}
case "video", "videos.generations":
return []string{"video_generate"}
default:
return []string{value}
}
}
func primaryString(values []string, fallback string) string {
for _, value := range values {
if value = strings.TrimSpace(value); value != "" {

View File

@ -18,27 +18,25 @@ SELECT p.id::text, p.platform_key, p.name, p.provider,
COALESCE(p.dynamic_priority, p.priority) AS effective_priority,
m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''),
COALESCE(b.provider_model_name, ''), m.model_name, COALESCE(m.model_alias, ''),
m.model_type, m.display_name, m.capabilities, m.capability_override,
$2 AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb)
COALESCE(b.pricing_rule_set_id::text, ''),
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb)
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
LEFT JOIN runtime_client_states s
ON s.client_id = p.platform_key || ':' || m.model_type || ':' || m.model_name
ON s.client_id = p.platform_key || ':' || $2 || ':' || m.model_name
WHERE p.status = 'enabled'
AND p.deleted_at IS NULL
AND m.enabled = true
AND (
m.model_type = $2
OR ($2 = 'text_generate' AND m.model_type IN ('chat', 'responses', 'text'))
OR ($2 = 'image_generate' AND m.model_type IN ('image', 'images.generations'))
OR ($2 = 'image_edit' AND m.model_type IN ('images.edits'))
OR ($2 = 'video_generate' AND m.model_type IN ('video', 'videos.generations', 'video_generate', 'text_to_video', 'image_to_video', 'omni_video', 'video_edit', 'video_reference', 'video_first_last_frame'))
)
AND m.model_type @> jsonb_build_array($2)
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
AND (
m.model_name = $1
@ -73,6 +71,10 @@ ORDER BY effective_priority ASC,
var modelRetryPolicy []byte
var modelRateLimitPolicy []byte
var runtimePolicyOverride []byte
var runtimeRetryPolicy []byte
var runtimeRateLimitPolicy []byte
var autoDisablePolicy []byte
var degradePolicy []byte
if err := rows.Scan(
&item.PlatformID,
&item.PlatformKey,
@ -105,11 +107,16 @@ ORDER BY effective_priority ASC,
&item.PricingMode,
&item.DiscountFactor,
&item.ModelPricingRuleSetID,
&item.BasePricingRuleSetID,
&permissionConfig,
&modelRetryPolicy,
&modelRateLimitPolicy,
&item.RuntimePolicySetID,
&runtimePolicyOverride,
&runtimeRetryPolicy,
&runtimeRateLimitPolicy,
&autoDisablePolicy,
&degradePolicy,
); err != nil {
return nil, err
}
@ -126,6 +133,10 @@ ORDER BY effective_priority ASC,
item.ModelRetryPolicy = decodeObject(modelRetryPolicy)
item.ModelRateLimitPolicy = decodeObject(modelRateLimitPolicy)
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
item.RuntimeRetryPolicy = decodeObject(runtimeRetryPolicy)
item.RuntimeRateLimitPolicy = decodeObject(runtimeRateLimitPolicy)
item.AutoDisablePolicy = decodeObject(autoDisablePolicy)
item.DegradePolicy = decodeObject(degradePolicy)
item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, item.ModelName)
item.QueueKey = item.ClientID
items = append(items, item)

View File

@ -17,7 +17,7 @@ type modelCatalogSnapshot struct {
ProviderKey string
CanonicalModelKey string
ProviderModelName string
ModelType string
ModelType StringList
DisplayName string
Capabilities map[string]any
BaseBillingConfig map[string]any
@ -83,9 +83,13 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier,
if err != nil && !IsNotFound(err) {
return PlatformModel{}, err
}
if input.ModelType == "" {
if len(input.ModelType) == 0 {
input.ModelType = base.ModelType
}
input.ModelType = normalizeModelTypeList(input.ModelType)
if len(input.ModelType) == 0 {
input.ModelType = StringList{"text_generate"}
}
if input.ModelName == "" {
input.ModelName = base.ProviderModelName
}
@ -104,11 +108,12 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier,
if len(billingConfig) == 0 {
billingConfig = mergeObjects(base.BaseBillingConfig, input.BillingConfigOverride)
}
explicitRuntimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID)
rateLimitPolicy := input.RateLimitPolicy
if len(rateLimitPolicy) == 0 {
if len(rateLimitPolicy) == 0 && explicitRuntimePolicySetID == "" {
rateLimitPolicy = base.DefaultRateLimitPolicy
}
runtimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID)
runtimePolicySetID := explicitRuntimePolicySetID
if runtimePolicySetID == "" {
runtimePolicySetID = base.RuntimePolicySetID
}
@ -119,6 +124,7 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier,
capabilityOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.CapabilityOverride))
capabilitiesJSON, _ := json.Marshal(emptyObjectIfNil(capabilities))
modelTypeJSON, _ := json.Marshal(input.ModelType)
billingOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingConfigOverride))
billingJSON, _ := json.Marshal(emptyObjectIfNil(billingConfig))
permissionJSON, _ := json.Marshal(emptyObjectIfNil(input.PermissionConfig))
@ -144,6 +150,7 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier,
var retryPolicyBytes []byte
var rateLimitPolicyBytes []byte
var runtimePolicyOverrideBytes []byte
var modelTypeBytes []byte
err = q.QueryRow(ctx, `
INSERT INTO platform_models (
platform_id, base_model_id, model_name, model_alias, model_type, display_name,
@ -152,12 +159,12 @@ INSERT INTO platform_models (
runtime_policy_set_id, runtime_policy_override, enabled
)
VALUES (
$1::uuid, $2::uuid, $3, NULLIF($4, ''), $5, $6,
$1::uuid, $2::uuid, $3, NULLIF($4, ''), $5::jsonb, $6,
$7::jsonb, $8::jsonb, $9, $10::numeric,
NULLIF($11, '')::uuid, $12::jsonb, $13::jsonb, $14::jsonb, $15::jsonb, $16::jsonb,
NULLIF($17, '')::uuid, $18::jsonb, true
)
ON CONFLICT (platform_id, model_name, model_type) DO UPDATE
ON CONFLICT (platform_id, model_name) DO UPDATE
SET base_model_id = EXCLUDED.base_model_id,
model_alias = EXCLUDED.model_alias,
display_name = EXCLUDED.display_name,
@ -185,7 +192,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_
baseID,
input.ModelName,
input.ModelAlias,
input.ModelType,
string(modelTypeJSON),
input.DisplayName,
string(capabilityOverrideJSON),
string(capabilitiesJSON),
@ -205,7 +212,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_
&model.BaseModelID,
&model.ModelName,
&model.ModelAlias,
&model.ModelType,
&modelTypeBytes,
&model.DisplayName,
&capabilityOverrideBytes,
&capabilitiesBytes,
@ -228,6 +235,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_
}
model.CapabilityOverride = decodeObject(capabilityOverrideBytes)
model.Capabilities = decodeObject(capabilitiesBytes)
model.ModelType = decodeStringArray(modelTypeBytes)
model.BillingConfigOverride = decodeObject(billingOverrideBytes)
model.BillingConfig = decodeObject(billingBytes)
model.PermissionConfig = decodeObject(permissionBytes)
@ -265,6 +273,7 @@ func (s *Store) lookupBaseModel(ctx context.Context, q platformModelQuerier, id
var billingConfig []byte
var rateLimitPolicy []byte
var runtimePolicyOverride []byte
var modelTypeBytes []byte
err := q.QueryRow(ctx, `
SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
capabilities, base_billing_config, default_rate_limit_policy,
@ -279,7 +288,7 @@ LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSp
&item.ProviderKey,
&item.CanonicalModelKey,
&item.ProviderModelName,
&item.ModelType,
&modelTypeBytes,
&item.DisplayName,
&capabilities,
&billingConfig,
@ -297,6 +306,7 @@ LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSp
item.BaseBillingConfig = decodeObject(billingConfig)
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
item.ModelType = normalizeModelTypeList(decodeStringArray(modelTypeBytes))
return item, nil
}

View File

@ -134,7 +134,7 @@ type PlatformModel struct {
PlatformName string `json:"platformName,omitempty"`
ModelName string `json:"modelName"`
ModelAlias string `json:"modelAlias,omitempty"`
ModelType string `json:"modelType"`
ModelType StringList `json:"modelType"`
DisplayName string `json:"displayName"`
CapabilityOverride map[string]any `json:"capabilityOverride,omitempty"`
Capabilities map[string]any `json:"capabilities,omitempty"`
@ -390,6 +390,8 @@ type GatewayTask struct {
ResponseDurationMS int64 `json:"responseDurationMs"`
FinishedAt string `json:"finishedAt,omitempty"`
Error string `json:"error,omitempty"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
@ -404,6 +406,7 @@ request, status, COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
COALESCE(usage, '{}'::jsonb), COALESCE(metrics, '{}'::jsonb), COALESCE(billing_summary, '{}'::jsonb),
COALESCE(final_charge_amount, 0)::float8, COALESCE(response_started_at::text, ''),
COALESCE(response_finished_at::text, ''), COALESCE(response_duration_ms, 0), COALESCE(error, ''),
COALESCE(error_code, ''), COALESCE(error_message, ''),
created_at, updated_at, COALESCE(finished_at::text, '')`
type TaskEvent struct {
@ -713,6 +716,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...)
var retryPolicy []byte
var rateLimitPolicy []byte
var runtimePolicyOverride []byte
var modelTypeBytes []byte
if err := rows.Scan(
&model.ID,
&model.PlatformID,
@ -721,7 +725,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...)
&model.PlatformName,
&model.ModelName,
&model.ModelAlias,
&model.ModelType,
&modelTypeBytes,
&model.DisplayName,
&capabilityOverride,
&capabilities,
@ -743,6 +747,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...)
}
model.CapabilityOverride = decodeObject(capabilityOverride)
model.Capabilities = decodeObject(capabilities)
model.ModelType = decodeStringArray(modelTypeBytes)
model.BillingConfigOverride = decodeObject(billingConfigOverride)
model.BillingConfig = decodeObject(billingConfig)
model.PermissionConfig = decodeObject(permissionConfig)
@ -1231,7 +1236,7 @@ func (s *Store) VerifyLocalAPIKey(ctx context.Context, secret string) (*auth.Use
return nil, auth.ErrUnauthorized
}
rows, err := s.pool.Query(ctx, `
SELECT k.id::text, k.key_hash, k.key_prefix, k.name, COALESCE(k.user_group_id::text, u.default_user_group_id::text, ''),
SELECT k.id::text, k.key_hash, k.key_prefix, k.name, k.scopes, COALESCE(k.user_group_id::text, u.default_user_group_id::text, ''),
u.id::text, u.username, u.roles, COALESCE(u.gateway_tenant_id::text, ''),
COALESCE(u.tenant_id, ''), COALESCE(u.tenant_key, '')
FROM gateway_api_keys k
@ -1252,6 +1257,7 @@ WHERE k.key_prefix = $1
var hash string
var keyPrefix string
var keyName string
var scopesBytes []byte
var userGroupID string
var gatewayUserID string
var username string
@ -1259,7 +1265,7 @@ WHERE k.key_prefix = $1
var gatewayTenantID string
var tenantID string
var tenantKey string
if err := rows.Scan(&apiKeyID, &hash, &keyPrefix, &keyName, &userGroupID, &gatewayUserID, &username, &rolesBytes, &gatewayTenantID, &tenantID, &tenantKey); err != nil {
if err := rows.Scan(&apiKeyID, &hash, &keyPrefix, &keyName, &scopesBytes, &userGroupID, &gatewayUserID, &username, &rolesBytes, &gatewayTenantID, &tenantID, &tenantKey); err != nil {
return nil, err
}
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(secret)) != nil {
@ -1281,6 +1287,7 @@ WHERE k.key_prefix = $1
APIKeyID: apiKeyID,
APIKeyName: keyName,
APIKeyPrefix: keyPrefix,
APIKeyScopes: decodeStringArray(scopesBytes),
}, nil
}
if err := rows.Err(); err != nil {
@ -1661,6 +1668,8 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
&task.ResponseFinishedAt,
&task.ResponseDurationMS,
&task.Error,
&task.ErrorCode,
&task.ErrorMessage,
&task.CreatedAt,
&task.UpdatedAt,
&task.FinishedAt,

View File

@ -176,6 +176,63 @@ func (s *Store) DeletePricingRuleSet(ctx context.Context, id string) error {
return nil
}
func (s *Store) PricingRuleSetBillingConfig(ctx context.Context, id string) (map[string]any, error) {
id = strings.TrimSpace(id)
if id == "" {
return nil, nil
}
rows, err := s.pool.Query(ctx, `
SELECT resource_type, base_price::float8, dynamic_weight
FROM model_pricing_rules
WHERE rule_set_id = $1::uuid
AND status = 'active'
ORDER BY priority ASC, resource_type ASC`, id)
if err != nil {
return nil, err
}
defer rows.Close()
config := map[string]any{}
for rows.Next() {
var resourceType string
var basePrice float64
var dynamicWeightBytes []byte
if err := rows.Scan(&resourceType, &basePrice, &dynamicWeightBytes); err != nil {
return nil, err
}
dynamicWeight := decodeObject(dynamicWeightBytes)
switch resourceType {
case "text_input":
config["textInputPer1k"] = basePrice
case "text_output":
config["textOutputPer1k"] = basePrice
case "image":
config["imageBase"] = basePrice
config["image"] = pricingResourceConfig(basePrice, dynamicWeight)
case "image_edit":
config["editBase"] = basePrice
config["image_edit"] = pricingResourceConfig(basePrice, dynamicWeight)
case "video":
config["videoBase"] = basePrice
config["video"] = pricingResourceConfig(basePrice, dynamicWeight)
default:
config[resourceType] = pricingResourceConfig(basePrice, dynamicWeight)
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return config, nil
}
func pricingResourceConfig(basePrice float64, dynamicWeight map[string]any) map[string]any {
config := map[string]any{"basePrice": basePrice}
if len(dynamicWeight) > 0 {
config["dynamicWeight"] = dynamicWeight
}
return config
}
func insertPricingRules(ctx context.Context, tx pgx.Tx, ruleSetID string, defaultCurrency string, rules []PricingRuleInput) error {
for index, rule := range rules {
rule = normalizePricingRule(rule, index, defaultCurrency)

View File

@ -107,6 +107,33 @@ func (s *Store) DeleteRuntimePolicySet(ctx context.Context, id string) error {
return nil
}
func (s *Store) DisableCandidatePlatform(ctx context.Context, platformID string) error {
if strings.TrimSpace(platformID) == "" {
return nil
}
_, err := s.pool.Exec(ctx, `
UPDATE integration_platforms
SET status = 'disabled',
updated_at = now()
WHERE id = $1::uuid`, platformID)
return err
}
func (s *Store) CooldownCandidatePlatform(ctx context.Context, platformID string, cooldownSeconds int) error {
if strings.TrimSpace(platformID) == "" {
return nil
}
if cooldownSeconds <= 0 {
cooldownSeconds = 300
}
_, err := s.pool.Exec(ctx, `
UPDATE integration_platforms
SET cooldown_until = now() + ($2::int * interval '1 second'),
updated_at = now()
WHERE id = $1::uuid`, platformID, cooldownSeconds)
return err
}
func scanRuntimePolicySet(scanner runtimePolicyScanner) (RuntimePolicySet, error) {
var item RuntimePolicySet
var rateLimitPolicy []byte

View File

@ -16,7 +16,7 @@ type CreatePlatformModelInput struct {
CanonicalModelKey string `json:"canonicalModelKey"`
ModelName string `json:"modelName"`
ModelAlias string `json:"modelAlias"`
ModelType string `json:"modelType"`
ModelType StringList `json:"modelType"`
DisplayName string `json:"displayName"`
CapabilityOverride map[string]any `json:"capabilityOverride"`
Capabilities map[string]any `json:"capabilities"`
@ -64,12 +64,17 @@ type RuntimeModelCandidate struct {
PermissionConfig map[string]any
PricingMode string
DiscountFactor float64
BasePricingRuleSetID string
PlatformPricingRuleSetID string
ModelPricingRuleSetID string
ModelRetryPolicy map[string]any
ModelRateLimitPolicy map[string]any
RuntimePolicySetID string
RuntimePolicyOverride map[string]any
RuntimeRetryPolicy map[string]any
RuntimeRateLimitPolicy map[string]any
AutoDisablePolicy map[string]any
DegradePolicy map[string]any
ClientID string
QueueKey string
}

View File

@ -3,9 +3,67 @@ package store
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
func (s *Store) ListTasks(ctx context.Context, user *auth.User, limit int) ([]GatewayTask, error) {
if limit <= 0 {
limit = 50
}
if limit > 100 {
limit = 100
}
gatewayUserID := localGatewayUserID(user)
apiKeyID := ""
userID := ""
if user != nil {
apiKeyID = strings.TrimSpace(user.APIKeyID)
userID = strings.TrimSpace(user.ID)
}
if gatewayUserID == "" && userID == "" {
return nil, ErrLocalUserRequired
}
rows, err := s.pool.Query(ctx, `
SELECT `+gatewayTaskColumns+`
FROM gateway_tasks
WHERE (
(
NULLIF($1, '')::uuid IS NOT NULL
AND gateway_user_id = NULLIF($1, '')::uuid
)
OR (
NULLIF($1, '')::uuid IS NULL
AND NULLIF($2, '') IS NOT NULL
AND user_id = $2
)
)
AND (
NULLIF($3, '') IS NULL
OR api_key_id = $3
)
ORDER BY created_at DESC
LIMIT $4`, gatewayUserID, userID, apiKeyID, limit)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]GatewayTask, 0)
for rows.Next() {
task, err := scanGatewayTask(rows)
if err != nil {
return nil, err
}
items = append(items, task)
}
return items, rows.Err()
}
func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error {
normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest))
_, err := s.pool.Exec(ctx, `
@ -140,6 +198,103 @@ WHERE id = $1::uuid`,
return s.GetTask(ctx, input.TaskID)
}
func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
if task.FinalChargeAmount <= 0 || strings.TrimSpace(task.GatewayUserID) == "" {
return nil
}
currency := strings.TrimSpace(taskBillingString(task.BillingSummary["currency"]))
if currency == "" || currency == "mixed" {
currency = "resource"
}
metadata, _ := json.Marshal(map[string]any{
"taskId": task.ID,
"kind": task.Kind,
"model": task.Model,
"resolvedModel": task.ResolvedModel,
"billings": task.Billings,
"billingSummary": task.BillingSummary,
})
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_accounts (
gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency
)
VALUES (NULLIF($1, '')::uuid, $2::uuid, NULLIF($3, ''), NULLIF($4, ''), NULLIF($5, ''), $6)
ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
task.GatewayTenantID, task.GatewayUserID, task.TenantID, task.TenantKey, task.UserID, currency); err != nil {
return err
}
var exists bool
if err := tx.QueryRow(ctx, `
SELECT EXISTS (
SELECT 1
FROM gateway_wallet_transactions t
JOIN gateway_wallet_accounts a ON a.id = t.account_id
WHERE a.gateway_user_id = $1::uuid
AND a.currency = $2
AND t.idempotency_key = $3
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
return err
}
if exists {
return nil
}
var accountID string
var balanceBefore float64
var gatewayTenantID string
if err := tx.QueryRow(ctx, `
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
AND currency = $2
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
return err
}
amount := roundMoney(task.FinalChargeAmount)
balanceAfter := roundMoney(balanceBefore - amount)
if _, err := tx.Exec(ctx, `
UPDATE gateway_wallet_accounts
SET balance = $2,
total_spent = total_spent + $3,
updated_at = now()
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
return err
}
_, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_transactions (
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
)
VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
)`,
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
return nil
}
return err
})
}
func billingIdempotencyKey(taskID string) string {
return "task:" + taskID + ":billing"
}
func roundMoney(value float64) float64 {
if value < 0 {
return -roundMoney(-value)
}
return float64(int64(value*1000000+0.5)) / 1000000
}
func taskBillingString(value any) string {
if text, ok := value.(string); ok {
return text
}
return ""
}
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
if _, err := s.pool.Exec(ctx, `

View File

@ -376,7 +376,7 @@ CREATE TABLE IF NOT EXISTS platform_models (
base_model_id uuid REFERENCES base_model_catalog(id) ON DELETE SET NULL,
model_name text NOT NULL,
model_alias text,
model_type text NOT NULL,
model_type jsonb NOT NULL DEFAULT '[]'::jsonb,
display_name text NOT NULL DEFAULT '',
capability_override jsonb NOT NULL DEFAULT '{}'::jsonb,
capabilities jsonb NOT NULL DEFAULT '{}'::jsonb,
@ -391,14 +391,17 @@ CREATE TABLE IF NOT EXISTS platform_models (
enabled boolean NOT NULL DEFAULT true,
created_at timestamptz NOT NULL DEFAULT now(),
updated_at timestamptz NOT NULL DEFAULT now(),
UNIQUE(platform_id, model_name, model_type)
UNIQUE(platform_id, model_name)
);
CREATE INDEX IF NOT EXISTS idx_platform_models_base
ON platform_models(base_model_id);
CREATE INDEX IF NOT EXISTS idx_platform_models_lookup
ON platform_models(model_type, model_name, enabled);
ON platform_models(model_name, enabled);
CREATE INDEX IF NOT EXISTS idx_platform_models_model_type
ON platform_models USING gin(model_type);
CREATE INDEX IF NOT EXISTS idx_platform_models_alias
ON platform_models(model_alias);

View File

@ -323,7 +323,7 @@ CREATE TABLE IF NOT EXISTS platform_models (
base_model_id uuid REFERENCES base_model_catalog(id) ON DELETE SET NULL,
model_name text NOT NULL,
model_alias text,
model_type text NOT NULL,
model_type jsonb NOT NULL DEFAULT '[]'::jsonb,
display_name text NOT NULL DEFAULT '',
capability_override jsonb NOT NULL DEFAULT '{}'::jsonb,
capabilities jsonb NOT NULL DEFAULT '{}'::jsonb,
@ -338,7 +338,7 @@ CREATE TABLE IF NOT EXISTS platform_models (
enabled boolean NOT NULL DEFAULT true,
created_at timestamptz NOT NULL DEFAULT now(),
updated_at timestamptz NOT NULL DEFAULT now(),
UNIQUE(platform_id, model_name, model_type)
UNIQUE(platform_id, model_name)
);
ALTER TABLE IF EXISTS platform_models
@ -636,7 +636,10 @@ CREATE INDEX IF NOT EXISTS idx_gateway_recharge_orders_user
CREATE INDEX IF NOT EXISTS idx_platform_models_base
ON platform_models(base_model_id);
CREATE INDEX IF NOT EXISTS idx_platform_models_lookup
ON platform_models(model_type, model_name, enabled);
ON platform_models(model_name, enabled);
CREATE INDEX IF NOT EXISTS idx_platform_models_model_type
ON platform_models USING gin(model_type);
CREATE INDEX IF NOT EXISTS idx_platform_models_alias
ON platform_models(model_alias);
CREATE INDEX IF NOT EXISTS idx_platform_models_capabilities

View File

@ -199,7 +199,19 @@ INSERT INTO platform_models (
platform_id, base_model_id, model_name, model_alias, model_type, display_name,
capabilities, pricing_mode, billing_config, retry_policy, rate_limit_policy, enabled
)
SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key, b.model_type, b.display_name,
SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key,
CASE b.model_type
WHEN 'chat' THEN '["text_generate"]'::jsonb
WHEN 'text' THEN '["text_generate"]'::jsonb
WHEN 'responses' THEN '["text_generate"]'::jsonb
WHEN 'image' THEN '["image_generate","image_edit"]'::jsonb
WHEN 'images.generations' THEN '["image_generate"]'::jsonb
WHEN 'images.edits' THEN '["image_edit"]'::jsonb
WHEN 'video' THEN '["video_generate"]'::jsonb
WHEN 'videos.generations' THEN '["video_generate"]'::jsonb
ELSE jsonb_build_array(b.model_type)
END,
b.display_name,
b.capabilities, 'inherit_discount', b.base_billing_config,
'{"enabled":true,"maxAttempts":2}'::jsonb,
b.default_rate_limit_policy,
@ -207,7 +219,7 @@ SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key, b.model_type, b
FROM integration_platforms p
JOIN base_model_catalog b ON b.provider_key = p.provider
WHERE p.platform_key IN ('openai-simulation', 'gemini-simulation')
ON CONFLICT (platform_id, model_name, model_type) DO UPDATE
ON CONFLICT (platform_id, model_name) DO UPDATE
SET base_model_id = EXCLUDED.base_model_id,
model_alias = EXCLUDED.model_alias,
display_name = EXCLUDED.display_name,

View File

@ -0,0 +1,138 @@
DO $$
DECLARE
column_type text;
BEGIN
SELECT data_type
INTO column_type
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'platform_models'
AND column_name = 'model_type';
IF column_type IS DISTINCT FROM 'jsonb' THEN
DROP INDEX IF EXISTS idx_platform_models_lookup;
ALTER TABLE platform_models
DROP CONSTRAINT IF EXISTS platform_models_platform_id_model_name_model_type_key;
ALTER TABLE platform_models
ADD COLUMN IF NOT EXISTS model_type_next jsonb NOT NULL DEFAULT '[]'::jsonb;
UPDATE platform_models
SET model_type_next = CASE trim(model_type)
WHEN 'chat' THEN '["text_generate"]'::jsonb
WHEN 'text' THEN '["text_generate"]'::jsonb
WHEN 'responses' THEN '["text_generate"]'::jsonb
WHEN 'image' THEN '["image_generate","image_edit"]'::jsonb
WHEN 'images.generations' THEN '["image_generate"]'::jsonb
WHEN 'images.edits' THEN '["image_edit"]'::jsonb
WHEN 'video' THEN '["video_generate"]'::jsonb
WHEN 'videos.generations' THEN '["video_generate"]'::jsonb
ELSE jsonb_build_array(trim(model_type))
END;
WITH ranked AS (
SELECT id,
first_value(id) OVER (
PARTITION BY platform_id, model_name
ORDER BY created_at ASC, id ASC
) AS keep_id,
row_number() OVER (
PARTITION BY platform_id, model_name
ORDER BY created_at ASC, id ASC
) AS row_number
FROM platform_models
),
merged AS (
SELECT ranked.keep_id,
jsonb_agg(DISTINCT type_value ORDER BY type_value) AS model_type
FROM ranked
JOIN platform_models model_row ON model_row.id = ranked.id
CROSS JOIN LATERAL jsonb_array_elements_text(model_row.model_type_next) AS type_item(type_value)
GROUP BY ranked.keep_id
)
UPDATE platform_models target
SET model_type_next = merged.model_type
FROM merged
WHERE target.id = merged.keep_id;
WITH ranked AS (
SELECT id,
first_value(id) OVER (
PARTITION BY platform_id, model_name
ORDER BY created_at ASC, id ASC
) AS keep_id,
row_number() OVER (
PARTITION BY platform_id, model_name
ORDER BY created_at ASC, id ASC
) AS row_number
FROM platform_models
)
UPDATE gateway_access_rules rules
SET resource_id = ranked.keep_id
FROM ranked
WHERE ranked.row_number > 1
AND rules.resource_type = 'platform_model'
AND rules.resource_id = ranked.id;
WITH ranked AS (
SELECT id,
row_number() OVER (
PARTITION BY platform_id, model_name
ORDER BY created_at ASC, id ASC
) AS row_number
FROM platform_models
)
DELETE FROM platform_models model_row
USING ranked
WHERE model_row.id = ranked.id
AND ranked.row_number > 1;
ALTER TABLE platform_models DROP COLUMN model_type;
ALTER TABLE platform_models RENAME COLUMN model_type_next TO model_type;
END IF;
END $$;
UPDATE platform_models
SET model_type = CASE
WHEN jsonb_typeof(model_type) = 'array' THEN model_type
WHEN jsonb_typeof(model_type) = 'string' THEN jsonb_build_array(model_type #>> '{}')
ELSE '[]'::jsonb
END;
UPDATE platform_models
SET model_type = COALESCE((
SELECT jsonb_agg(DISTINCT normalized_type ORDER BY normalized_type)
FROM jsonb_array_elements_text(platform_models.model_type) AS item(model_type_value)
CROSS JOIN LATERAL (
VALUES
(CASE item.model_type_value
WHEN 'chat' THEN 'text_generate'
WHEN 'text' THEN 'text_generate'
WHEN 'responses' THEN 'text_generate'
WHEN 'images.generations' THEN 'image_generate'
WHEN 'images.edits' THEN 'image_edit'
WHEN 'video' THEN 'video_generate'
WHEN 'videos.generations' THEN 'video_generate'
ELSE item.model_type_value
END)
) AS normalized(normalized_type)
WHERE trim(normalized_type) <> ''
), '[]'::jsonb);
UPDATE platform_models
SET model_type = '["image_generate","image_edit"]'::jsonb
WHERE model_type = '["image"]'::jsonb;
ALTER TABLE platform_models
ALTER COLUMN model_type SET NOT NULL,
ALTER COLUMN model_type SET DEFAULT '[]'::jsonb;
DROP INDEX IF EXISTS idx_platform_models_lookup;
CREATE INDEX IF NOT EXISTS idx_platform_models_lookup
ON platform_models(model_name, enabled);
CREATE INDEX IF NOT EXISTS idx_platform_models_model_type
ON platform_models USING gin(model_type);
CREATE UNIQUE INDEX IF NOT EXISTS idx_platform_models_platform_model_name
ON platform_models(platform_id, model_name);

View File

@ -0,0 +1,11 @@
UPDATE platform_models m
SET rate_limit_policy = '{}'::jsonb,
updated_at = now()
FROM base_model_catalog b,
model_runtime_policy_sets rp
WHERE m.base_model_id = b.id
AND m.runtime_policy_set_id = rp.id
AND m.runtime_policy_set_id IS NOT NULL
AND m.runtime_policy_set_id IS DISTINCT FROM b.runtime_policy_set_id
AND m.rate_limit_policy = b.default_rate_limit_policy
AND rp.rate_limit_policy <> '{}'::jsonb;

View File

@ -0,0 +1,108 @@
DO $$
DECLARE
column_type text;
BEGIN
SELECT data_type
INTO column_type
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'base_model_catalog'
AND column_name = 'model_type';
IF column_type IS DISTINCT FROM 'jsonb' THEN
DROP INDEX IF EXISTS idx_base_model_catalog_type;
ALTER TABLE base_model_catalog
ADD COLUMN IF NOT EXISTS model_type_next jsonb NOT NULL DEFAULT '[]'::jsonb;
WITH source_types AS (
SELECT id,
CASE
WHEN jsonb_typeof(capabilities->'originalTypes') = 'array' THEN capabilities->'originalTypes'
WHEN jsonb_typeof(metadata->'originalTypes') = 'array' THEN metadata->'originalTypes'
ELSE jsonb_build_array(model_type)
END AS model_types
FROM base_model_catalog
),
normalized_types AS (
SELECT source_types.id, normalized.model_type
FROM source_types
CROSS JOIN LATERAL jsonb_array_elements_text(source_types.model_types) AS raw_type(model_type)
CROSS JOIN LATERAL (
SELECT 'text_generate' AS model_type
WHERE raw_type.model_type IN ('chat', 'text', 'responses')
UNION ALL SELECT 'image_generate'
WHERE raw_type.model_type = 'image'
UNION ALL SELECT 'image_edit'
WHERE raw_type.model_type = 'image'
UNION ALL SELECT 'image_generate'
WHERE raw_type.model_type = 'images.generations'
UNION ALL SELECT 'image_edit'
WHERE raw_type.model_type = 'images.edits'
UNION ALL SELECT 'video_generate'
WHERE raw_type.model_type IN ('video', 'videos.generations')
UNION ALL SELECT raw_type.model_type
WHERE trim(raw_type.model_type) <> ''
AND raw_type.model_type NOT IN (
'chat', 'text', 'responses', 'image', 'images.generations', 'images.edits', 'video', 'videos.generations'
)
) AS normalized
)
UPDATE base_model_catalog target
SET model_type_next = COALESCE((
SELECT jsonb_agg(DISTINCT normalized_types.model_type ORDER BY normalized_types.model_type)
FROM normalized_types
WHERE normalized_types.id = target.id
), '[]'::jsonb);
ALTER TABLE base_model_catalog DROP COLUMN model_type;
ALTER TABLE base_model_catalog RENAME COLUMN model_type_next TO model_type;
END IF;
END $$;
UPDATE base_model_catalog
SET model_type = CASE
WHEN jsonb_typeof(model_type) = 'array' THEN model_type
WHEN jsonb_typeof(model_type) = 'string' THEN jsonb_build_array(model_type #>> '{}')
ELSE '[]'::jsonb
END;
UPDATE base_model_catalog
SET default_snapshot = default_snapshot || jsonb_build_object('modelType', model_type)
WHERE catalog_type = 'system'
AND COALESCE(default_snapshot, '{}'::jsonb) <> '{}'::jsonb;
ALTER TABLE base_model_catalog
ALTER COLUMN model_type SET NOT NULL,
ALTER COLUMN model_type SET DEFAULT '[]'::jsonb;
CREATE INDEX IF NOT EXISTS idx_base_model_catalog_type
ON base_model_catalog(provider_key, status);
CREATE INDEX IF NOT EXISTS idx_base_model_catalog_model_type
ON base_model_catalog USING gin(model_type);
CREATE OR REPLACE FUNCTION fill_system_base_model_default_snapshot()
RETURNS trigger AS $$
BEGIN
IF NEW.catalog_type = 'system' AND NEW.default_snapshot IS NULL THEN
NEW.default_snapshot = jsonb_build_object(
'providerKey', NEW.provider_key,
'canonicalModelKey', NEW.canonical_model_key,
'providerModelName', NEW.provider_model_name,
'modelType', NEW.model_type,
'modelAlias', NEW.display_name,
'capabilities', NEW.capabilities,
'baseBillingConfig', NEW.base_billing_config,
'defaultRateLimitPolicy', NEW.default_rate_limit_policy,
'pricingRuleSetId', COALESCE(NEW.pricing_rule_set_id::text, ''),
'runtimePolicySetId', COALESCE(NEW.runtime_policy_set_id::text, ''),
'runtimePolicyOverride', NEW.runtime_policy_override,
'metadata', NEW.metadata,
'pricingVersion', NEW.pricing_version,
'status', NEW.status
);
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

View File

@ -0,0 +1,9 @@
UPDATE platform_models m
SET model_type = b.model_type,
updated_at = now()
FROM base_model_catalog b
WHERE m.base_model_id = b.id
AND jsonb_typeof(b.model_type) = 'array'
AND jsonb_array_length(b.model_type) > 0
AND m.model_type <> b.model_type
AND COALESCE(m.capabilities, '{}'::jsonb) = COALESCE(b.capabilities, '{}'::jsonb);

View File

@ -0,0 +1,48 @@
ALTER TABLE gateway_wallet_accounts
ADD COLUMN IF NOT EXISTS total_recharged numeric NOT NULL DEFAULT 0,
ADD COLUMN IF NOT EXISTS total_spent numeric NOT NULL DEFAULT 0;
ALTER TABLE gateway_wallet_transactions
ADD COLUMN IF NOT EXISTS account_id uuid REFERENCES gateway_wallet_accounts(id) ON DELETE CASCADE,
ADD COLUMN IF NOT EXISTS direction text,
ADD COLUMN IF NOT EXISTS balance_before numeric,
ADD COLUMN IF NOT EXISTS idempotency_key text;
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'gateway_wallet_transactions'
AND column_name = 'wallet_account_id'
) THEN
EXECUTE 'UPDATE gateway_wallet_transactions SET account_id = wallet_account_id WHERE account_id IS NULL';
END IF;
END $$;
UPDATE gateway_wallet_transactions
SET direction = CASE
WHEN transaction_type IN ('recharge', 'refund', 'credit') THEN 'credit'
ELSE 'debit'
END
WHERE direction IS NULL;
UPDATE gateway_wallet_transactions
SET balance_before = COALESCE(balance_after, 0) + CASE
WHEN direction = 'debit' THEN COALESCE(amount, 0)
ELSE -COALESCE(amount, 0)
END
WHERE balance_before IS NULL;
ALTER TABLE gateway_wallet_transactions
ALTER COLUMN direction SET DEFAULT 'debit',
ALTER COLUMN direction SET NOT NULL,
ALTER COLUMN balance_before SET DEFAULT 0,
ALTER COLUMN balance_before SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_gateway_wallet_transactions_account
ON gateway_wallet_transactions(account_id, created_at DESC);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_gateway_wallet_tx_idempotency
ON gateway_wallet_transactions(account_id, idempotency_key)
WHERE idempotency_key IS NOT NULL;

View File

@ -49,6 +49,7 @@ import {
listPricingRules,
listPricingRuleSets,
listRuntimePolicySets,
listTasks,
listPublicBaseModels,
listPublicCatalogProviders,
listRateLimitWindows,
@ -121,6 +122,7 @@ type DataKey =
| 'tenants'
| 'users'
| 'userGroups'
| 'tasks'
| 'accessRules'
| 'apiKeys';
@ -157,6 +159,7 @@ export function App() {
const [selectedPlaygroundApiKeyId, setSelectedPlaygroundApiKeyId] = useState('');
const [taskForm, setTaskForm] = useState<TaskForm>({ kind: 'chat.completions', model: 'gpt-4o-mini', prompt: '用一句话确认 AI Gateway simulation 链路正常。' });
const [taskResult, setTaskResult] = useState<GatewayTask | null>(null);
const [tasks, setTasks] = useState<GatewayTask[]>([]);
const [coreState, setCoreState] = useState<LoadState>('idle');
const [coreMessage, setCoreMessage] = useState('');
const [state, setState] = useState<LoadState>('idle');
@ -228,10 +231,11 @@ export function App() {
rateLimitWindows,
runtimePolicySets,
taskResult,
tasks,
tenants,
userGroups,
users,
}), [accessRules, apiKeys, baseModels, models, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runtimePolicySets, taskResult, tenants, userGroups, users]);
}), [accessRules, apiKeys, baseModels, models, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runtimePolicySets, taskResult, tasks, tenants, userGroups, users]);
async function refresh(nextToken = token) {
await ensureRouteData(nextToken, true);
@ -327,6 +331,9 @@ export function App() {
case 'userGroups':
setUserGroups((await listUserGroups(nextToken)).items);
return;
case 'tasks':
setTasks((await listTasks(nextToken)).items);
return;
case 'accessRules':
setAccessRules((await (activePage === 'workspace' && workspaceSection === 'apiKeys'
? listApiKeyAccessRules(nextToken)
@ -628,6 +635,8 @@ export function App() {
const response = await runTask(credential, taskForm);
const detail = await getTask(credential, response.task.id);
setTaskResult(detail);
setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]);
invalidateDataKeys('tasks');
setCoreState('ready');
setCoreMessage(`${taskForm.kind} 已通过 ${apiKeySecret ? '本地 API Key' : '当前 Access Token'} 完成 simulation。`);
} catch (err) {
@ -660,6 +669,7 @@ export function App() {
setApiKeySecretsById({});
setSelectedPlaygroundApiKeyId('');
setTaskResult(null);
setTasks([]);
setCoreMessage('');
navigatePath('/');
}
@ -852,10 +862,16 @@ export function App() {
function platformModelIsSelected(model: PlatformModel, selectedModels: PlatformModelBindingInput[]) {
return selectedModels.some((selected) => {
if (selected.baseModelId && model.baseModelId) return selected.baseModelId === model.baseModelId;
return selected.modelName === model.modelName && selected.modelType === model.modelType;
return selected.modelName === model.modelName && sameModelTypes(selected.modelType, model.modelType);
});
}
function sameModelTypes(left: string[], right: string[]) {
if (left.length !== right.length) return false;
const rightSet = new Set(right);
return left.every((type) => rightSet.has(type));
}
function mergeExistingPlatformModelInput(input: PlatformModelBindingInput, currentModels: PlatformModel[], platformId: string): PlatformModelBindingInput {
const existing = currentModels.find((model) => model.platformId === platformId && platformModelIsSelected(model, [input]));
if (!existing) return input;
@ -924,6 +940,7 @@ function dataKeysForRoute(
if (activePage === 'workspace') {
if (workspaceSection === 'overview') return ['users', 'userGroups', 'apiKeys'];
if (workspaceSection === 'apiKeys') return ['apiKeys', 'accessRules', 'playgroundModels'];
if (workspaceSection === 'tasks') return ['tasks'];
return [];
}

View File

@ -582,6 +582,10 @@ export async function getTask(token: string, taskId: string): Promise<GatewayTas
return request<GatewayTask>(`/api/v1/tasks/${taskId}`, { token });
}
export async function listTasks(token: string, limit = 50): Promise<ListResponse<GatewayTask>> {
return request<ListResponse<GatewayTask>>(`/api/v1/tasks?limit=${encodeURIComponent(String(limit))}`, { token });
}
export function resolveApiAssetUrl(src: string) {
if (/^(https?:|data:|blob:)/i.test(src)) return src;
return `${API_BASE}${src.startsWith('/') ? src : `/${src}`}`;

View File

@ -27,6 +27,7 @@ export interface ConsoleData {
rateLimitWindows: RateLimitWindow[];
runtimePolicySets: RuntimePolicySet[];
taskResult: GatewayTask | null;
tasks: GatewayTask[];
tenants: GatewayTenant[];
userGroups: UserGroup[];
users: GatewayUser[];

View File

@ -73,7 +73,7 @@ export function Dashboard(props: {
<DataPanel
columns={['模型', '类型', '平台', '启用']}
empty="暂无模型数据"
rows={props.models.map((item) => [item.modelName, item.modelType, item.provider ?? item.platformName ?? '-', item.enabled ? '是' : '否'])}
rows={props.models.map((item) => [item.modelName, item.modelType.join(', '), item.provider ?? item.platformName ?? '-', item.enabled ? '是' : '否'])}
title="模型"
/>
</section>

View File

@ -173,8 +173,8 @@ function identityPanelProps(props: {
function OverviewPanel(props: { data: ConsoleData; stats: StatItem[] }) {
const enabledPlatforms = props.data.platforms.filter((item) => item.status === 'enabled');
const chatModels = props.data.models.filter((item) => item.modelType === 'chat' && item.enabled);
const imageModels = props.data.models.filter((item) => item.modelType === 'image' && item.enabled);
const chatModels = props.data.models.filter((item) => item.modelType.includes('text_generate') && item.enabled);
const imageModels = props.data.models.filter((item) => item.modelType.some((type) => type.includes('image')) && item.enabled);
return (
<div className="pageStack">
@ -229,7 +229,7 @@ function OverviewPanel(props: { data: ConsoleData; stats: StatItem[] }) {
rows={[
['对话', chatModels.length, 'Phase 1', '/v1/chat/completions'],
['图像', imageModels.length, 'Phase 1', '/v1/images/*'],
['视频', props.data.models.filter((item) => item.modelType === 'video').length, 'Next', '/v1/videos/*'],
['视频', props.data.models.filter((item) => item.modelType.some((type) => type.includes('video'))).length, 'Next', '/v1/videos/*'],
]}
/>
</CardContent>

View File

@ -9,7 +9,7 @@ import type {
import type { ConsoleData } from '../app-state';
import { PageHeader } from '../components/PageHeader';
import { Badge, Card, CardContent, Input } from '../components/ui';
import { primaryBaseModelType, stableModelAlias } from './admin/platform-form';
import { stableModelAlias } from './admin/platform-form';
type ModelListItem = {
id: string;
@ -17,7 +17,7 @@ type ModelListItem = {
platformName?: string;
modelName: string;
modelAlias?: string;
modelType: string;
modelType: string[];
displayName: string;
capabilities?: Record<string, unknown>;
pricingMode: string;
@ -69,7 +69,7 @@ const publicModels: PlatformModel[] = [
platformName: 'OpenAI Simulation',
modelName: 'gpt-4o-mini',
modelAlias: 'gpt-4o-mini',
modelType: 'chat',
modelType: ['text_generate'],
displayName: 'gpt-4o-mini',
capabilities: { multimodal: true },
pricingMode: 'inherit',
@ -84,7 +84,7 @@ const publicModels: PlatformModel[] = [
platformName: 'OpenAI Simulation',
modelName: 'gpt-image-1',
modelAlias: 'gpt-image-1',
modelType: 'image',
modelType: ['image_generate', 'image_edit'],
displayName: 'gpt-image-1',
capabilities: { imageEdit: true },
pricingMode: 'inherit',
@ -99,7 +99,7 @@ const publicModels: PlatformModel[] = [
platformName: 'Gemini Simulation',
modelName: 'gemini-2.0-flash',
modelAlias: 'gemini-2.0-flash',
modelType: 'chat',
modelType: ['text_generate'],
displayName: 'gemini-2.0-flash',
capabilities: { multimodal: true, vision: true },
pricingMode: 'inherit_discount',
@ -148,7 +148,7 @@ export function ModelsPage(props: { data: ConsoleData }) {
return sourceModels.filter((model) => {
const providerInfo = providerMap.get(model.providerKey);
const matchedProvider = provider === 'all' || model.providerKey === provider;
const matchedCapability = capability === 'all' || model.modelType === capability;
const matchedCapability = modelMatchesCapability(model.modelType, capability);
const matchedQuery = [
model.modelName,
model.modelAlias,
@ -331,7 +331,7 @@ function modelFromBaseModel(model: BaseModelCatalogItem): ModelListItem {
providerKey: model.providerKey,
modelName: model.providerModelName,
modelAlias: stableModelAlias(model),
modelType: primaryBaseModelType(model),
modelType: model.modelType,
displayName: stableModelAlias(model),
capabilities: model.capabilities,
pricingMode: 'inherit',
@ -363,7 +363,7 @@ function providerInitials(label: string) {
}
function tagsForModel(model: ModelListItem) {
const tags = [capabilityName(model.modelType)];
const tags = model.modelType.map(capabilityName);
const capabilities = model.capabilities ?? {};
if (capabilities.multimodal || capabilities.vision) tags.push('多模态');
if (capabilities.reasoning) tags.push('推理');
@ -373,7 +373,23 @@ function tagsForModel(model: ModelListItem) {
}
function capabilityName(type: string) {
return capabilityFilters.find((item) => item.value === type)?.label ?? type;
const labels: Record<string, string> = {
text_generate: '对话',
image_generate: '绘图',
image_edit: '图像编辑',
video_generate: '视频',
image_to_video: '图生视频',
audio_generate: '音频',
};
return labels[type] ?? capabilityFilters.find((item) => item.value === type)?.label ?? type;
}
function modelMatchesCapability(modelTypes: string[], capability: string) {
if (capability === 'all') return true;
if (capability === 'chat') return modelTypes.includes('text_generate') || modelTypes.includes('chat');
if (capability === 'image') return modelTypes.some((type) => type.includes('image'));
if (capability === 'video') return modelTypes.some((type) => type.includes('video'));
return modelTypes.includes(capability);
}
function priceLabel(model: ModelListItem) {

View File

@ -988,8 +988,8 @@ function filterModelsForMode(models: PlatformModel[], mode: PlaygroundMode, hasR
}
function filterWithFallback(models: PlatformModel[], modelTypes: string[]) {
const exact = models.filter((model) => modelTypes.includes(model.modelType));
return exact.length ? exact : models.filter((model) => modelTypes.some((type) => model.modelType.includes(type) || type.includes(model.modelType)));
const exact = models.filter((model) => model.modelType.some((type) => modelTypes.includes(type)));
return exact.length ? exact : models.filter((model) => modelTypes.some((type) => model.modelType.some((modelType) => modelType.includes(type) || type.includes(modelType))));
}
function buildModelOptions(models: PlatformModel[]): ModelOption[] {

View File

@ -1,6 +1,6 @@
import { useMemo, useState, type FormEvent, type ReactNode } from 'react';
import { Copy, CreditCard, KeyRound, ListChecks, Plus, ShieldCheck, Trash2, UserRound } from 'lucide-react';
import type { GatewayAccessRuleBatchRequest, GatewayApiKey, IntegrationPlatform, PlatformModel } from '@easyai-ai-gateway/contracts';
import type { GatewayAccessRuleBatchRequest, GatewayApiKey, GatewayTask, IntegrationPlatform, PlatformModel } from '@easyai-ai-gateway/contracts';
import type { ConsoleData } from '../app-state';
import { EntityTable } from '../components/EntityTable';
import { Badge, Button, Card, CardContent, CardHeader, CardTitle, ConfirmDialog, DateTimePicker, FormDialog, Input, Label, Table, TableCell, TableHead, TableRow, Tabs } from '../components/ui';
@ -297,30 +297,22 @@ function ApiKeyPanel(props: {
}
function TaskPanel(props: { data: ConsoleData }) {
const task = props.data.taskResult;
const usage = task?.usage ?? {};
const tokenText = usage.totalTokens ? `${usage.totalTokens}` : '-';
const chargeText = task?.finalChargeAmount ? `${task.finalChargeAmount}` : '-';
const tasks = useMemo(() => {
const latest = props.data.taskResult;
if (!latest) return props.data.tasks;
return [latest, ...props.data.tasks.filter((item) => item.id !== latest.id)];
}, [props.data.taskResult, props.data.tasks]);
return (
<Card>
<CardHeader>
<CardTitle></CardTitle>
</CardHeader>
<CardContent>
{task ? (
<div className="taskPreview">
<Badge variant={task.status === 'succeeded' ? 'success' : 'secondary'}>{task.status}</Badge>
<strong>{task.kind}</strong>
<span>{task.model}</span>
<div className="infoGrid compact">
<InfoItem label="API Key" value={task.apiKeyName || task.apiKeyId || '-'} />
<InfoItem label="RequestID" value={task.requestId || '-'} />
<InfoItem label="实际模型" value={task.resolvedModel || task.model} />
<InfoItem label="Token" value={tokenText} />
<InfoItem label="扣费" value={chargeText} />
<InfoItem label="响应耗时" value={task.responseDurationMs ? `${task.responseDurationMs}ms` : '-'} />
</div>
<pre>{JSON.stringify({ result: task.result, usage: task.usage, billings: task.billings, billingSummary: task.billingSummary, metrics: task.metrics }, null, 2)}</pre>
{tasks.length ? (
<div className="taskList">
{tasks.map((task) => (
<TaskRecord key={task.id} task={task} />
))}
</div>
) : (
<div className="emptyState">
@ -332,6 +324,34 @@ function TaskPanel(props: { data: ConsoleData }) {
);
}
function TaskRecord(props: { task: GatewayTask }) {
const usage = props.task.usage ?? {};
const tokenText = usage.totalTokens ? `${usage.totalTokens}` : '-';
const chargeText = props.task.finalChargeAmount !== undefined ? `${props.task.finalChargeAmount}` : '-';
const badgeVariant = props.task.status === 'succeeded' ? 'success' : props.task.status === 'failed' ? 'destructive' : 'secondary';
return (
<div className="taskPreview">
<div className="taskRecordHeader">
<Badge variant={badgeVariant}>{props.task.status}</Badge>
<strong>{props.task.kind}</strong>
<span>{props.task.model}</span>
<span>{formatDateTime(props.task.createdAt)}</span>
</div>
<div className="infoGrid compact">
<InfoItem label="API Key" value={props.task.apiKeyName || props.task.apiKeyId || '-'} />
<InfoItem label="RequestID" value={props.task.requestId || '-'} />
<InfoItem label="模型类型" value={props.task.modelType || '-'} />
<InfoItem label="实际模型" value={props.task.resolvedModel || props.task.model} />
<InfoItem label="Token" value={tokenText} />
<InfoItem label="扣费" value={chargeText} />
<InfoItem label="响应耗时" value={props.task.responseDurationMs ? `${props.task.responseDurationMs}ms` : '-'} />
<InfoItem label="错误" value={props.task.errorCode || props.task.errorMessage || '-'} />
</div>
<pre>{JSON.stringify({ result: props.task.result, usage: props.task.usage, billings: props.task.billings, billingSummary: props.task.billingSummary, metrics: props.task.metrics }, null, 2)}</pre>
</div>
);
}
function InfoItem(props: { label: string; value: string }) {
return (
<div className="infoItem">

View File

@ -304,7 +304,7 @@ function buildPlatformTree(platforms: IntegrationPlatform[], platformModels: Pla
.map((model) => ({
id: model.id,
name: modelLabel(model),
subtitle: `${model.modelType} / ${model.modelName}`,
subtitle: `${model.modelType.join(', ')} / ${model.modelName}`,
})),
})).sort((a, b) => a.name.localeCompare(b.name));
}

View File

@ -59,7 +59,7 @@ export function PlatformManagementPanel(props: {
model.displayName,
model.modelName,
model.modelAlias,
model.modelType,
...model.modelType,
model.provider,
platform?.name,
platform?.internalName,
@ -473,7 +473,7 @@ function PlatformModelTable(props: {
<ModelCatalogCard
key={model.id}
badges={[
<Badge variant="outline">{model.modelType}</Badge>,
<Badge variant="outline">{model.modelType.join(', ')}</Badge>,
<Badge variant={model.enabled ? 'success' : 'secondary'}>{model.enabled ? 'enabled' : 'disabled'}</Badge>,
]}
chips={platformModelChips(model)}
@ -889,7 +889,7 @@ function findBaseModelForPlatformModel(platform: IntegrationPlatform | undefined
return baseModels.find((item) => item.id === model.baseModelId) ??
baseModels.find((item) => item.canonicalModelKey === model.modelAlias) ??
baseModels.find((item) => stableModelAlias(item) === model.modelAlias) ??
baseModels.find((item) => item.providerKey === platform?.provider && item.providerModelName === model.modelName && baseModelTypes(item).includes(model.modelType));
baseModels.find((item) => item.providerKey === platform?.provider && item.providerModelName === model.modelName && model.modelType.some((type) => baseModelTypes(item).includes(type)));
}
function readPlatformModelIconPath(model: PlatformModel, baseModel?: BaseModelCatalogItem) {

View File

@ -149,7 +149,7 @@ export function platformModelPayloads(models: BaseModelCatalogItem[], form: Plat
canonicalModelKey: model.canonicalModelKey,
modelName: model.providerModelName,
modelAlias: stableModelAlias(model),
modelType: primaryBaseModelType(model),
modelType: baseModelTypes(model),
displayName: stableModelAlias(model) || model.providerModelName,
pricingMode: 'inherit_discount',
discountFactor: optionalPositiveNumber(form.modelDiscountFactors[model.id]) ?? optionalPositiveNumber(form.modelDiscountFactor),

View File

@ -638,7 +638,7 @@ function capabilityTypeKeys(
: [contextKey, 'text_to_video', 'image_to_video', 'omni_video', 'video_generate', 'video'];
return uniqueStrings([
...stringListFromCapability(source.originalTypes),
model.modelType,
...model.modelType,
...modeTypeHints.filter((item): item is string => Boolean(item)),
]);
}

View File

@ -224,6 +224,23 @@ strong {
gap: 14px;
}
.taskList {
display: grid;
gap: 14px;
}
.taskRecordHeader {
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 8px 12px;
color: var(--muted-foreground);
}
.taskRecordHeader strong {
color: var(--text-strong);
}
.appShell {
min-height: 100vh;
}

View File

@ -57,7 +57,7 @@ export interface PlatformModelForm {
canonicalModelKey: string;
modelName: string;
modelAlias: string;
modelType: string;
modelType: string[];
pricingRuleSetId: string;
discountFactor: string;
}
@ -84,7 +84,7 @@ export interface PlatformModelBindingInput {
baseModelId?: string;
modelName: string;
modelAlias?: string;
modelType: string;
modelType: string[];
displayName?: string;
pricingMode?: string;
retryPolicy?: Record<string, unknown>;

View File

@ -36,6 +36,19 @@
状态约定:`未执行`、`执行中`、`成功`、`失败`、`阻塞`。
### 本次自动化回环记录
| 项目 | 结果 |
| --- | --- |
| 测试批次 | `loopback-20260511-integration` |
| 执行方式 | `httptest` 本地服务 + 独立 PostgreSQL 测试库 `easyai_ai_gateway_test_codex` |
| 已通过验证 | 注册 / 登录 / API Key、API Key scopes 拦截、管理员接口拒绝 API Key、模型候选路由、真实 Chat / 兼容 Chat / 文生图 / 图生图 / 文生视频 / 图生视频 / 首尾帧视频、任务详情、任务列表、事件流、callback outbox、重试、运行策略限流、降级、自动禁用、计费规则集绑定、钱包扣费和交易流水 |
| 核心修复验证 | `base_model_catalog.model_type``platform_models.model_type` 均为 JSONB 数组;`doubao-4.5图像编辑` 同时包含 `image_generate` / `image_edit``豆包Seedance-1.5-pro` 同时包含 `video_generate` / `image_to_video`;运行时候选查询只用 `model_type` 命中 |
| 真实外部链路 | 已使用主库 `Playground API Key` 和已配置平台 KEY 验证通过Chat=`qwen-plus`,图像=`doubao-4.5图像编辑`,视频=`豆包Seedance-1.5-pro` |
| 真实任务证据 | Chat `6b454b1f-29b8-48ef-8ebb-740062d2e8b4`;文生图 `2db64a7e-f01d-424c-a5eb-027ef58cacde`;图生图 `af4cae00-c6d7-4c70-8477-64326c5d5cfc`;文生视频 `2436e7aa-3518-442b-8559-1e48a9cb3462`;图生视频 `80cc9655-4c87-466d-b263-5df23d23c157`;首尾帧视频 `5353a0cf-efa8-4fbc-90fd-60d1897fd012` |
| 验证命令 | `AI_GATEWAY_TEST_DATABASE_URL=... go test ./...`、`pnpm --filter @easyai-ai-gateway/web typecheck`、`pnpm test`、`git diff --check` |
| 结论 | 本地自动化覆盖项和真实外部模型链路均为 `成功` |
## 3. 测试数据准备
| ID | 任务 | 接口 / 方式 | 成功判定 | 状态 | 结果记录 |
@ -98,6 +111,7 @@
| HISTORY-03 | 事件记录保存 | `gateway_task_events` | 事件包含接收、运行、重试、完成或失败,`simulated` 标记准确 | 未执行 | 事件摘要待填写 |
| HISTORY-04 | 回调 outbox 保存 | `gateway_task_callback_outbox`,开启回调配置时验证 | 每个 task event 写入幂等 outbox`task_id + seq + callback_url` 唯一 | 未执行 | outbox 数量待填写 |
| HISTORY-05 | API Key 身份保存 | 任务详情或数据库字段 | `apiKeyId`、`apiKeyName`、`apiKeyPrefix` 可追溯到发起 Key | 未执行 | key 信息待填写 |
| HISTORY-06 | 前端任务记录 | `GET /api/v1/tasks` + 工作台“任务记录”页 | 页面从服务端任务列表读取历史记录,不依赖当前会话里刚运行的单条 taskResult | 成功 | 已新增任务列表接口和前端列表渲染;自动化断言 task list 包含已完成任务 |
## 6. 定价规则绑定与扣费计算
@ -109,6 +123,7 @@
| PRICE-02 | 绑定规则集到平台模型 | `POST /api/v1/platforms/{platformID}/models` 或更新模型 | 平台模型返回 `pricingRuleSetId=loopback-pricing-*``billingConfigOverride` 生效 | 未执行 | platformModelId 待填写 |
| PRICE-03 | Chat 预估计费 | `POST /api/v1/pricing/estimate`Chat 请求体 | 返回 `resolver=effective-pricing-v1``text_input` 与 `text_output` 金额按 tokens 计算 | 未执行 | estimate items 待填写 |
| PRICE-04 | Chat 最终扣费 | 执行 Chat 成功任务 | `finalChargeAmount=sum(billings.amount)`,输入/输出 token 数与 usage 对齐 | 未执行 | usage、billings、finalChargeAmount 待填写 |
| PRICE-04B | 钱包实际扣费 | 执行 Chat 成功任务后读取钱包 | `gateway_wallet_accounts.balance``finalChargeAmount` 扣减,`gateway_wallet_transactions` 写入 `task_billing` 幂等流水 | 成功 | 集成测试断言 100 -> 99.972,交易 amount=0.028 |
| PRICE-05 | 文生图扣费 | 执行文生图成功任务 | `resourceType=image`,数量、尺寸、质量权重参与计算 | 未执行 | billings 待填写 |
| PRICE-06 | 图生图扣费 | 执行图生图成功任务 | `resourceType=image_edit`,优先使用编辑价格,缺省时回退图片价格 | 未执行 | billings 待填写 |
| PRICE-07 | 视频扣费 | 执行文生视频、图生视频、首尾帧视频成功任务 | `resourceType=video`,数量、分辨率或时长权重按规则计算,三类视频请求均有计费记录 | 未执行 | billings 待填写 |
@ -175,7 +190,10 @@
| 时间 | 用例 ID | 失败现象 | 初步原因 | 修复位置 | 重跑结果 |
| --- | --- | --- | --- | --- | --- |
| 待填写 | 待填写 | 待填写 | 待填写 | 待填写 | 待填写 |
| 2026-05-11 | TASK-IMAGE-01 | `doubao-4.5图像编辑` 使用 `1024x1024` 返回 `http_400`,提示像素不足 | 测试参数不符合真实模型最小尺寸约束 | 调整真实验证请求为 `2048x2048` | 文生图任务 `2db64a7e-f01d-424c-a5eb-027ef58cacde` 成功 |
| 2026-05-11 | TASK-VIDEO-02 | 图生视频传 `image` 时上游按 `r2v` 处理并返回不支持 `Seedance-1.5-pro` | Volces 适配层把通用 `image` 映射成 `reference_image`,而非图生视频首帧 | `apps/api/internal/clients/volces.go`,将 `image` 默认映射为 `first_frame`,显式 `reference_image` 才走参考图 | 图生视频任务 `80cc9655-4c87-466d-b263-5df23d23c157` 成功 |
| 2026-05-11 | TASK-VIDEO-02 / TASK-VIDEO-03 | 带图片的视频请求任务 `modelType` 仍记录为 `video_generate` | `videos.generations` 未根据请求体区分文生视频和图生视频 | `apps/api/internal/runner/service.go` / `pricing.go`,带图片、首帧或尾帧时使用 `image_to_video` | 图生视频与首尾帧视频均以 `modelType=image_to_video` 成功 |
| 2026-05-11 | HISTORY-06 | 前端“任务记录”没有任何记录 | 前端只展示本地 `taskResult`,后端没有当前用户任务列表接口 | `GET /api/v1/tasks`、`ListTasks`、工作台任务列表状态和渲染 | 集成测试验证任务列表返回已完成任务;前端 typecheck 成功 |
## 12. 清理清单

View File

@ -560,7 +560,7 @@ export interface PlatformModel {
platformName?: string;
modelName: string;
modelAlias?: string;
modelType: 'chat' | 'image' | 'video' | 'audio' | 'embedding' | string;
modelType: string[];
displayName: string;
capabilityOverride?: Record<string, unknown>;
capabilities?: Record<string, unknown>;
@ -625,6 +625,8 @@ export interface GatewayTask {
responseDurationMs?: number;
finishedAt?: string;
error?: string;
errorCode?: string;
errorMessage?: string;
createdAt: string;
updatedAt: string;
}