fix gateway loopback validation chains
This commit is contained in:
parent
ff666b1ece
commit
ca7e76e815
@ -40,6 +40,7 @@ type User struct {
|
||||
APIKeySecret string `json:"apiKeySecret,omitempty"`
|
||||
APIKeyName string `json:"apiKeyName,omitempty"`
|
||||
APIKeyPrefix string `json:"apiKeyPrefix,omitempty"`
|
||||
APIKeyScopes []string `json:"apiKeyScopes,omitempty"`
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
@ -137,6 +138,7 @@ func (a *Authenticator) verifyJWT(tokenString string) (*User, error) {
|
||||
APIKeySecret: stringClaim(claims, "apiKeySecret"),
|
||||
APIKeyName: stringClaim(claims, "apiKeyName"),
|
||||
APIKeyPrefix: stringClaim(claims, "apiKeyPrefix"),
|
||||
APIKeyScopes: stringSliceClaim(claims, "apiKeyScopes"),
|
||||
}
|
||||
if user.Source == "" {
|
||||
user.Source = "gateway"
|
||||
@ -167,6 +169,7 @@ func (a *Authenticator) SignJWT(user *User, ttl time.Duration) (string, error) {
|
||||
"apiKeyId": user.APIKeyID,
|
||||
"apiKeyName": user.APIKeyName,
|
||||
"apiKeyPrefix": user.APIKeyPrefix,
|
||||
"apiKeyScopes": user.APIKeyScopes,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(ttl).Unix(),
|
||||
}
|
||||
|
||||
@ -34,7 +34,8 @@ func (c SimulationClient) Run(ctx context.Context, request Request) (Response, e
|
||||
}
|
||||
}
|
||||
responseFinishedAt := time.Now()
|
||||
if profile == "retryable_failure" {
|
||||
switch profile {
|
||||
case "retryable_failure":
|
||||
return Response{}, &ClientError{
|
||||
Code: "server_error",
|
||||
Message: "simulated retryable failure",
|
||||
@ -44,8 +45,27 @@ func (c SimulationClient) Run(ctx context.Context, request Request) (Response, e
|
||||
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
|
||||
Retryable: true,
|
||||
}
|
||||
}
|
||||
if profile == "fatal_failure" || profile == "non_retryable_failure" {
|
||||
case "rate_limit", "overloaded":
|
||||
return Response{}, &ClientError{
|
||||
Code: profile,
|
||||
Message: "simulated " + profile,
|
||||
RequestID: "simulated-request",
|
||||
ResponseStartedAt: responseStartedAt,
|
||||
ResponseFinishedAt: responseFinishedAt,
|
||||
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
|
||||
Retryable: true,
|
||||
}
|
||||
case "invalid_api_key":
|
||||
return Response{}, &ClientError{
|
||||
Code: "invalid_api_key",
|
||||
Message: "simulated invalid_api_key",
|
||||
RequestID: "simulated-request",
|
||||
ResponseStartedAt: responseStartedAt,
|
||||
ResponseFinishedAt: responseFinishedAt,
|
||||
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
|
||||
Retryable: false,
|
||||
}
|
||||
case "fatal_failure", "non_retryable_failure":
|
||||
return Response{}, &ClientError{
|
||||
Code: "bad_request",
|
||||
Message: "simulated non-retryable failure",
|
||||
@ -184,7 +204,7 @@ func simulatedVideoData(request Request) []any {
|
||||
}
|
||||
|
||||
func simulatedUsage(request Request) Usage {
|
||||
if request.ModelType == "chat" || request.Kind == "responses" {
|
||||
if request.ModelType == "chat" || request.ModelType == "text_generate" || request.Kind == "responses" {
|
||||
return Usage{InputTokens: 12, OutputTokens: 8, TotalTokens: 20}
|
||||
}
|
||||
return Usage{}
|
||||
|
||||
@ -251,9 +251,18 @@ func buildVolcesContentFromBody(body map[string]any) []map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
appendURLContent("image_url", "first_frame", firstNonEmptyStringValue(body, "first_frame", "firstFrame"))
|
||||
firstFrame := firstNonEmptyStringValue(body, "first_frame", "firstFrame")
|
||||
appendURLContent("image_url", "first_frame", firstFrame)
|
||||
appendURLContent("image_url", "last_frame", firstNonEmptyStringValue(body, "last_frame", "lastFrame"))
|
||||
for _, url := range firstNonEmptyStringListFromAny(body["image"], body["images"], body["image_url"], body["imageUrl"], body["image_urls"], body["imageUrls"], body["reference_image"], body["referenceImage"]) {
|
||||
imageURLs := firstNonEmptyStringListFromAny(body["image"], body["images"], body["image_url"], body["imageUrl"], body["image_urls"], body["imageUrls"])
|
||||
if firstFrame == "" && len(imageURLs) > 0 {
|
||||
appendURLContent("image_url", "first_frame", imageURLs[0])
|
||||
imageURLs = imageURLs[1:]
|
||||
}
|
||||
for _, url := range imageURLs {
|
||||
appendURLContent("image_url", "reference_image", url)
|
||||
}
|
||||
for _, url := range firstNonEmptyStringListFromAny(body["reference_image"], body["referenceImage"]) {
|
||||
appendURLContent("image_url", "reference_image", url)
|
||||
}
|
||||
for _, url := range firstNonEmptyStringListFromAny(body["video"], body["video_url"], body["videoUrl"], body["reference_video"], body["referenceVideo"]) {
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@ -123,8 +124,43 @@ func TestCoreLocalFlow(t *testing.T) {
|
||||
if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil {
|
||||
t.Fatalf("promote smoke user: %v", err)
|
||||
}
|
||||
var smokeGatewayUserID string
|
||||
if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil {
|
||||
t.Fatalf("read smoke gateway user id: %v", err)
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodGet, "/api/admin/models", apiKeyResponse.Secret, nil, http.StatusForbidden, nil)
|
||||
|
||||
var chatOnlyAPIKeyResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
APIKey struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"apiKey"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/api-keys", loginResponse.AccessToken, map[string]any{
|
||||
"name": "chat only key",
|
||||
"scopes": []string{"chat"},
|
||||
}, http.StatusCreated, &chatOnlyAPIKeyResponse)
|
||||
var taskCountBefore int
|
||||
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountBefore); err != nil {
|
||||
t.Fatalf("count tasks before scoped request: %v", err)
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||
"model": "gpt-image-1",
|
||||
"prompt": "scope should block this",
|
||||
}, http.StatusForbidden, nil)
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/pricing/estimate", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||
"kind": "images.generations",
|
||||
"model": "gpt-image-1",
|
||||
"prompt": "scope should block this estimate",
|
||||
}, http.StatusForbidden, nil)
|
||||
var taskCountAfter int
|
||||
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountAfter); err != nil {
|
||||
t.Fatalf("count tasks after scoped request: %v", err)
|
||||
}
|
||||
if taskCountAfter != taskCountBefore {
|
||||
t.Fatalf("scoped API key rejection should happen before task creation, before=%d after=%d", taskCountBefore, taskCountAfter)
|
||||
}
|
||||
|
||||
inviteCode := "INVITE-" + suffixText
|
||||
if _, err := testPool.Exec(ctx, `
|
||||
INSERT INTO gateway_invitations (invite_code, max_uses, metadata)
|
||||
@ -322,6 +358,317 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
t.Fatalf("unexpected image edit task: %+v", imageEditResponse.Task)
|
||||
}
|
||||
|
||||
var gptImageModelTypesRaw []byte
|
||||
if err := testPool.QueryRow(ctx, `
|
||||
SELECT model_type
|
||||
FROM platform_models
|
||||
WHERE model_name = 'gpt-image-1'
|
||||
AND model_type @> '["image_generate","image_edit"]'::jsonb
|
||||
LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil {
|
||||
t.Fatalf("gpt-image-1 platform model should store both image model types: %v", err)
|
||||
}
|
||||
var gptImageModelTypes []string
|
||||
if err := json.Unmarshal(gptImageModelTypesRaw, &gptImageModelTypes); err != nil {
|
||||
t.Fatalf("decode gpt-image-1 model_type: %v raw=%s", err, string(gptImageModelTypesRaw))
|
||||
}
|
||||
if !stringSliceContains(gptImageModelTypes, "image_generate") || !stringSliceContains(gptImageModelTypes, "image_edit") {
|
||||
t.Fatalf("gpt-image-1 model_type should include generation and edit types: %+v", gptImageModelTypes)
|
||||
}
|
||||
|
||||
deniedModel := "permission-deny-smoke-" + suffixText
|
||||
var deniedPlatformModel struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": deniedModel,
|
||||
"modelAlias": deniedModel,
|
||||
"modelType": []string{"text_generate"},
|
||||
"displayName": "Permission Deny Smoke",
|
||||
}, http.StatusCreated, &deniedPlatformModel)
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{
|
||||
"subjectType": "api_key",
|
||||
"subjectId": chatOnlyAPIKeyResponse.APIKey.ID,
|
||||
"resourceType": "platform_model",
|
||||
"resourceId": deniedPlatformModel.ID,
|
||||
"effect": "deny",
|
||||
"priority": 10,
|
||||
"minPermissionLevel": 0,
|
||||
"status": "active",
|
||||
}, http.StatusCreated, nil)
|
||||
var deniedTask struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||
"model": deniedModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "permission deny"}},
|
||||
}, http.StatusAccepted, &deniedTask)
|
||||
if deniedTask.Task.Status != "failed" || deniedTask.Task.ErrorCode != "no_model_candidate" {
|
||||
t.Fatalf("deny access rule should hide denied model from runtime candidates: %+v", deniedTask.Task)
|
||||
}
|
||||
var restrictedModels struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
ModelName string `json:"modelName"`
|
||||
} `json:"items"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodGet, "/api/v1/playground/models", chatOnlyAPIKeyResponse.Secret, nil, http.StatusOK, &restrictedModels)
|
||||
if modelListContains(restrictedModels.Items, deniedPlatformModel.ID) {
|
||||
t.Fatalf("deny access rule should hide denied model from playable list: %+v", restrictedModels.Items)
|
||||
}
|
||||
|
||||
controlledModel := "permission-allow-smoke-" + suffixText
|
||||
var controlledPlatformModel struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": controlledModel,
|
||||
"modelAlias": controlledModel,
|
||||
"modelType": []string{"text_generate"},
|
||||
"displayName": "Permission Allow Smoke",
|
||||
}, http.StatusCreated, &controlledPlatformModel)
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{
|
||||
"subjectType": "api_key",
|
||||
"subjectId": apiKeyResponse.APIKey.ID,
|
||||
"resourceType": "platform_model",
|
||||
"resourceId": controlledPlatformModel.ID,
|
||||
"effect": "allow",
|
||||
"priority": 10,
|
||||
"minPermissionLevel": 0,
|
||||
"status": "active",
|
||||
}, http.StatusCreated, nil)
|
||||
var blockedControlledTask struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||
"model": controlledModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "allow should block other keys"}},
|
||||
}, http.StatusAccepted, &blockedControlledTask)
|
||||
if blockedControlledTask.Task.Status != "failed" || blockedControlledTask.Task.ErrorCode != "no_model_candidate" {
|
||||
t.Fatalf("allow access rule should make the resource unavailable to unmatched subjects: %+v", blockedControlledTask.Task)
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/access-rules", loginResponse.AccessToken, map[string]any{
|
||||
"subjectType": "api_key",
|
||||
"subjectId": chatOnlyAPIKeyResponse.APIKey.ID,
|
||||
"resourceType": "platform_model",
|
||||
"resourceId": controlledPlatformModel.ID,
|
||||
"effect": "allow",
|
||||
"priority": 10,
|
||||
"minPermissionLevel": 0,
|
||||
"status": "active",
|
||||
}, http.StatusCreated, nil)
|
||||
var allowedControlledTask struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||
"model": controlledModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "allow should pass"}},
|
||||
}, http.StatusAccepted, &allowedControlledTask)
|
||||
if allowedControlledTask.Task.Status != "succeeded" {
|
||||
t.Fatalf("matching allow access rule should make the controlled model usable: %+v", allowedControlledTask.Task)
|
||||
}
|
||||
|
||||
var customPricingRuleSet struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/pricing/rule-sets", loginResponse.AccessToken, map[string]any{
|
||||
"ruleSetKey": "smoke-pricing-" + suffixText,
|
||||
"name": "Smoke Pricing",
|
||||
"currency": "resource",
|
||||
"rules": []map[string]any{
|
||||
{"ruleKey": "text_input", "displayName": "Text Input", "resourceType": "text_input", "unit": "1k_tokens", "basePrice": 1},
|
||||
{"ruleKey": "text_output", "displayName": "Text Output", "resourceType": "text_output", "unit": "1k_tokens", "basePrice": 2},
|
||||
{"ruleKey": "image", "displayName": "Image", "resourceType": "image", "unit": "image", "basePrice": 7},
|
||||
{"ruleKey": "image_edit", "displayName": "Image Edit", "resourceType": "image_edit", "unit": "image", "basePrice": 11},
|
||||
{"ruleKey": "video", "displayName": "Video", "resourceType": "video", "unit": "video", "basePrice": 13},
|
||||
},
|
||||
}, http.StatusCreated, &customPricingRuleSet)
|
||||
pricingModel := "pricing-smoke-" + suffixText
|
||||
var pricingPlatformModel map[string]any
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": pricingModel,
|
||||
"modelAlias": pricingModel,
|
||||
"modelType": []string{"text_generate"},
|
||||
"displayName": "Pricing Smoke",
|
||||
"pricingRuleSetId": customPricingRuleSet.ID,
|
||||
}, http.StatusCreated, &pricingPlatformModel)
|
||||
if _, err := testPool.Exec(ctx, `
|
||||
UPDATE gateway_wallet_accounts
|
||||
SET balance = 100, total_spent = 0, updated_at = now()
|
||||
WHERE gateway_user_id = $1::uuid
|
||||
AND currency = 'resource'`, smokeGatewayUserID); err != nil {
|
||||
t.Fatalf("seed wallet balance: %v", err)
|
||||
}
|
||||
var walletBalanceBefore float64
|
||||
if err := testPool.QueryRow(ctx, `
|
||||
SELECT balance::float8
|
||||
FROM gateway_wallet_accounts
|
||||
WHERE gateway_user_id = $1::uuid
|
||||
AND currency = 'resource'`, smokeGatewayUserID).Scan(&walletBalanceBefore); err != nil {
|
||||
t.Fatalf("read wallet balance before pricing task: %v", err)
|
||||
}
|
||||
var pricingTask struct {
|
||||
Task struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
BillingSummary map[string]any `json:"billingSummary"`
|
||||
FinalChargeAmount float64 `json:"finalChargeAmount"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": pricingModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "priced ping"}},
|
||||
}, http.StatusAccepted, &pricingTask)
|
||||
if pricingTask.Task.Status != "succeeded" || !floatNear(pricingTask.Task.FinalChargeAmount, 0.028) {
|
||||
t.Fatalf("custom pricing rule set should drive text billing, got task=%+v", pricingTask.Task)
|
||||
}
|
||||
var walletBalanceAfter float64
|
||||
var walletSpentAfter float64
|
||||
if err := testPool.QueryRow(ctx, `
|
||||
SELECT balance::float8, total_spent::float8
|
||||
FROM gateway_wallet_accounts
|
||||
WHERE gateway_user_id = $1::uuid
|
||||
AND currency = 'resource'`, smokeGatewayUserID).Scan(&walletBalanceAfter, &walletSpentAfter); err != nil {
|
||||
t.Fatalf("read wallet balance after pricing task: %v", err)
|
||||
}
|
||||
if !floatNear(walletBalanceAfter, walletBalanceBefore-pricingTask.Task.FinalChargeAmount) || !floatNear(walletSpentAfter, pricingTask.Task.FinalChargeAmount) {
|
||||
t.Fatalf("task billing should debit wallet balance and spent totals, before=%f after=%f spent=%f task=%+v", walletBalanceBefore, walletBalanceAfter, walletSpentAfter, pricingTask.Task)
|
||||
}
|
||||
var walletTransactionAmount float64
|
||||
if err := testPool.QueryRow(ctx, `
|
||||
SELECT amount::float8
|
||||
FROM gateway_wallet_transactions
|
||||
WHERE reference_type = 'gateway_task'
|
||||
AND reference_id = $1
|
||||
AND transaction_type = 'task_billing'`, pricingTask.Task.ID).Scan(&walletTransactionAmount); err != nil {
|
||||
t.Fatalf("read task billing wallet transaction: %v", err)
|
||||
}
|
||||
if !floatNear(walletTransactionAmount, pricingTask.Task.FinalChargeAmount) {
|
||||
t.Fatalf("task billing transaction amount=%f want=%f", walletTransactionAmount, pricingTask.Task.FinalChargeAmount)
|
||||
}
|
||||
|
||||
rateLimitedModel := "rate-limit-smoke-" + suffixText
|
||||
var rateLimitPolicySet struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{
|
||||
"policyKey": "smoke-rate-limit-" + suffixText,
|
||||
"name": "Smoke Rate Limit",
|
||||
"retryPolicy": map[string]any{
|
||||
"enabled": false,
|
||||
"maxAttempts": 1,
|
||||
},
|
||||
"rateLimitPolicy": map[string]any{
|
||||
"rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": 60}},
|
||||
},
|
||||
}, http.StatusCreated, &rateLimitPolicySet)
|
||||
var rateLimitPlatformModel map[string]any
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": rateLimitedModel,
|
||||
"modelAlias": rateLimitedModel,
|
||||
"modelType": []string{"text_generate"},
|
||||
"displayName": "Rate Limit Smoke",
|
||||
"runtimePolicySetId": rateLimitPolicySet.ID,
|
||||
"runtimePolicyOverride": map[string]any{},
|
||||
}, http.StatusCreated, &rateLimitPlatformModel)
|
||||
var rateLimitTaskOne struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": rateLimitedModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "first"}},
|
||||
}, http.StatusAccepted, &rateLimitTaskOne)
|
||||
if rateLimitTaskOne.Task.Status != "succeeded" {
|
||||
t.Fatalf("first rate-limited task should succeed: %+v", rateLimitTaskOne.Task)
|
||||
}
|
||||
var rateLimitTaskTwo struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": rateLimitedModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "second"}},
|
||||
}, http.StatusAccepted, &rateLimitTaskTwo)
|
||||
if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" {
|
||||
t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task)
|
||||
}
|
||||
|
||||
videoRouteModel := "video-route-smoke-" + suffixText
|
||||
var videoRoutePlatformModel map[string]any
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": videoRouteModel,
|
||||
"modelAlias": videoRouteModel,
|
||||
"modelType": []string{"video_generate", "image_to_video"},
|
||||
"displayName": "Video Route Smoke",
|
||||
}, http.StatusCreated, &videoRoutePlatformModel)
|
||||
var textToVideoTask struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
ModelType string `json:"modelType"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/videos/generations", apiKeyResponse.Secret, map[string]any{
|
||||
"model": videoRouteModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"prompt": "text to video route",
|
||||
}, http.StatusAccepted, &textToVideoTask)
|
||||
if textToVideoTask.Task.Status != "succeeded" || textToVideoTask.Task.ModelType != "video_generate" {
|
||||
t.Fatalf("text-to-video request should use video_generate model_type: %+v", textToVideoTask.Task)
|
||||
}
|
||||
var imageToVideoTask struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
ModelType string `json:"modelType"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/videos/generations", apiKeyResponse.Secret, map[string]any{
|
||||
"model": videoRouteModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"prompt": "image to video route",
|
||||
"image": "https://example.com/source.png",
|
||||
}, http.StatusAccepted, &imageToVideoTask)
|
||||
if imageToVideoTask.Task.Status != "succeeded" || imageToVideoTask.Task.ModelType != "image_to_video" {
|
||||
t.Fatalf("image-to-video request should use image_to_video model_type: %+v", imageToVideoTask.Task)
|
||||
}
|
||||
|
||||
failoverModel := "phase1-failover-" + suffixText
|
||||
var failedPlatform struct {
|
||||
ID string `json:"id"`
|
||||
@ -353,7 +700,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": failoverModel,
|
||||
"modelAlias": failoverModel,
|
||||
"modelType": "chat",
|
||||
"modelType": []string{"text_generate"},
|
||||
"displayName": "Failover Smoke",
|
||||
"retryPolicy": map[string]any{"enabled": true, "maxAttempts": 2},
|
||||
}, http.StatusCreated, &platformModel)
|
||||
@ -376,6 +723,142 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
t.Fatalf("failover task should succeed through second client: %+v", failoverTask.Task)
|
||||
}
|
||||
|
||||
var degradePolicySet struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{
|
||||
"policyKey": "smoke-degrade-" + suffixText,
|
||||
"name": "Smoke Degrade",
|
||||
"degradePolicy": map[string]any{
|
||||
"enabled": true,
|
||||
"keywords": []string{"rate_limit"},
|
||||
"cooldownSeconds": 600,
|
||||
},
|
||||
}, http.StatusCreated, °radePolicySet)
|
||||
degradeModel := "degrade-smoke-" + suffixText
|
||||
var degradedPlatform struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{
|
||||
"provider": "openai",
|
||||
"platformKey": "openai-degrade-" + suffixText,
|
||||
"name": "OpenAI Degrade Failure",
|
||||
"baseUrl": "https://api.openai.com/v1",
|
||||
"authType": "bearer",
|
||||
"credentials": map[string]any{"mode": "simulation", "simulationFailure": "rate_limit"},
|
||||
"priority": 30,
|
||||
}, http.StatusCreated, °radedPlatform)
|
||||
var degradeSuccessPlatform struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{
|
||||
"provider": "openai",
|
||||
"platformKey": "openai-degrade-success-" + suffixText,
|
||||
"name": "OpenAI Degrade Success",
|
||||
"baseUrl": "https://api.openai.com/v1",
|
||||
"authType": "bearer",
|
||||
"credentials": map[string]any{"mode": "simulation"},
|
||||
"priority": 40,
|
||||
}, http.StatusCreated, °radeSuccessPlatform)
|
||||
for _, item := range []struct {
|
||||
platformID string
|
||||
runtimePolicySetID string
|
||||
}{
|
||||
{platformID: degradedPlatform.ID, runtimePolicySetID: degradePolicySet.ID},
|
||||
{platformID: degradeSuccessPlatform.ID},
|
||||
} {
|
||||
var platformModel map[string]any
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+item.platformID+"/models", loginResponse.AccessToken, map[string]any{
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": degradeModel,
|
||||
"modelAlias": degradeModel,
|
||||
"modelType": []string{"text_generate"},
|
||||
"displayName": "Degrade Smoke",
|
||||
"runtimePolicySetId": item.runtimePolicySetID,
|
||||
}, http.StatusCreated, &platformModel)
|
||||
}
|
||||
var degradeTask struct {
|
||||
Task struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": degradeModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "degrade please"}},
|
||||
}, http.StatusAccepted, °radeTask)
|
||||
if degradeTask.Task.Status != "succeeded" {
|
||||
t.Fatalf("degrade task should fail over after cooling down failed platform: %+v", degradeTask.Task)
|
||||
}
|
||||
var cooledDown bool
|
||||
if err := testPool.QueryRow(ctx, `SELECT COALESCE(cooldown_until > now(), false) FROM integration_platforms WHERE id = $1::uuid`, degradedPlatform.ID).Scan(&cooledDown); err != nil {
|
||||
t.Fatalf("read degraded platform cooldown: %v", err)
|
||||
}
|
||||
if !cooledDown {
|
||||
t.Fatal("degrade policy should set platform cooldown_until")
|
||||
}
|
||||
|
||||
var autoDisablePolicySet struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/runtime/policy-sets", loginResponse.AccessToken, map[string]any{
|
||||
"policyKey": "smoke-auto-disable-" + suffixText,
|
||||
"name": "Smoke Auto Disable",
|
||||
"autoDisablePolicy": map[string]any{
|
||||
"enabled": true,
|
||||
"keywords": []string{"invalid_api_key"},
|
||||
"threshold": 1,
|
||||
},
|
||||
}, http.StatusCreated, &autoDisablePolicySet)
|
||||
autoDisableModel := "auto-disable-smoke-" + suffixText
|
||||
var invalidKeyPlatform struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms", loginResponse.AccessToken, map[string]any{
|
||||
"provider": "openai",
|
||||
"platformKey": "openai-invalid-key-" + suffixText,
|
||||
"name": "OpenAI Invalid Key",
|
||||
"baseUrl": "https://api.openai.com/v1",
|
||||
"authType": "bearer",
|
||||
"credentials": map[string]any{"mode": "simulation", "simulationFailure": "invalid_api_key"},
|
||||
"priority": 50,
|
||||
}, http.StatusCreated, &invalidKeyPlatform)
|
||||
var invalidKeyPlatformModel map[string]any
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+invalidKeyPlatform.ID+"/models", loginResponse.AccessToken, map[string]any{
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": autoDisableModel,
|
||||
"modelAlias": autoDisableModel,
|
||||
"modelType": []string{"text_generate"},
|
||||
"displayName": "Auto Disable Smoke",
|
||||
"runtimePolicySetId": autoDisablePolicySet.ID,
|
||||
}, http.StatusCreated, &invalidKeyPlatformModel)
|
||||
var autoDisableTask struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": autoDisableModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "disable please"}},
|
||||
}, http.StatusAccepted, &autoDisableTask)
|
||||
if autoDisableTask.Task.Status != "failed" || autoDisableTask.Task.ErrorCode != "invalid_api_key" {
|
||||
t.Fatalf("auto disable task should fail with invalid_api_key: %+v", autoDisableTask.Task)
|
||||
}
|
||||
var invalidKeyPlatformStatus string
|
||||
if err := testPool.QueryRow(ctx, `SELECT status FROM integration_platforms WHERE id = $1::uuid`, invalidKeyPlatform.ID).Scan(&invalidKeyPlatformStatus); err != nil {
|
||||
t.Fatalf("read invalid key platform status: %v", err)
|
||||
}
|
||||
if invalidKeyPlatformStatus != "disabled" {
|
||||
t.Fatalf("auto disable policy should disable platform, got %q", invalidKeyPlatformStatus)
|
||||
}
|
||||
|
||||
var taskDetail struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
@ -393,6 +876,21 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
if taskDetail.APIKeyName != apiKeyResponse.APIKey.Name || taskDetail.RequestID == "" || taskDetail.Usage["totalTokens"] == nil || taskDetail.FinalChargeAmount <= 0 {
|
||||
t.Fatalf("task detail should expose enriched record fields: %+v", taskDetail)
|
||||
}
|
||||
var taskList struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
APIKeyName string `json:"apiKeyName"`
|
||||
ModelType string `json:"modelType"`
|
||||
FinalCharge float64 `json:"finalChargeAmount"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
ErrorMessage string `json:"errorMessage"`
|
||||
} `json:"items"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks?limit=20", loginResponse.AccessToken, nil, http.StatusOK, &taskList)
|
||||
if !taskListContains(taskList.Items, taskResponse.Task.ID) || !taskListContains(taskList.Items, pricingTask.Task.ID) {
|
||||
t.Fatalf("task list should include persisted task records, got %+v", taskList.Items)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, server.URL+"/api/v1/tasks/"+taskResponse.Task.ID+"/events", nil)
|
||||
if err != nil {
|
||||
@ -411,6 +909,9 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
if !bytes.Contains(body, []byte("task.progress")) {
|
||||
t.Fatalf("events response should include progress events body=%s", string(body))
|
||||
}
|
||||
if !bytes.Contains(body, []byte("task.billing.settled")) {
|
||||
t.Fatalf("events response should include billing settlement event body=%s", string(body))
|
||||
}
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, server.URL+"/api/v1/tasks/"+failoverTask.Task.ID+"/events", nil)
|
||||
if err != nil {
|
||||
@ -508,3 +1009,45 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stringSliceContains(values []string, target string) bool {
|
||||
for _, value := range values {
|
||||
if value == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func modelListContains(values []struct {
|
||||
ID string `json:"id"`
|
||||
ModelName string `json:"modelName"`
|
||||
}, target string) bool {
|
||||
for _, value := range values {
|
||||
if value.ID == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func taskListContains(values []struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
APIKeyName string `json:"apiKeyName"`
|
||||
ModelType string `json:"modelType"`
|
||||
FinalCharge float64 `json:"finalChargeAmount"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
ErrorMessage string `json:"errorMessage"`
|
||||
}, target string) bool {
|
||||
for _, value := range values {
|
||||
if value.ID == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func floatNear(value float64, expected float64) bool {
|
||||
return math.Abs(value-expected) < 0.000001
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -468,6 +469,10 @@ func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
if !apiKeyScopeAllowed(user, kind) {
|
||||
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
|
||||
return
|
||||
}
|
||||
estimate, err := s.runner.Estimate(r.Context(), kind, model, body, user)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNoModelCandidate) {
|
||||
@ -509,6 +514,10 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
writeError(w, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
if !apiKeyScopeAllowed(user, kind) {
|
||||
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
|
||||
return
|
||||
}
|
||||
|
||||
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
|
||||
Kind: kind,
|
||||
@ -567,6 +576,36 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func apiKeyScopeAllowed(user *auth.User, kind string) bool {
|
||||
if user == nil || strings.TrimSpace(user.APIKeyID) == "" || len(user.APIKeyScopes) == 0 {
|
||||
return true
|
||||
}
|
||||
required := scopeForTaskKind(kind)
|
||||
for _, scope := range user.APIKeyScopes {
|
||||
scope = strings.TrimSpace(strings.ToLower(scope))
|
||||
if scope == "*" || scope == "all" || scope == required {
|
||||
return true
|
||||
}
|
||||
if required == "chat" && (scope == "text" || scope == "text_generate") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func scopeForTaskKind(kind string) string {
|
||||
switch kind {
|
||||
case "chat.completions", "responses":
|
||||
return "chat"
|
||||
case "images.generations", "images.edits":
|
||||
return "image"
|
||||
case "videos.generations":
|
||||
return "video"
|
||||
default:
|
||||
return kind
|
||||
}
|
||||
}
|
||||
|
||||
func statusFromRunError(err error) int {
|
||||
switch {
|
||||
case errors.Is(err, store.ErrNoModelCandidate):
|
||||
@ -578,6 +617,30 @@ func statusFromRunError(err error) int {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := auth.UserFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
limit := 50
|
||||
if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" {
|
||||
parsed, err := strconv.Atoi(raw)
|
||||
if err != nil || parsed <= 0 {
|
||||
writeError(w, http.StatusBadRequest, "invalid limit")
|
||||
return
|
||||
}
|
||||
limit = parsed
|
||||
}
|
||||
tasks, err := s.store.ListTasks(r.Context(), user, limit)
|
||||
if err != nil {
|
||||
s.logger.Error("list tasks failed", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "list tasks failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": tasks})
|
||||
}
|
||||
|
||||
func boolValue(body map[string]any, key string) bool {
|
||||
value, _ := body[key].(bool)
|
||||
return value
|
||||
|
||||
@ -101,6 +101,7 @@ func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Han
|
||||
mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false)))
|
||||
mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false)))
|
||||
mux.Handle("POST /api/v1/videos/generations", server.auth.Require(auth.PermissionBasic, server.createTask("videos.generations", false)))
|
||||
mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
||||
mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
||||
mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
||||
mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
|
||||
|
||||
@ -18,15 +18,36 @@ func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, ca
|
||||
}
|
||||
|
||||
func effectiveRateLimitPolicy(candidate store.RuntimeModelCandidate) map[string]any {
|
||||
if hasRules(candidate.ModelRateLimitPolicy) {
|
||||
return candidate.ModelRateLimitPolicy
|
||||
policy := candidate.PlatformRateLimitPolicy
|
||||
if hasRules(candidate.RuntimeRateLimitPolicy) {
|
||||
policy = mergeMap(policy, candidate.RuntimeRateLimitPolicy)
|
||||
}
|
||||
if hasRules(candidate.PlatformRateLimitPolicy) {
|
||||
return candidate.PlatformRateLimitPolicy
|
||||
if nested, ok := candidate.RuntimePolicyOverride["rateLimitPolicy"].(map[string]any); ok && len(nested) > 0 {
|
||||
policy = mergeMap(policy, nested)
|
||||
}
|
||||
if hasRules(candidate.ModelRateLimitPolicy) {
|
||||
policy = mergeMap(policy, candidate.ModelRateLimitPolicy)
|
||||
}
|
||||
if hasRules(policy) {
|
||||
return policy
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func effectiveRetryPolicy(candidate store.RuntimeModelCandidate) map[string]any {
|
||||
policy := candidate.PlatformRetryPolicy
|
||||
if len(candidate.RuntimeRetryPolicy) > 0 {
|
||||
policy = mergeMap(policy, candidate.RuntimeRetryPolicy)
|
||||
}
|
||||
if nested, ok := candidate.RuntimePolicyOverride["retryPolicy"].(map[string]any); ok && len(nested) > 0 {
|
||||
policy = mergeMap(policy, nested)
|
||||
}
|
||||
if len(candidate.ModelRetryPolicy) > 0 {
|
||||
policy = mergeMap(policy, candidate.ModelRetryPolicy)
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string]any, body map[string]any) []store.RateLimitReservation {
|
||||
if scopeKey == "" || !hasRules(policy) {
|
||||
return nil
|
||||
|
||||
@ -16,7 +16,7 @@ type EstimateResult struct {
|
||||
}
|
||||
|
||||
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
|
||||
candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind), user)
|
||||
candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind, body), user)
|
||||
if err != nil {
|
||||
return EstimateResult{}, err
|
||||
}
|
||||
@ -38,7 +38,7 @@ func (s *Service) Estimate(ctx context.Context, kind string, model string, body
|
||||
}
|
||||
|
||||
func (s *Service) billings(ctx context.Context, user *auth.User, kind string, body map[string]any, candidate store.RuntimeModelCandidate, response clients.Response, simulated bool) []any {
|
||||
config := effectiveBillingConfig(candidate)
|
||||
config := s.effectiveBillingConfig(ctx, candidate)
|
||||
discount := effectiveDiscount(ctx, s.store, user, candidate)
|
||||
if isTextGenerationKind(kind) {
|
||||
inputTokens := response.Usage.InputTokens
|
||||
@ -74,11 +74,21 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo
|
||||
return []any{billingLine(candidate, resource, unit, count, roundPrice(amount), discount, simulated)}
|
||||
}
|
||||
|
||||
func effectiveBillingConfig(candidate store.RuntimeModelCandidate) map[string]any {
|
||||
func (s *Service) effectiveBillingConfig(ctx context.Context, candidate store.RuntimeModelCandidate) map[string]any {
|
||||
base := candidate.BaseBillingConfig
|
||||
if ruleSetID := firstNonEmptyString(candidate.BasePricingRuleSetID, candidate.PlatformPricingRuleSetID); ruleSetID != "" {
|
||||
if ruleSetConfig, err := s.store.PricingRuleSetBillingConfig(ctx, ruleSetID); err == nil && len(ruleSetConfig) > 0 {
|
||||
base = ruleSetConfig
|
||||
}
|
||||
}
|
||||
if len(candidate.BillingConfig) > 0 {
|
||||
base = candidate.BillingConfig
|
||||
}
|
||||
if candidate.ModelPricingRuleSetID != "" {
|
||||
if ruleSetConfig, err := s.store.PricingRuleSetBillingConfig(ctx, candidate.ModelPricingRuleSetID); err == nil && len(ruleSetConfig) > 0 {
|
||||
base = ruleSetConfig
|
||||
}
|
||||
}
|
||||
if len(candidate.BillingConfigOverride) > 0 {
|
||||
base = mergeMap(base, candidate.BillingConfigOverride)
|
||||
}
|
||||
|
||||
84
apps/api/internal/runner/runtime_policy.go
Normal file
84
apps/api/internal/runner/runtime_policy.go
Normal file
@ -0,0 +1,84 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func (s *Service) applyCandidateFailurePolicies(ctx context.Context, taskID string, candidate store.RuntimeModelCandidate, cause error, simulated bool) {
|
||||
code := clients.ErrorCode(cause)
|
||||
message := ""
|
||||
if cause != nil {
|
||||
message = cause.Error()
|
||||
}
|
||||
|
||||
autoDisablePolicy := effectiveRuntimePolicy(candidate.AutoDisablePolicy, candidate.RuntimePolicyOverride, "autoDisablePolicy")
|
||||
if failurePolicyMatches(autoDisablePolicy, code, message) && intFromPolicy(autoDisablePolicy, "threshold") <= 1 {
|
||||
if err := s.store.DisableCandidatePlatform(ctx, candidate.PlatformID); err == nil {
|
||||
_ = s.emit(ctx, taskID, "task.policy.auto_disabled", "running", "auto_disable", 0.48, "candidate platform disabled by failure policy", map[string]any{
|
||||
"platformId": candidate.PlatformID,
|
||||
"platformModelId": candidate.PlatformModelID,
|
||||
"code": code,
|
||||
}, simulated)
|
||||
}
|
||||
}
|
||||
|
||||
degradePolicy := effectiveRuntimePolicy(candidate.DegradePolicy, candidate.RuntimePolicyOverride, "degradePolicy")
|
||||
if failurePolicyMatches(degradePolicy, code, message) {
|
||||
cooldownSeconds := intFromPolicy(degradePolicy, "cooldownSeconds")
|
||||
if err := s.store.CooldownCandidatePlatform(ctx, candidate.PlatformID, cooldownSeconds); err == nil {
|
||||
_ = s.emit(ctx, taskID, "task.policy.degraded", "running", "degrade", 0.5, "candidate platform cooled down by failure policy", map[string]any{
|
||||
"platformId": candidate.PlatformID,
|
||||
"platformModelId": candidate.PlatformModelID,
|
||||
"cooldownSeconds": cooldownSeconds,
|
||||
"code": code,
|
||||
}, simulated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func effectiveRuntimePolicy(base map[string]any, override map[string]any, key string) map[string]any {
|
||||
policy := base
|
||||
if nested, ok := override[key].(map[string]any); ok && len(nested) > 0 {
|
||||
policy = mergeMap(policy, nested)
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func failurePolicyMatches(policy map[string]any, code string, message string) bool {
|
||||
if len(policy) == 0 || !boolFromMap(policy, "enabled") {
|
||||
return false
|
||||
}
|
||||
keywords := stringListFromPolicy(policy, "keywords")
|
||||
if len(keywords) == 0 {
|
||||
return false
|
||||
}
|
||||
target := strings.ToLower(strings.TrimSpace(code + " " + message))
|
||||
for _, keyword := range keywords {
|
||||
keyword = strings.ToLower(strings.TrimSpace(keyword))
|
||||
if keyword != "" && strings.Contains(target, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringListFromPolicy(values map[string]any, key string) []string {
|
||||
raw, ok := values[key].([]any)
|
||||
if !ok {
|
||||
if typed, ok := values[key].([]string); ok {
|
||||
return typed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
if text, ok := item.(string); ok && strings.TrimSpace(text) != "" {
|
||||
out = append(out, text)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -50,8 +50,8 @@ func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, use
|
||||
}
|
||||
|
||||
func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
|
||||
modelType := modelTypeFromKind(task.Kind)
|
||||
body := normalizeRequest(task.Kind, task.Request)
|
||||
modelType := modelTypeFromKind(task.Kind, body)
|
||||
if err := validateRequest(task.Kind, body); err != nil {
|
||||
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
|
||||
if finishErr != nil {
|
||||
@ -102,6 +102,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
if finishErr != nil {
|
||||
return Result{}, finishErr
|
||||
}
|
||||
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
|
||||
return Result{}, settleErr
|
||||
}
|
||||
if finished.FinalChargeAmount > 0 {
|
||||
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
|
||||
"amount": finished.FinalChargeAmount,
|
||||
"currency": stringFromAny(record.BillingSummary["currency"]),
|
||||
}, isSimulation(task, candidate)); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
}
|
||||
if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{
|
||||
"result": response.Result,
|
||||
"billings": billings,
|
||||
@ -226,6 +237,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
_ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable, "requestId": requestID, "metrics": metrics}, simulated)
|
||||
s.applyCandidateFailurePolicies(ctx, task.ID, candidate, err, simulated)
|
||||
return clients.Response{}, err
|
||||
}
|
||||
uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result)
|
||||
@ -317,7 +329,7 @@ func (s *Service) emit(ctx context.Context, taskID string, eventType string, sta
|
||||
return nil
|
||||
}
|
||||
|
||||
func modelTypeFromKind(kind string) string {
|
||||
func modelTypeFromKind(kind string, body map[string]any) string {
|
||||
switch kind {
|
||||
case "chat.completions", "responses":
|
||||
return "text_generate"
|
||||
@ -327,12 +339,30 @@ func modelTypeFromKind(kind string) string {
|
||||
}
|
||||
return "image_generate"
|
||||
case "videos.generations":
|
||||
if videoRequestHasReferenceImage(body) {
|
||||
return "image_to_video"
|
||||
}
|
||||
return "video_generate"
|
||||
default:
|
||||
return "task"
|
||||
}
|
||||
}
|
||||
|
||||
func videoRequestHasReferenceImage(body map[string]any) bool {
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
for _, key := range []string{
|
||||
"image", "images", "image_url", "imageUrl", "image_urls", "imageUrls",
|
||||
"reference_image", "referenceImage", "first_frame", "firstFrame", "last_frame", "lastFrame",
|
||||
} {
|
||||
if hasAnyString(body, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isTextGenerationKind(kind string) bool {
|
||||
return kind == "chat.completions" || kind == "responses"
|
||||
}
|
||||
@ -345,10 +375,8 @@ func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate)
|
||||
}
|
||||
|
||||
func retryEnabled(candidate store.RuntimeModelCandidate) bool {
|
||||
if enabled, ok := candidate.ModelRetryPolicy["enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
if enabled, ok := candidate.PlatformRetryPolicy["enabled"].(bool); ok {
|
||||
policy := effectiveRetryPolicy(candidate)
|
||||
if enabled, ok := policy["enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
return true
|
||||
@ -360,10 +388,7 @@ func maxAttemptsForCandidates(candidates []store.RuntimeModelCandidate) int {
|
||||
}
|
||||
maxAttempts := len(candidates)
|
||||
for _, candidate := range candidates {
|
||||
if value := intFromPolicy(candidate.ModelRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts {
|
||||
maxAttempts = value
|
||||
}
|
||||
if value := intFromPolicy(candidate.PlatformRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts {
|
||||
if value := intFromPolicy(effectiveRetryPolicy(candidate), "maxAttempts"); value > 0 && value < maxAttempts {
|
||||
maxAttempts = value
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,7 +59,7 @@ func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT `+baseModelColumns+`
|
||||
FROM base_model_catalog
|
||||
ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`)
|
||||
ORDER BY provider_key ASC, canonical_model_key ASC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -84,7 +84,7 @@ func (s *Store) CreateBaseModel(ctx context.Context, input BaseModelInput) (Base
|
||||
runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride))
|
||||
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
||||
defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot))
|
||||
modelType := primaryString(input.ModelType, "text_generate")
|
||||
modelType, _ := json.Marshal(input.ModelType)
|
||||
|
||||
return scanBaseModel(s.pool.QueryRow(ctx, `
|
||||
INSERT INTO base_model_catalog (
|
||||
@ -94,7 +94,7 @@ INSERT INTO base_model_catalog (
|
||||
)
|
||||
VALUES (
|
||||
(SELECT id FROM model_catalog_providers WHERE provider_key = $1 OR provider_code = $1 LIMIT 1),
|
||||
$1, $2, $3, $4, $5, $6, $7, $8,
|
||||
$1, $2, $3, $4::jsonb, $5, $6, $7, $8,
|
||||
COALESCE(NULLIF($9, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)),
|
||||
COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)),
|
||||
$11, $12, NULLIF($13, ''), NULLIF($14::jsonb, '{}'::jsonb), $15, $16
|
||||
@ -103,7 +103,7 @@ RETURNING `+baseModelColumns,
|
||||
input.ProviderKey,
|
||||
input.CanonicalModelKey,
|
||||
input.ProviderModelName,
|
||||
modelType,
|
||||
string(modelType),
|
||||
input.ModelAlias,
|
||||
capabilities,
|
||||
billingConfig,
|
||||
@ -127,7 +127,7 @@ func (s *Store) UpdateBaseModel(ctx context.Context, id string, input BaseModelI
|
||||
runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride))
|
||||
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
||||
defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot))
|
||||
modelType := primaryString(input.ModelType, "text_generate")
|
||||
modelType, _ := json.Marshal(input.ModelType)
|
||||
|
||||
return scanBaseModel(s.pool.QueryRow(ctx, `
|
||||
UPDATE base_model_catalog
|
||||
@ -135,7 +135,7 @@ SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = $
|
||||
provider_key = $2,
|
||||
canonical_model_key = $3,
|
||||
provider_model_name = $4,
|
||||
model_type = $5,
|
||||
model_type = $5::jsonb,
|
||||
display_name = $6,
|
||||
capabilities = $7,
|
||||
base_billing_config = $8,
|
||||
@ -156,7 +156,7 @@ RETURNING `+baseModelColumns,
|
||||
input.ProviderKey,
|
||||
input.CanonicalModelKey,
|
||||
input.ProviderModelName,
|
||||
modelType,
|
||||
string(modelType),
|
||||
input.ModelAlias,
|
||||
capabilities,
|
||||
billingConfig,
|
||||
@ -195,7 +195,7 @@ SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = C
|
||||
provider_key = COALESCE($2::text, provider_key),
|
||||
canonical_model_key = COALESCE($3::text, canonical_model_key),
|
||||
provider_model_name = COALESCE($4::text, provider_model_name),
|
||||
model_type = COALESCE($5::text, model_type),
|
||||
model_type = COALESCE($5::jsonb, model_type),
|
||||
display_name = COALESCE($6::text, display_name),
|
||||
capabilities = COALESCE($7::jsonb, capabilities),
|
||||
base_billing_config = COALESCE($8::jsonb, base_billing_config),
|
||||
@ -214,7 +214,7 @@ RETURNING `+baseModelColumns,
|
||||
stringFromSnapshot(snapshot, "providerKey"),
|
||||
stringFromSnapshot(snapshot, "canonicalModelKey"),
|
||||
stringFromSnapshot(snapshot, "providerModelName"),
|
||||
stringFromSnapshot(snapshot, "modelType"),
|
||||
jsonStringListFromSnapshot(snapshot, "modelType"),
|
||||
stringFromSnapshot(snapshot, "modelAlias", "displayName"),
|
||||
jsonFromSnapshot(snapshot, "capabilities"),
|
||||
jsonFromSnapshot(snapshot, "baseBillingConfig"),
|
||||
@ -242,9 +242,10 @@ SET provider_id = (
|
||||
canonical_model_key = COALESCE(NULLIF(default_snapshot->>'canonicalModelKey', ''), canonical_model_key),
|
||||
provider_model_name = COALESCE(NULLIF(default_snapshot->>'providerModelName', ''), provider_model_name),
|
||||
model_type = COALESCE(NULLIF(CASE
|
||||
WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType'->>0
|
||||
ELSE default_snapshot->>'modelType'
|
||||
END, ''), model_type),
|
||||
WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType'
|
||||
WHEN COALESCE(default_snapshot->>'modelType', '') <> '' THEN jsonb_build_array(default_snapshot->>'modelType')
|
||||
ELSE NULL
|
||||
END, '[]'::jsonb), model_type),
|
||||
display_name = COALESCE(NULLIF(COALESCE(default_snapshot->>'modelAlias', default_snapshot->>'displayName'), ''), display_name),
|
||||
capabilities = COALESCE(default_snapshot->'capabilities', capabilities),
|
||||
base_billing_config = COALESCE(default_snapshot->'baseBillingConfig', base_billing_config),
|
||||
@ -292,7 +293,7 @@ func scanBaseModelRows(rows pgx.Rows) ([]BaseModel, error) {
|
||||
|
||||
func scanBaseModel(scanner baseModelScanner) (BaseModel, error) {
|
||||
var item BaseModel
|
||||
var modelType string
|
||||
var modelType []byte
|
||||
var modelAlias string
|
||||
var capabilities []byte
|
||||
var billingConfig []byte
|
||||
@ -330,7 +331,7 @@ func scanBaseModel(scanner baseModelScanner) (BaseModel, error) {
|
||||
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
|
||||
item.Metadata = decodeObject(metadata)
|
||||
item.DefaultSnapshot = decodeObject(defaultSnapshot)
|
||||
item.ModelType = baseModelTypes(item.Capabilities, item.Metadata, modelType)
|
||||
item.ModelType = decodeStringArray(modelType)
|
||||
item.ModelAlias = modelAlias
|
||||
item.DisplayName = modelAlias
|
||||
return item, nil
|
||||
@ -420,14 +421,22 @@ func jsonFromSnapshot(snapshot map[string]any, key string) any {
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func baseModelTypes(capabilities map[string]any, metadata map[string]any, fallback string) StringList {
|
||||
values := make([]string, 0)
|
||||
values = append(values, stringListFromAny(capabilities["originalTypes"])...)
|
||||
values = append(values, stringListFromAny(metadata["originalTypes"])...)
|
||||
if fallback != "" {
|
||||
values = append(values, fallback)
|
||||
func jsonStringListFromSnapshot(snapshot map[string]any, key string) any {
|
||||
values := stringListFromAny(snapshot[key])
|
||||
if len(values) == 0 {
|
||||
if value, ok := snapshot[key].(string); ok {
|
||||
values = []string{value}
|
||||
}
|
||||
}
|
||||
return uniqueStringList(values)
|
||||
normalized := uniqueStringList(values)
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
raw, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func stringListFromAny(value any) []string {
|
||||
@ -451,16 +460,39 @@ func uniqueStringList(values []string) StringList {
|
||||
out := make([]string, 0, len(values))
|
||||
seen := map[string]bool{}
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" || seen[value] {
|
||||
continue
|
||||
for _, normalized := range modelTypeAliases(value) {
|
||||
normalized = strings.TrimSpace(normalized)
|
||||
if normalized == "" || seen[normalized] {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = true
|
||||
out = append(out, normalized)
|
||||
}
|
||||
seen[value] = true
|
||||
out = append(out, value)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeModelTypeList(values []string) StringList {
|
||||
return uniqueStringList(values)
|
||||
}
|
||||
|
||||
func modelTypeAliases(value string) []string {
|
||||
switch strings.TrimSpace(value) {
|
||||
case "chat", "text", "responses":
|
||||
return []string{"text_generate"}
|
||||
case "image":
|
||||
return []string{"image_generate", "image_edit"}
|
||||
case "images.generations":
|
||||
return []string{"image_generate"}
|
||||
case "images.edits":
|
||||
return []string{"image_edit"}
|
||||
case "video", "videos.generations":
|
||||
return []string{"video_generate"}
|
||||
default:
|
||||
return []string{value}
|
||||
}
|
||||
}
|
||||
|
||||
func primaryString(values []string, fallback string) string {
|
||||
for _, value := range values {
|
||||
if value = strings.TrimSpace(value); value != "" {
|
||||
|
||||
@ -18,27 +18,25 @@ SELECT p.id::text, p.platform_key, p.name, p.provider,
|
||||
COALESCE(p.dynamic_priority, p.priority) AS effective_priority,
|
||||
m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''),
|
||||
COALESCE(b.provider_model_name, ''), m.model_name, COALESCE(m.model_alias, ''),
|
||||
m.model_type, m.display_name, m.capabilities, m.capability_override,
|
||||
$2 AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
|
||||
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
|
||||
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
|
||||
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
|
||||
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb)
|
||||
COALESCE(b.pricing_rule_set_id::text, ''),
|
||||
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
|
||||
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
|
||||
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
|
||||
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb)
|
||||
FROM platform_models m
|
||||
JOIN integration_platforms p ON p.id = m.platform_id
|
||||
LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider
|
||||
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
|
||||
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
|
||||
LEFT JOIN runtime_client_states s
|
||||
ON s.client_id = p.platform_key || ':' || m.model_type || ':' || m.model_name
|
||||
ON s.client_id = p.platform_key || ':' || $2 || ':' || m.model_name
|
||||
WHERE p.status = 'enabled'
|
||||
AND p.deleted_at IS NULL
|
||||
AND m.enabled = true
|
||||
AND (
|
||||
m.model_type = $2
|
||||
OR ($2 = 'text_generate' AND m.model_type IN ('chat', 'responses', 'text'))
|
||||
OR ($2 = 'image_generate' AND m.model_type IN ('image', 'images.generations'))
|
||||
OR ($2 = 'image_edit' AND m.model_type IN ('images.edits'))
|
||||
OR ($2 = 'video_generate' AND m.model_type IN ('video', 'videos.generations', 'video_generate', 'text_to_video', 'image_to_video', 'omni_video', 'video_edit', 'video_reference', 'video_first_last_frame'))
|
||||
)
|
||||
AND m.model_type @> jsonb_build_array($2)
|
||||
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
|
||||
AND (
|
||||
m.model_name = $1
|
||||
@ -73,6 +71,10 @@ ORDER BY effective_priority ASC,
|
||||
var modelRetryPolicy []byte
|
||||
var modelRateLimitPolicy []byte
|
||||
var runtimePolicyOverride []byte
|
||||
var runtimeRetryPolicy []byte
|
||||
var runtimeRateLimitPolicy []byte
|
||||
var autoDisablePolicy []byte
|
||||
var degradePolicy []byte
|
||||
if err := rows.Scan(
|
||||
&item.PlatformID,
|
||||
&item.PlatformKey,
|
||||
@ -105,11 +107,16 @@ ORDER BY effective_priority ASC,
|
||||
&item.PricingMode,
|
||||
&item.DiscountFactor,
|
||||
&item.ModelPricingRuleSetID,
|
||||
&item.BasePricingRuleSetID,
|
||||
&permissionConfig,
|
||||
&modelRetryPolicy,
|
||||
&modelRateLimitPolicy,
|
||||
&item.RuntimePolicySetID,
|
||||
&runtimePolicyOverride,
|
||||
&runtimeRetryPolicy,
|
||||
&runtimeRateLimitPolicy,
|
||||
&autoDisablePolicy,
|
||||
°radePolicy,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -126,6 +133,10 @@ ORDER BY effective_priority ASC,
|
||||
item.ModelRetryPolicy = decodeObject(modelRetryPolicy)
|
||||
item.ModelRateLimitPolicy = decodeObject(modelRateLimitPolicy)
|
||||
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
|
||||
item.RuntimeRetryPolicy = decodeObject(runtimeRetryPolicy)
|
||||
item.RuntimeRateLimitPolicy = decodeObject(runtimeRateLimitPolicy)
|
||||
item.AutoDisablePolicy = decodeObject(autoDisablePolicy)
|
||||
item.DegradePolicy = decodeObject(degradePolicy)
|
||||
item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, item.ModelName)
|
||||
item.QueueKey = item.ClientID
|
||||
items = append(items, item)
|
||||
|
||||
@ -17,7 +17,7 @@ type modelCatalogSnapshot struct {
|
||||
ProviderKey string
|
||||
CanonicalModelKey string
|
||||
ProviderModelName string
|
||||
ModelType string
|
||||
ModelType StringList
|
||||
DisplayName string
|
||||
Capabilities map[string]any
|
||||
BaseBillingConfig map[string]any
|
||||
@ -83,9 +83,13 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier,
|
||||
if err != nil && !IsNotFound(err) {
|
||||
return PlatformModel{}, err
|
||||
}
|
||||
if input.ModelType == "" {
|
||||
if len(input.ModelType) == 0 {
|
||||
input.ModelType = base.ModelType
|
||||
}
|
||||
input.ModelType = normalizeModelTypeList(input.ModelType)
|
||||
if len(input.ModelType) == 0 {
|
||||
input.ModelType = StringList{"text_generate"}
|
||||
}
|
||||
if input.ModelName == "" {
|
||||
input.ModelName = base.ProviderModelName
|
||||
}
|
||||
@ -104,11 +108,12 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier,
|
||||
if len(billingConfig) == 0 {
|
||||
billingConfig = mergeObjects(base.BaseBillingConfig, input.BillingConfigOverride)
|
||||
}
|
||||
explicitRuntimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID)
|
||||
rateLimitPolicy := input.RateLimitPolicy
|
||||
if len(rateLimitPolicy) == 0 {
|
||||
if len(rateLimitPolicy) == 0 && explicitRuntimePolicySetID == "" {
|
||||
rateLimitPolicy = base.DefaultRateLimitPolicy
|
||||
}
|
||||
runtimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID)
|
||||
runtimePolicySetID := explicitRuntimePolicySetID
|
||||
if runtimePolicySetID == "" {
|
||||
runtimePolicySetID = base.RuntimePolicySetID
|
||||
}
|
||||
@ -119,6 +124,7 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier,
|
||||
|
||||
capabilityOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.CapabilityOverride))
|
||||
capabilitiesJSON, _ := json.Marshal(emptyObjectIfNil(capabilities))
|
||||
modelTypeJSON, _ := json.Marshal(input.ModelType)
|
||||
billingOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingConfigOverride))
|
||||
billingJSON, _ := json.Marshal(emptyObjectIfNil(billingConfig))
|
||||
permissionJSON, _ := json.Marshal(emptyObjectIfNil(input.PermissionConfig))
|
||||
@ -144,6 +150,7 @@ func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier,
|
||||
var retryPolicyBytes []byte
|
||||
var rateLimitPolicyBytes []byte
|
||||
var runtimePolicyOverrideBytes []byte
|
||||
var modelTypeBytes []byte
|
||||
err = q.QueryRow(ctx, `
|
||||
INSERT INTO platform_models (
|
||||
platform_id, base_model_id, model_name, model_alias, model_type, display_name,
|
||||
@ -152,12 +159,12 @@ INSERT INTO platform_models (
|
||||
runtime_policy_set_id, runtime_policy_override, enabled
|
||||
)
|
||||
VALUES (
|
||||
$1::uuid, $2::uuid, $3, NULLIF($4, ''), $5, $6,
|
||||
$1::uuid, $2::uuid, $3, NULLIF($4, ''), $5::jsonb, $6,
|
||||
$7::jsonb, $8::jsonb, $9, $10::numeric,
|
||||
NULLIF($11, '')::uuid, $12::jsonb, $13::jsonb, $14::jsonb, $15::jsonb, $16::jsonb,
|
||||
NULLIF($17, '')::uuid, $18::jsonb, true
|
||||
)
|
||||
ON CONFLICT (platform_id, model_name, model_type) DO UPDATE
|
||||
ON CONFLICT (platform_id, model_name) DO UPDATE
|
||||
SET base_model_id = EXCLUDED.base_model_id,
|
||||
model_alias = EXCLUDED.model_alias,
|
||||
display_name = EXCLUDED.display_name,
|
||||
@ -185,7 +192,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_
|
||||
baseID,
|
||||
input.ModelName,
|
||||
input.ModelAlias,
|
||||
input.ModelType,
|
||||
string(modelTypeJSON),
|
||||
input.DisplayName,
|
||||
string(capabilityOverrideJSON),
|
||||
string(capabilitiesJSON),
|
||||
@ -205,7 +212,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_
|
||||
&model.BaseModelID,
|
||||
&model.ModelName,
|
||||
&model.ModelAlias,
|
||||
&model.ModelType,
|
||||
&modelTypeBytes,
|
||||
&model.DisplayName,
|
||||
&capabilityOverrideBytes,
|
||||
&capabilitiesBytes,
|
||||
@ -228,6 +235,7 @@ RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_
|
||||
}
|
||||
model.CapabilityOverride = decodeObject(capabilityOverrideBytes)
|
||||
model.Capabilities = decodeObject(capabilitiesBytes)
|
||||
model.ModelType = decodeStringArray(modelTypeBytes)
|
||||
model.BillingConfigOverride = decodeObject(billingOverrideBytes)
|
||||
model.BillingConfig = decodeObject(billingBytes)
|
||||
model.PermissionConfig = decodeObject(permissionBytes)
|
||||
@ -265,6 +273,7 @@ func (s *Store) lookupBaseModel(ctx context.Context, q platformModelQuerier, id
|
||||
var billingConfig []byte
|
||||
var rateLimitPolicy []byte
|
||||
var runtimePolicyOverride []byte
|
||||
var modelTypeBytes []byte
|
||||
err := q.QueryRow(ctx, `
|
||||
SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
|
||||
capabilities, base_billing_config, default_rate_limit_policy,
|
||||
@ -279,7 +288,7 @@ LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSp
|
||||
&item.ProviderKey,
|
||||
&item.CanonicalModelKey,
|
||||
&item.ProviderModelName,
|
||||
&item.ModelType,
|
||||
&modelTypeBytes,
|
||||
&item.DisplayName,
|
||||
&capabilities,
|
||||
&billingConfig,
|
||||
@ -297,6 +306,7 @@ LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSp
|
||||
item.BaseBillingConfig = decodeObject(billingConfig)
|
||||
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
|
||||
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
|
||||
item.ModelType = normalizeModelTypeList(decodeStringArray(modelTypeBytes))
|
||||
return item, nil
|
||||
}
|
||||
|
||||
|
||||
@ -134,7 +134,7 @@ type PlatformModel struct {
|
||||
PlatformName string `json:"platformName,omitempty"`
|
||||
ModelName string `json:"modelName"`
|
||||
ModelAlias string `json:"modelAlias,omitempty"`
|
||||
ModelType string `json:"modelType"`
|
||||
ModelType StringList `json:"modelType"`
|
||||
DisplayName string `json:"displayName"`
|
||||
CapabilityOverride map[string]any `json:"capabilityOverride,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
@ -390,6 +390,8 @@ type GatewayTask struct {
|
||||
ResponseDurationMS int64 `json:"responseDurationMs"`
|
||||
FinishedAt string `json:"finishedAt,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorCode string `json:"errorCode,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
@ -404,6 +406,7 @@ request, status, COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
|
||||
COALESCE(usage, '{}'::jsonb), COALESCE(metrics, '{}'::jsonb), COALESCE(billing_summary, '{}'::jsonb),
|
||||
COALESCE(final_charge_amount, 0)::float8, COALESCE(response_started_at::text, ''),
|
||||
COALESCE(response_finished_at::text, ''), COALESCE(response_duration_ms, 0), COALESCE(error, ''),
|
||||
COALESCE(error_code, ''), COALESCE(error_message, ''),
|
||||
created_at, updated_at, COALESCE(finished_at::text, '')`
|
||||
|
||||
type TaskEvent struct {
|
||||
@ -713,6 +716,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...)
|
||||
var retryPolicy []byte
|
||||
var rateLimitPolicy []byte
|
||||
var runtimePolicyOverride []byte
|
||||
var modelTypeBytes []byte
|
||||
if err := rows.Scan(
|
||||
&model.ID,
|
||||
&model.PlatformID,
|
||||
@ -721,7 +725,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...)
|
||||
&model.PlatformName,
|
||||
&model.ModelName,
|
||||
&model.ModelAlias,
|
||||
&model.ModelType,
|
||||
&modelTypeBytes,
|
||||
&model.DisplayName,
|
||||
&capabilityOverride,
|
||||
&capabilities,
|
||||
@ -743,6 +747,7 @@ ORDER BY m.model_type ASC, m.model_name ASC`, args...)
|
||||
}
|
||||
model.CapabilityOverride = decodeObject(capabilityOverride)
|
||||
model.Capabilities = decodeObject(capabilities)
|
||||
model.ModelType = decodeStringArray(modelTypeBytes)
|
||||
model.BillingConfigOverride = decodeObject(billingConfigOverride)
|
||||
model.BillingConfig = decodeObject(billingConfig)
|
||||
model.PermissionConfig = decodeObject(permissionConfig)
|
||||
@ -1231,7 +1236,7 @@ func (s *Store) VerifyLocalAPIKey(ctx context.Context, secret string) (*auth.Use
|
||||
return nil, auth.ErrUnauthorized
|
||||
}
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT k.id::text, k.key_hash, k.key_prefix, k.name, COALESCE(k.user_group_id::text, u.default_user_group_id::text, ''),
|
||||
SELECT k.id::text, k.key_hash, k.key_prefix, k.name, k.scopes, COALESCE(k.user_group_id::text, u.default_user_group_id::text, ''),
|
||||
u.id::text, u.username, u.roles, COALESCE(u.gateway_tenant_id::text, ''),
|
||||
COALESCE(u.tenant_id, ''), COALESCE(u.tenant_key, '')
|
||||
FROM gateway_api_keys k
|
||||
@ -1252,6 +1257,7 @@ WHERE k.key_prefix = $1
|
||||
var hash string
|
||||
var keyPrefix string
|
||||
var keyName string
|
||||
var scopesBytes []byte
|
||||
var userGroupID string
|
||||
var gatewayUserID string
|
||||
var username string
|
||||
@ -1259,7 +1265,7 @@ WHERE k.key_prefix = $1
|
||||
var gatewayTenantID string
|
||||
var tenantID string
|
||||
var tenantKey string
|
||||
if err := rows.Scan(&apiKeyID, &hash, &keyPrefix, &keyName, &userGroupID, &gatewayUserID, &username, &rolesBytes, &gatewayTenantID, &tenantID, &tenantKey); err != nil {
|
||||
if err := rows.Scan(&apiKeyID, &hash, &keyPrefix, &keyName, &scopesBytes, &userGroupID, &gatewayUserID, &username, &rolesBytes, &gatewayTenantID, &tenantID, &tenantKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(secret)) != nil {
|
||||
@ -1281,6 +1287,7 @@ WHERE k.key_prefix = $1
|
||||
APIKeyID: apiKeyID,
|
||||
APIKeyName: keyName,
|
||||
APIKeyPrefix: keyPrefix,
|
||||
APIKeyScopes: decodeStringArray(scopesBytes),
|
||||
}, nil
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@ -1661,6 +1668,8 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
|
||||
&task.ResponseFinishedAt,
|
||||
&task.ResponseDurationMS,
|
||||
&task.Error,
|
||||
&task.ErrorCode,
|
||||
&task.ErrorMessage,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
&task.FinishedAt,
|
||||
|
||||
@ -176,6 +176,63 @@ func (s *Store) DeletePricingRuleSet(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) PricingRuleSetBillingConfig(ctx context.Context, id string) (map[string]any, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return nil, nil
|
||||
}
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT resource_type, base_price::float8, dynamic_weight
|
||||
FROM model_pricing_rules
|
||||
WHERE rule_set_id = $1::uuid
|
||||
AND status = 'active'
|
||||
ORDER BY priority ASC, resource_type ASC`, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
config := map[string]any{}
|
||||
for rows.Next() {
|
||||
var resourceType string
|
||||
var basePrice float64
|
||||
var dynamicWeightBytes []byte
|
||||
if err := rows.Scan(&resourceType, &basePrice, &dynamicWeightBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dynamicWeight := decodeObject(dynamicWeightBytes)
|
||||
switch resourceType {
|
||||
case "text_input":
|
||||
config["textInputPer1k"] = basePrice
|
||||
case "text_output":
|
||||
config["textOutputPer1k"] = basePrice
|
||||
case "image":
|
||||
config["imageBase"] = basePrice
|
||||
config["image"] = pricingResourceConfig(basePrice, dynamicWeight)
|
||||
case "image_edit":
|
||||
config["editBase"] = basePrice
|
||||
config["image_edit"] = pricingResourceConfig(basePrice, dynamicWeight)
|
||||
case "video":
|
||||
config["videoBase"] = basePrice
|
||||
config["video"] = pricingResourceConfig(basePrice, dynamicWeight)
|
||||
default:
|
||||
config[resourceType] = pricingResourceConfig(basePrice, dynamicWeight)
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func pricingResourceConfig(basePrice float64, dynamicWeight map[string]any) map[string]any {
|
||||
config := map[string]any{"basePrice": basePrice}
|
||||
if len(dynamicWeight) > 0 {
|
||||
config["dynamicWeight"] = dynamicWeight
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func insertPricingRules(ctx context.Context, tx pgx.Tx, ruleSetID string, defaultCurrency string, rules []PricingRuleInput) error {
|
||||
for index, rule := range rules {
|
||||
rule = normalizePricingRule(rule, index, defaultCurrency)
|
||||
|
||||
@ -107,6 +107,33 @@ func (s *Store) DeleteRuntimePolicySet(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DisableCandidatePlatform(ctx context.Context, platformID string) error {
|
||||
if strings.TrimSpace(platformID) == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE integration_platforms
|
||||
SET status = 'disabled',
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`, platformID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CooldownCandidatePlatform(ctx context.Context, platformID string, cooldownSeconds int) error {
|
||||
if strings.TrimSpace(platformID) == "" {
|
||||
return nil
|
||||
}
|
||||
if cooldownSeconds <= 0 {
|
||||
cooldownSeconds = 300
|
||||
}
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE integration_platforms
|
||||
SET cooldown_until = now() + ($2::int * interval '1 second'),
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`, platformID, cooldownSeconds)
|
||||
return err
|
||||
}
|
||||
|
||||
func scanRuntimePolicySet(scanner runtimePolicyScanner) (RuntimePolicySet, error) {
|
||||
var item RuntimePolicySet
|
||||
var rateLimitPolicy []byte
|
||||
|
||||
@ -16,7 +16,7 @@ type CreatePlatformModelInput struct {
|
||||
CanonicalModelKey string `json:"canonicalModelKey"`
|
||||
ModelName string `json:"modelName"`
|
||||
ModelAlias string `json:"modelAlias"`
|
||||
ModelType string `json:"modelType"`
|
||||
ModelType StringList `json:"modelType"`
|
||||
DisplayName string `json:"displayName"`
|
||||
CapabilityOverride map[string]any `json:"capabilityOverride"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
@ -64,12 +64,17 @@ type RuntimeModelCandidate struct {
|
||||
PermissionConfig map[string]any
|
||||
PricingMode string
|
||||
DiscountFactor float64
|
||||
BasePricingRuleSetID string
|
||||
PlatformPricingRuleSetID string
|
||||
ModelPricingRuleSetID string
|
||||
ModelRetryPolicy map[string]any
|
||||
ModelRateLimitPolicy map[string]any
|
||||
RuntimePolicySetID string
|
||||
RuntimePolicyOverride map[string]any
|
||||
RuntimeRetryPolicy map[string]any
|
||||
RuntimeRateLimitPolicy map[string]any
|
||||
AutoDisablePolicy map[string]any
|
||||
DegradePolicy map[string]any
|
||||
ClientID string
|
||||
QueueKey string
|
||||
}
|
||||
|
||||
@ -3,9 +3,67 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
func (s *Store) ListTasks(ctx context.Context, user *auth.User, limit int) ([]GatewayTask, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
gatewayUserID := localGatewayUserID(user)
|
||||
apiKeyID := ""
|
||||
userID := ""
|
||||
if user != nil {
|
||||
apiKeyID = strings.TrimSpace(user.APIKeyID)
|
||||
userID = strings.TrimSpace(user.ID)
|
||||
}
|
||||
if gatewayUserID == "" && userID == "" {
|
||||
return nil, ErrLocalUserRequired
|
||||
}
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT `+gatewayTaskColumns+`
|
||||
FROM gateway_tasks
|
||||
WHERE (
|
||||
(
|
||||
NULLIF($1, '')::uuid IS NOT NULL
|
||||
AND gateway_user_id = NULLIF($1, '')::uuid
|
||||
)
|
||||
OR (
|
||||
NULLIF($1, '')::uuid IS NULL
|
||||
AND NULLIF($2, '') IS NOT NULL
|
||||
AND user_id = $2
|
||||
)
|
||||
)
|
||||
AND (
|
||||
NULLIF($3, '') IS NULL
|
||||
OR api_key_id = $3
|
||||
)
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $4`, gatewayUserID, userID, apiKeyID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]GatewayTask, 0)
|
||||
for rows.Next() {
|
||||
task, err := scanGatewayTask(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, task)
|
||||
}
|
||||
return items, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error {
|
||||
normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest))
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
@ -140,6 +198,103 @@ WHERE id = $1::uuid`,
|
||||
return s.GetTask(ctx, input.TaskID)
|
||||
}
|
||||
|
||||
func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
|
||||
if task.FinalChargeAmount <= 0 || strings.TrimSpace(task.GatewayUserID) == "" {
|
||||
return nil
|
||||
}
|
||||
currency := strings.TrimSpace(taskBillingString(task.BillingSummary["currency"]))
|
||||
if currency == "" || currency == "mixed" {
|
||||
currency = "resource"
|
||||
}
|
||||
metadata, _ := json.Marshal(map[string]any{
|
||||
"taskId": task.ID,
|
||||
"kind": task.Kind,
|
||||
"model": task.Model,
|
||||
"resolvedModel": task.ResolvedModel,
|
||||
"billings": task.Billings,
|
||||
"billingSummary": task.BillingSummary,
|
||||
})
|
||||
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_wallet_accounts (
|
||||
gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency
|
||||
)
|
||||
VALUES (NULLIF($1, '')::uuid, $2::uuid, NULLIF($3, ''), NULLIF($4, ''), NULLIF($5, ''), $6)
|
||||
ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
|
||||
task.GatewayTenantID, task.GatewayUserID, task.TenantID, task.TenantKey, task.UserID, currency); err != nil {
|
||||
return err
|
||||
}
|
||||
var exists bool
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM gateway_wallet_transactions t
|
||||
JOIN gateway_wallet_accounts a ON a.id = t.account_id
|
||||
WHERE a.gateway_user_id = $1::uuid
|
||||
AND a.currency = $2
|
||||
AND t.idempotency_key = $3
|
||||
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
var accountID string
|
||||
var balanceBefore float64
|
||||
var gatewayTenantID string
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
|
||||
FROM gateway_wallet_accounts
|
||||
WHERE gateway_user_id = $1::uuid
|
||||
AND currency = $2
|
||||
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
|
||||
return err
|
||||
}
|
||||
amount := roundMoney(task.FinalChargeAmount)
|
||||
balanceAfter := roundMoney(balanceBefore - amount)
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE gateway_wallet_accounts
|
||||
SET balance = $2,
|
||||
total_spent = total_spent + $3,
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_wallet_transactions (
|
||||
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||
)
|
||||
VALUES (
|
||||
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
|
||||
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||
)`,
|
||||
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
|
||||
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func billingIdempotencyKey(taskID string) string {
|
||||
return "task:" + taskID + ":billing"
|
||||
}
|
||||
|
||||
func roundMoney(value float64) float64 {
|
||||
if value < 0 {
|
||||
return -roundMoney(-value)
|
||||
}
|
||||
return float64(int64(value*1000000+0.5)) / 1000000
|
||||
}
|
||||
|
||||
func taskBillingString(value any) string {
|
||||
if text, ok := value.(string); ok {
|
||||
return text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
|
||||
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
||||
if _, err := s.pool.Exec(ctx, `
|
||||
|
||||
@ -376,7 +376,7 @@ CREATE TABLE IF NOT EXISTS platform_models (
|
||||
base_model_id uuid REFERENCES base_model_catalog(id) ON DELETE SET NULL,
|
||||
model_name text NOT NULL,
|
||||
model_alias text,
|
||||
model_type text NOT NULL,
|
||||
model_type jsonb NOT NULL DEFAULT '[]'::jsonb,
|
||||
display_name text NOT NULL DEFAULT '',
|
||||
capability_override jsonb NOT NULL DEFAULT '{}'::jsonb,
|
||||
capabilities jsonb NOT NULL DEFAULT '{}'::jsonb,
|
||||
@ -391,14 +391,17 @@ CREATE TABLE IF NOT EXISTS platform_models (
|
||||
enabled boolean NOT NULL DEFAULT true,
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||
UNIQUE(platform_id, model_name, model_type)
|
||||
UNIQUE(platform_id, model_name)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_base
|
||||
ON platform_models(base_model_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_lookup
|
||||
ON platform_models(model_type, model_name, enabled);
|
||||
ON platform_models(model_name, enabled);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_model_type
|
||||
ON platform_models USING gin(model_type);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_alias
|
||||
ON platform_models(model_alias);
|
||||
|
||||
@ -323,7 +323,7 @@ CREATE TABLE IF NOT EXISTS platform_models (
|
||||
base_model_id uuid REFERENCES base_model_catalog(id) ON DELETE SET NULL,
|
||||
model_name text NOT NULL,
|
||||
model_alias text,
|
||||
model_type text NOT NULL,
|
||||
model_type jsonb NOT NULL DEFAULT '[]'::jsonb,
|
||||
display_name text NOT NULL DEFAULT '',
|
||||
capability_override jsonb NOT NULL DEFAULT '{}'::jsonb,
|
||||
capabilities jsonb NOT NULL DEFAULT '{}'::jsonb,
|
||||
@ -338,7 +338,7 @@ CREATE TABLE IF NOT EXISTS platform_models (
|
||||
enabled boolean NOT NULL DEFAULT true,
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||
UNIQUE(platform_id, model_name, model_type)
|
||||
UNIQUE(platform_id, model_name)
|
||||
);
|
||||
|
||||
ALTER TABLE IF EXISTS platform_models
|
||||
@ -636,7 +636,10 @@ CREATE INDEX IF NOT EXISTS idx_gateway_recharge_orders_user
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_base
|
||||
ON platform_models(base_model_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_lookup
|
||||
ON platform_models(model_type, model_name, enabled);
|
||||
ON platform_models(model_name, enabled);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_model_type
|
||||
ON platform_models USING gin(model_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_alias
|
||||
ON platform_models(model_alias);
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_capabilities
|
||||
|
||||
@ -199,7 +199,19 @@ INSERT INTO platform_models (
|
||||
platform_id, base_model_id, model_name, model_alias, model_type, display_name,
|
||||
capabilities, pricing_mode, billing_config, retry_policy, rate_limit_policy, enabled
|
||||
)
|
||||
SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key, b.model_type, b.display_name,
|
||||
SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key,
|
||||
CASE b.model_type
|
||||
WHEN 'chat' THEN '["text_generate"]'::jsonb
|
||||
WHEN 'text' THEN '["text_generate"]'::jsonb
|
||||
WHEN 'responses' THEN '["text_generate"]'::jsonb
|
||||
WHEN 'image' THEN '["image_generate","image_edit"]'::jsonb
|
||||
WHEN 'images.generations' THEN '["image_generate"]'::jsonb
|
||||
WHEN 'images.edits' THEN '["image_edit"]'::jsonb
|
||||
WHEN 'video' THEN '["video_generate"]'::jsonb
|
||||
WHEN 'videos.generations' THEN '["video_generate"]'::jsonb
|
||||
ELSE jsonb_build_array(b.model_type)
|
||||
END,
|
||||
b.display_name,
|
||||
b.capabilities, 'inherit_discount', b.base_billing_config,
|
||||
'{"enabled":true,"maxAttempts":2}'::jsonb,
|
||||
b.default_rate_limit_policy,
|
||||
@ -207,7 +219,7 @@ SELECT p.id, b.id, b.provider_model_name, b.canonical_model_key, b.model_type, b
|
||||
FROM integration_platforms p
|
||||
JOIN base_model_catalog b ON b.provider_key = p.provider
|
||||
WHERE p.platform_key IN ('openai-simulation', 'gemini-simulation')
|
||||
ON CONFLICT (platform_id, model_name, model_type) DO UPDATE
|
||||
ON CONFLICT (platform_id, model_name) DO UPDATE
|
||||
SET base_model_id = EXCLUDED.base_model_id,
|
||||
model_alias = EXCLUDED.model_alias,
|
||||
display_name = EXCLUDED.display_name,
|
||||
|
||||
138
apps/api/migrations/0019_platform_model_type_array.sql
Normal file
138
apps/api/migrations/0019_platform_model_type_array.sql
Normal file
@ -0,0 +1,138 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
column_type text;
|
||||
BEGIN
|
||||
SELECT data_type
|
||||
INTO column_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'platform_models'
|
||||
AND column_name = 'model_type';
|
||||
|
||||
IF column_type IS DISTINCT FROM 'jsonb' THEN
|
||||
DROP INDEX IF EXISTS idx_platform_models_lookup;
|
||||
ALTER TABLE platform_models
|
||||
DROP CONSTRAINT IF EXISTS platform_models_platform_id_model_name_model_type_key;
|
||||
|
||||
ALTER TABLE platform_models
|
||||
ADD COLUMN IF NOT EXISTS model_type_next jsonb NOT NULL DEFAULT '[]'::jsonb;
|
||||
|
||||
UPDATE platform_models
|
||||
SET model_type_next = CASE trim(model_type)
|
||||
WHEN 'chat' THEN '["text_generate"]'::jsonb
|
||||
WHEN 'text' THEN '["text_generate"]'::jsonb
|
||||
WHEN 'responses' THEN '["text_generate"]'::jsonb
|
||||
WHEN 'image' THEN '["image_generate","image_edit"]'::jsonb
|
||||
WHEN 'images.generations' THEN '["image_generate"]'::jsonb
|
||||
WHEN 'images.edits' THEN '["image_edit"]'::jsonb
|
||||
WHEN 'video' THEN '["video_generate"]'::jsonb
|
||||
WHEN 'videos.generations' THEN '["video_generate"]'::jsonb
|
||||
ELSE jsonb_build_array(trim(model_type))
|
||||
END;
|
||||
|
||||
WITH ranked AS (
|
||||
SELECT id,
|
||||
first_value(id) OVER (
|
||||
PARTITION BY platform_id, model_name
|
||||
ORDER BY created_at ASC, id ASC
|
||||
) AS keep_id,
|
||||
row_number() OVER (
|
||||
PARTITION BY platform_id, model_name
|
||||
ORDER BY created_at ASC, id ASC
|
||||
) AS row_number
|
||||
FROM platform_models
|
||||
),
|
||||
merged AS (
|
||||
SELECT ranked.keep_id,
|
||||
jsonb_agg(DISTINCT type_value ORDER BY type_value) AS model_type
|
||||
FROM ranked
|
||||
JOIN platform_models model_row ON model_row.id = ranked.id
|
||||
CROSS JOIN LATERAL jsonb_array_elements_text(model_row.model_type_next) AS type_item(type_value)
|
||||
GROUP BY ranked.keep_id
|
||||
)
|
||||
UPDATE platform_models target
|
||||
SET model_type_next = merged.model_type
|
||||
FROM merged
|
||||
WHERE target.id = merged.keep_id;
|
||||
|
||||
WITH ranked AS (
|
||||
SELECT id,
|
||||
first_value(id) OVER (
|
||||
PARTITION BY platform_id, model_name
|
||||
ORDER BY created_at ASC, id ASC
|
||||
) AS keep_id,
|
||||
row_number() OVER (
|
||||
PARTITION BY platform_id, model_name
|
||||
ORDER BY created_at ASC, id ASC
|
||||
) AS row_number
|
||||
FROM platform_models
|
||||
)
|
||||
UPDATE gateway_access_rules rules
|
||||
SET resource_id = ranked.keep_id
|
||||
FROM ranked
|
||||
WHERE ranked.row_number > 1
|
||||
AND rules.resource_type = 'platform_model'
|
||||
AND rules.resource_id = ranked.id;
|
||||
|
||||
WITH ranked AS (
|
||||
SELECT id,
|
||||
row_number() OVER (
|
||||
PARTITION BY platform_id, model_name
|
||||
ORDER BY created_at ASC, id ASC
|
||||
) AS row_number
|
||||
FROM platform_models
|
||||
)
|
||||
DELETE FROM platform_models model_row
|
||||
USING ranked
|
||||
WHERE model_row.id = ranked.id
|
||||
AND ranked.row_number > 1;
|
||||
|
||||
ALTER TABLE platform_models DROP COLUMN model_type;
|
||||
ALTER TABLE platform_models RENAME COLUMN model_type_next TO model_type;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
UPDATE platform_models
|
||||
SET model_type = CASE
|
||||
WHEN jsonb_typeof(model_type) = 'array' THEN model_type
|
||||
WHEN jsonb_typeof(model_type) = 'string' THEN jsonb_build_array(model_type #>> '{}')
|
||||
ELSE '[]'::jsonb
|
||||
END;
|
||||
|
||||
UPDATE platform_models
|
||||
SET model_type = COALESCE((
|
||||
SELECT jsonb_agg(DISTINCT normalized_type ORDER BY normalized_type)
|
||||
FROM jsonb_array_elements_text(platform_models.model_type) AS item(model_type_value)
|
||||
CROSS JOIN LATERAL (
|
||||
VALUES
|
||||
(CASE item.model_type_value
|
||||
WHEN 'chat' THEN 'text_generate'
|
||||
WHEN 'text' THEN 'text_generate'
|
||||
WHEN 'responses' THEN 'text_generate'
|
||||
WHEN 'images.generations' THEN 'image_generate'
|
||||
WHEN 'images.edits' THEN 'image_edit'
|
||||
WHEN 'video' THEN 'video_generate'
|
||||
WHEN 'videos.generations' THEN 'video_generate'
|
||||
ELSE item.model_type_value
|
||||
END)
|
||||
) AS normalized(normalized_type)
|
||||
WHERE trim(normalized_type) <> ''
|
||||
), '[]'::jsonb);
|
||||
|
||||
UPDATE platform_models
|
||||
SET model_type = '["image_generate","image_edit"]'::jsonb
|
||||
WHERE model_type = '["image"]'::jsonb;
|
||||
|
||||
ALTER TABLE platform_models
|
||||
ALTER COLUMN model_type SET NOT NULL,
|
||||
ALTER COLUMN model_type SET DEFAULT '[]'::jsonb;
|
||||
|
||||
DROP INDEX IF EXISTS idx_platform_models_lookup;
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_lookup
|
||||
ON platform_models(model_name, enabled);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_platform_models_model_type
|
||||
ON platform_models USING gin(model_type);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_platform_models_platform_model_name
|
||||
ON platform_models(platform_id, model_name);
|
||||
@ -0,0 +1,11 @@
|
||||
UPDATE platform_models m
|
||||
SET rate_limit_policy = '{}'::jsonb,
|
||||
updated_at = now()
|
||||
FROM base_model_catalog b,
|
||||
model_runtime_policy_sets rp
|
||||
WHERE m.base_model_id = b.id
|
||||
AND m.runtime_policy_set_id = rp.id
|
||||
AND m.runtime_policy_set_id IS NOT NULL
|
||||
AND m.runtime_policy_set_id IS DISTINCT FROM b.runtime_policy_set_id
|
||||
AND m.rate_limit_policy = b.default_rate_limit_policy
|
||||
AND rp.rate_limit_policy <> '{}'::jsonb;
|
||||
108
apps/api/migrations/0021_base_model_type_array.sql
Normal file
108
apps/api/migrations/0021_base_model_type_array.sql
Normal file
@ -0,0 +1,108 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
column_type text;
|
||||
BEGIN
|
||||
SELECT data_type
|
||||
INTO column_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'base_model_catalog'
|
||||
AND column_name = 'model_type';
|
||||
|
||||
IF column_type IS DISTINCT FROM 'jsonb' THEN
|
||||
DROP INDEX IF EXISTS idx_base_model_catalog_type;
|
||||
|
||||
ALTER TABLE base_model_catalog
|
||||
ADD COLUMN IF NOT EXISTS model_type_next jsonb NOT NULL DEFAULT '[]'::jsonb;
|
||||
|
||||
WITH source_types AS (
|
||||
SELECT id,
|
||||
CASE
|
||||
WHEN jsonb_typeof(capabilities->'originalTypes') = 'array' THEN capabilities->'originalTypes'
|
||||
WHEN jsonb_typeof(metadata->'originalTypes') = 'array' THEN metadata->'originalTypes'
|
||||
ELSE jsonb_build_array(model_type)
|
||||
END AS model_types
|
||||
FROM base_model_catalog
|
||||
),
|
||||
normalized_types AS (
|
||||
SELECT source_types.id, normalized.model_type
|
||||
FROM source_types
|
||||
CROSS JOIN LATERAL jsonb_array_elements_text(source_types.model_types) AS raw_type(model_type)
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT 'text_generate' AS model_type
|
||||
WHERE raw_type.model_type IN ('chat', 'text', 'responses')
|
||||
UNION ALL SELECT 'image_generate'
|
||||
WHERE raw_type.model_type = 'image'
|
||||
UNION ALL SELECT 'image_edit'
|
||||
WHERE raw_type.model_type = 'image'
|
||||
UNION ALL SELECT 'image_generate'
|
||||
WHERE raw_type.model_type = 'images.generations'
|
||||
UNION ALL SELECT 'image_edit'
|
||||
WHERE raw_type.model_type = 'images.edits'
|
||||
UNION ALL SELECT 'video_generate'
|
||||
WHERE raw_type.model_type IN ('video', 'videos.generations')
|
||||
UNION ALL SELECT raw_type.model_type
|
||||
WHERE trim(raw_type.model_type) <> ''
|
||||
AND raw_type.model_type NOT IN (
|
||||
'chat', 'text', 'responses', 'image', 'images.generations', 'images.edits', 'video', 'videos.generations'
|
||||
)
|
||||
) AS normalized
|
||||
)
|
||||
UPDATE base_model_catalog target
|
||||
SET model_type_next = COALESCE((
|
||||
SELECT jsonb_agg(DISTINCT normalized_types.model_type ORDER BY normalized_types.model_type)
|
||||
FROM normalized_types
|
||||
WHERE normalized_types.id = target.id
|
||||
), '[]'::jsonb);
|
||||
|
||||
ALTER TABLE base_model_catalog DROP COLUMN model_type;
|
||||
ALTER TABLE base_model_catalog RENAME COLUMN model_type_next TO model_type;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
UPDATE base_model_catalog
|
||||
SET model_type = CASE
|
||||
WHEN jsonb_typeof(model_type) = 'array' THEN model_type
|
||||
WHEN jsonb_typeof(model_type) = 'string' THEN jsonb_build_array(model_type #>> '{}')
|
||||
ELSE '[]'::jsonb
|
||||
END;
|
||||
|
||||
UPDATE base_model_catalog
|
||||
SET default_snapshot = default_snapshot || jsonb_build_object('modelType', model_type)
|
||||
WHERE catalog_type = 'system'
|
||||
AND COALESCE(default_snapshot, '{}'::jsonb) <> '{}'::jsonb;
|
||||
|
||||
ALTER TABLE base_model_catalog
|
||||
ALTER COLUMN model_type SET NOT NULL,
|
||||
ALTER COLUMN model_type SET DEFAULT '[]'::jsonb;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_base_model_catalog_type
|
||||
ON base_model_catalog(provider_key, status);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_base_model_catalog_model_type
|
||||
ON base_model_catalog USING gin(model_type);
|
||||
|
||||
CREATE OR REPLACE FUNCTION fill_system_base_model_default_snapshot()
|
||||
RETURNS trigger AS $$
|
||||
BEGIN
|
||||
IF NEW.catalog_type = 'system' AND NEW.default_snapshot IS NULL THEN
|
||||
NEW.default_snapshot = jsonb_build_object(
|
||||
'providerKey', NEW.provider_key,
|
||||
'canonicalModelKey', NEW.canonical_model_key,
|
||||
'providerModelName', NEW.provider_model_name,
|
||||
'modelType', NEW.model_type,
|
||||
'modelAlias', NEW.display_name,
|
||||
'capabilities', NEW.capabilities,
|
||||
'baseBillingConfig', NEW.base_billing_config,
|
||||
'defaultRateLimitPolicy', NEW.default_rate_limit_policy,
|
||||
'pricingRuleSetId', COALESCE(NEW.pricing_rule_set_id::text, ''),
|
||||
'runtimePolicySetId', COALESCE(NEW.runtime_policy_set_id::text, ''),
|
||||
'runtimePolicyOverride', NEW.runtime_policy_override,
|
||||
'metadata', NEW.metadata,
|
||||
'pricingVersion', NEW.pricing_version,
|
||||
'status', NEW.status
|
||||
);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@ -0,0 +1,9 @@
|
||||
UPDATE platform_models m
|
||||
SET model_type = b.model_type,
|
||||
updated_at = now()
|
||||
FROM base_model_catalog b
|
||||
WHERE m.base_model_id = b.id
|
||||
AND jsonb_typeof(b.model_type) = 'array'
|
||||
AND jsonb_array_length(b.model_type) > 0
|
||||
AND m.model_type <> b.model_type
|
||||
AND COALESCE(m.capabilities, '{}'::jsonb) = COALESCE(b.capabilities, '{}'::jsonb);
|
||||
48
apps/api/migrations/0023_wallet_task_billing_schema.sql
Normal file
48
apps/api/migrations/0023_wallet_task_billing_schema.sql
Normal file
@ -0,0 +1,48 @@
|
||||
ALTER TABLE gateway_wallet_accounts
|
||||
ADD COLUMN IF NOT EXISTS total_recharged numeric NOT NULL DEFAULT 0,
|
||||
ADD COLUMN IF NOT EXISTS total_spent numeric NOT NULL DEFAULT 0;
|
||||
|
||||
ALTER TABLE gateway_wallet_transactions
|
||||
ADD COLUMN IF NOT EXISTS account_id uuid REFERENCES gateway_wallet_accounts(id) ON DELETE CASCADE,
|
||||
ADD COLUMN IF NOT EXISTS direction text,
|
||||
ADD COLUMN IF NOT EXISTS balance_before numeric,
|
||||
ADD COLUMN IF NOT EXISTS idempotency_key text;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'gateway_wallet_transactions'
|
||||
AND column_name = 'wallet_account_id'
|
||||
) THEN
|
||||
EXECUTE 'UPDATE gateway_wallet_transactions SET account_id = wallet_account_id WHERE account_id IS NULL';
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
UPDATE gateway_wallet_transactions
|
||||
SET direction = CASE
|
||||
WHEN transaction_type IN ('recharge', 'refund', 'credit') THEN 'credit'
|
||||
ELSE 'debit'
|
||||
END
|
||||
WHERE direction IS NULL;
|
||||
|
||||
UPDATE gateway_wallet_transactions
|
||||
SET balance_before = COALESCE(balance_after, 0) + CASE
|
||||
WHEN direction = 'debit' THEN COALESCE(amount, 0)
|
||||
ELSE -COALESCE(amount, 0)
|
||||
END
|
||||
WHERE balance_before IS NULL;
|
||||
|
||||
ALTER TABLE gateway_wallet_transactions
|
||||
ALTER COLUMN direction SET DEFAULT 'debit',
|
||||
ALTER COLUMN direction SET NOT NULL,
|
||||
ALTER COLUMN balance_before SET DEFAULT 0,
|
||||
ALTER COLUMN balance_before SET NOT NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_wallet_transactions_account
|
||||
ON gateway_wallet_transactions(account_id, created_at DESC);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uniq_gateway_wallet_tx_idempotency
|
||||
ON gateway_wallet_transactions(account_id, idempotency_key)
|
||||
WHERE idempotency_key IS NOT NULL;
|
||||
@ -49,6 +49,7 @@ import {
|
||||
listPricingRules,
|
||||
listPricingRuleSets,
|
||||
listRuntimePolicySets,
|
||||
listTasks,
|
||||
listPublicBaseModels,
|
||||
listPublicCatalogProviders,
|
||||
listRateLimitWindows,
|
||||
@ -121,6 +122,7 @@ type DataKey =
|
||||
| 'tenants'
|
||||
| 'users'
|
||||
| 'userGroups'
|
||||
| 'tasks'
|
||||
| 'accessRules'
|
||||
| 'apiKeys';
|
||||
|
||||
@ -157,6 +159,7 @@ export function App() {
|
||||
const [selectedPlaygroundApiKeyId, setSelectedPlaygroundApiKeyId] = useState('');
|
||||
const [taskForm, setTaskForm] = useState<TaskForm>({ kind: 'chat.completions', model: 'gpt-4o-mini', prompt: '用一句话确认 AI Gateway simulation 链路正常。' });
|
||||
const [taskResult, setTaskResult] = useState<GatewayTask | null>(null);
|
||||
const [tasks, setTasks] = useState<GatewayTask[]>([]);
|
||||
const [coreState, setCoreState] = useState<LoadState>('idle');
|
||||
const [coreMessage, setCoreMessage] = useState('');
|
||||
const [state, setState] = useState<LoadState>('idle');
|
||||
@ -228,10 +231,11 @@ export function App() {
|
||||
rateLimitWindows,
|
||||
runtimePolicySets,
|
||||
taskResult,
|
||||
tasks,
|
||||
tenants,
|
||||
userGroups,
|
||||
users,
|
||||
}), [accessRules, apiKeys, baseModels, models, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runtimePolicySets, taskResult, tenants, userGroups, users]);
|
||||
}), [accessRules, apiKeys, baseModels, models, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runtimePolicySets, taskResult, tasks, tenants, userGroups, users]);
|
||||
|
||||
async function refresh(nextToken = token) {
|
||||
await ensureRouteData(nextToken, true);
|
||||
@ -327,6 +331,9 @@ export function App() {
|
||||
case 'userGroups':
|
||||
setUserGroups((await listUserGroups(nextToken)).items);
|
||||
return;
|
||||
case 'tasks':
|
||||
setTasks((await listTasks(nextToken)).items);
|
||||
return;
|
||||
case 'accessRules':
|
||||
setAccessRules((await (activePage === 'workspace' && workspaceSection === 'apiKeys'
|
||||
? listApiKeyAccessRules(nextToken)
|
||||
@ -628,6 +635,8 @@ export function App() {
|
||||
const response = await runTask(credential, taskForm);
|
||||
const detail = await getTask(credential, response.task.id);
|
||||
setTaskResult(detail);
|
||||
setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]);
|
||||
invalidateDataKeys('tasks');
|
||||
setCoreState('ready');
|
||||
setCoreMessage(`${taskForm.kind} 已通过 ${apiKeySecret ? '本地 API Key' : '当前 Access Token'} 完成 simulation。`);
|
||||
} catch (err) {
|
||||
@ -660,6 +669,7 @@ export function App() {
|
||||
setApiKeySecretsById({});
|
||||
setSelectedPlaygroundApiKeyId('');
|
||||
setTaskResult(null);
|
||||
setTasks([]);
|
||||
setCoreMessage('');
|
||||
navigatePath('/');
|
||||
}
|
||||
@ -852,10 +862,16 @@ export function App() {
|
||||
function platformModelIsSelected(model: PlatformModel, selectedModels: PlatformModelBindingInput[]) {
|
||||
return selectedModels.some((selected) => {
|
||||
if (selected.baseModelId && model.baseModelId) return selected.baseModelId === model.baseModelId;
|
||||
return selected.modelName === model.modelName && selected.modelType === model.modelType;
|
||||
return selected.modelName === model.modelName && sameModelTypes(selected.modelType, model.modelType);
|
||||
});
|
||||
}
|
||||
|
||||
function sameModelTypes(left: string[], right: string[]) {
|
||||
if (left.length !== right.length) return false;
|
||||
const rightSet = new Set(right);
|
||||
return left.every((type) => rightSet.has(type));
|
||||
}
|
||||
|
||||
function mergeExistingPlatformModelInput(input: PlatformModelBindingInput, currentModels: PlatformModel[], platformId: string): PlatformModelBindingInput {
|
||||
const existing = currentModels.find((model) => model.platformId === platformId && platformModelIsSelected(model, [input]));
|
||||
if (!existing) return input;
|
||||
@ -924,6 +940,7 @@ function dataKeysForRoute(
|
||||
if (activePage === 'workspace') {
|
||||
if (workspaceSection === 'overview') return ['users', 'userGroups', 'apiKeys'];
|
||||
if (workspaceSection === 'apiKeys') return ['apiKeys', 'accessRules', 'playgroundModels'];
|
||||
if (workspaceSection === 'tasks') return ['tasks'];
|
||||
return [];
|
||||
}
|
||||
|
||||
|
||||
@ -582,6 +582,10 @@ export async function getTask(token: string, taskId: string): Promise<GatewayTas
|
||||
return request<GatewayTask>(`/api/v1/tasks/${taskId}`, { token });
|
||||
}
|
||||
|
||||
export async function listTasks(token: string, limit = 50): Promise<ListResponse<GatewayTask>> {
|
||||
return request<ListResponse<GatewayTask>>(`/api/v1/tasks?limit=${encodeURIComponent(String(limit))}`, { token });
|
||||
}
|
||||
|
||||
export function resolveApiAssetUrl(src: string) {
|
||||
if (/^(https?:|data:|blob:)/i.test(src)) return src;
|
||||
return `${API_BASE}${src.startsWith('/') ? src : `/${src}`}`;
|
||||
|
||||
@ -27,6 +27,7 @@ export interface ConsoleData {
|
||||
rateLimitWindows: RateLimitWindow[];
|
||||
runtimePolicySets: RuntimePolicySet[];
|
||||
taskResult: GatewayTask | null;
|
||||
tasks: GatewayTask[];
|
||||
tenants: GatewayTenant[];
|
||||
userGroups: UserGroup[];
|
||||
users: GatewayUser[];
|
||||
|
||||
@ -73,7 +73,7 @@ export function Dashboard(props: {
|
||||
<DataPanel
|
||||
columns={['模型', '类型', '平台', '启用']}
|
||||
empty="暂无模型数据"
|
||||
rows={props.models.map((item) => [item.modelName, item.modelType, item.provider ?? item.platformName ?? '-', item.enabled ? '是' : '否'])}
|
||||
rows={props.models.map((item) => [item.modelName, item.modelType.join(', '), item.provider ?? item.platformName ?? '-', item.enabled ? '是' : '否'])}
|
||||
title="模型"
|
||||
/>
|
||||
</section>
|
||||
|
||||
@ -173,8 +173,8 @@ function identityPanelProps(props: {
|
||||
|
||||
function OverviewPanel(props: { data: ConsoleData; stats: StatItem[] }) {
|
||||
const enabledPlatforms = props.data.platforms.filter((item) => item.status === 'enabled');
|
||||
const chatModels = props.data.models.filter((item) => item.modelType === 'chat' && item.enabled);
|
||||
const imageModels = props.data.models.filter((item) => item.modelType === 'image' && item.enabled);
|
||||
const chatModels = props.data.models.filter((item) => item.modelType.includes('text_generate') && item.enabled);
|
||||
const imageModels = props.data.models.filter((item) => item.modelType.some((type) => type.includes('image')) && item.enabled);
|
||||
|
||||
return (
|
||||
<div className="pageStack">
|
||||
@ -229,7 +229,7 @@ function OverviewPanel(props: { data: ConsoleData; stats: StatItem[] }) {
|
||||
rows={[
|
||||
['对话', chatModels.length, 'Phase 1', '/v1/chat/completions'],
|
||||
['图像', imageModels.length, 'Phase 1', '/v1/images/*'],
|
||||
['视频', props.data.models.filter((item) => item.modelType === 'video').length, 'Next', '/v1/videos/*'],
|
||||
['视频', props.data.models.filter((item) => item.modelType.some((type) => type.includes('video'))).length, 'Next', '/v1/videos/*'],
|
||||
]}
|
||||
/>
|
||||
</CardContent>
|
||||
|
||||
@ -9,7 +9,7 @@ import type {
|
||||
import type { ConsoleData } from '../app-state';
|
||||
import { PageHeader } from '../components/PageHeader';
|
||||
import { Badge, Card, CardContent, Input } from '../components/ui';
|
||||
import { primaryBaseModelType, stableModelAlias } from './admin/platform-form';
|
||||
import { stableModelAlias } from './admin/platform-form';
|
||||
|
||||
type ModelListItem = {
|
||||
id: string;
|
||||
@ -17,7 +17,7 @@ type ModelListItem = {
|
||||
platformName?: string;
|
||||
modelName: string;
|
||||
modelAlias?: string;
|
||||
modelType: string;
|
||||
modelType: string[];
|
||||
displayName: string;
|
||||
capabilities?: Record<string, unknown>;
|
||||
pricingMode: string;
|
||||
@ -69,7 +69,7 @@ const publicModels: PlatformModel[] = [
|
||||
platformName: 'OpenAI Simulation',
|
||||
modelName: 'gpt-4o-mini',
|
||||
modelAlias: 'gpt-4o-mini',
|
||||
modelType: 'chat',
|
||||
modelType: ['text_generate'],
|
||||
displayName: 'gpt-4o-mini',
|
||||
capabilities: { multimodal: true },
|
||||
pricingMode: 'inherit',
|
||||
@ -84,7 +84,7 @@ const publicModels: PlatformModel[] = [
|
||||
platformName: 'OpenAI Simulation',
|
||||
modelName: 'gpt-image-1',
|
||||
modelAlias: 'gpt-image-1',
|
||||
modelType: 'image',
|
||||
modelType: ['image_generate', 'image_edit'],
|
||||
displayName: 'gpt-image-1',
|
||||
capabilities: { imageEdit: true },
|
||||
pricingMode: 'inherit',
|
||||
@ -99,7 +99,7 @@ const publicModels: PlatformModel[] = [
|
||||
platformName: 'Gemini Simulation',
|
||||
modelName: 'gemini-2.0-flash',
|
||||
modelAlias: 'gemini-2.0-flash',
|
||||
modelType: 'chat',
|
||||
modelType: ['text_generate'],
|
||||
displayName: 'gemini-2.0-flash',
|
||||
capabilities: { multimodal: true, vision: true },
|
||||
pricingMode: 'inherit_discount',
|
||||
@ -148,7 +148,7 @@ export function ModelsPage(props: { data: ConsoleData }) {
|
||||
return sourceModels.filter((model) => {
|
||||
const providerInfo = providerMap.get(model.providerKey);
|
||||
const matchedProvider = provider === 'all' || model.providerKey === provider;
|
||||
const matchedCapability = capability === 'all' || model.modelType === capability;
|
||||
const matchedCapability = modelMatchesCapability(model.modelType, capability);
|
||||
const matchedQuery = [
|
||||
model.modelName,
|
||||
model.modelAlias,
|
||||
@ -331,7 +331,7 @@ function modelFromBaseModel(model: BaseModelCatalogItem): ModelListItem {
|
||||
providerKey: model.providerKey,
|
||||
modelName: model.providerModelName,
|
||||
modelAlias: stableModelAlias(model),
|
||||
modelType: primaryBaseModelType(model),
|
||||
modelType: model.modelType,
|
||||
displayName: stableModelAlias(model),
|
||||
capabilities: model.capabilities,
|
||||
pricingMode: 'inherit',
|
||||
@ -363,7 +363,7 @@ function providerInitials(label: string) {
|
||||
}
|
||||
|
||||
function tagsForModel(model: ModelListItem) {
|
||||
const tags = [capabilityName(model.modelType)];
|
||||
const tags = model.modelType.map(capabilityName);
|
||||
const capabilities = model.capabilities ?? {};
|
||||
if (capabilities.multimodal || capabilities.vision) tags.push('多模态');
|
||||
if (capabilities.reasoning) tags.push('推理');
|
||||
@ -373,7 +373,23 @@ function tagsForModel(model: ModelListItem) {
|
||||
}
|
||||
|
||||
function capabilityName(type: string) {
|
||||
return capabilityFilters.find((item) => item.value === type)?.label ?? type;
|
||||
const labels: Record<string, string> = {
|
||||
text_generate: '对话',
|
||||
image_generate: '绘图',
|
||||
image_edit: '图像编辑',
|
||||
video_generate: '视频',
|
||||
image_to_video: '图生视频',
|
||||
audio_generate: '音频',
|
||||
};
|
||||
return labels[type] ?? capabilityFilters.find((item) => item.value === type)?.label ?? type;
|
||||
}
|
||||
|
||||
function modelMatchesCapability(modelTypes: string[], capability: string) {
|
||||
if (capability === 'all') return true;
|
||||
if (capability === 'chat') return modelTypes.includes('text_generate') || modelTypes.includes('chat');
|
||||
if (capability === 'image') return modelTypes.some((type) => type.includes('image'));
|
||||
if (capability === 'video') return modelTypes.some((type) => type.includes('video'));
|
||||
return modelTypes.includes(capability);
|
||||
}
|
||||
|
||||
function priceLabel(model: ModelListItem) {
|
||||
|
||||
@ -988,8 +988,8 @@ function filterModelsForMode(models: PlatformModel[], mode: PlaygroundMode, hasR
|
||||
}
|
||||
|
||||
function filterWithFallback(models: PlatformModel[], modelTypes: string[]) {
|
||||
const exact = models.filter((model) => modelTypes.includes(model.modelType));
|
||||
return exact.length ? exact : models.filter((model) => modelTypes.some((type) => model.modelType.includes(type) || type.includes(model.modelType)));
|
||||
const exact = models.filter((model) => model.modelType.some((type) => modelTypes.includes(type)));
|
||||
return exact.length ? exact : models.filter((model) => modelTypes.some((type) => model.modelType.some((modelType) => modelType.includes(type) || type.includes(modelType))));
|
||||
}
|
||||
|
||||
function buildModelOptions(models: PlatformModel[]): ModelOption[] {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useMemo, useState, type FormEvent, type ReactNode } from 'react';
|
||||
import { Copy, CreditCard, KeyRound, ListChecks, Plus, ShieldCheck, Trash2, UserRound } from 'lucide-react';
|
||||
import type { GatewayAccessRuleBatchRequest, GatewayApiKey, IntegrationPlatform, PlatformModel } from '@easyai-ai-gateway/contracts';
|
||||
import type { GatewayAccessRuleBatchRequest, GatewayApiKey, GatewayTask, IntegrationPlatform, PlatformModel } from '@easyai-ai-gateway/contracts';
|
||||
import type { ConsoleData } from '../app-state';
|
||||
import { EntityTable } from '../components/EntityTable';
|
||||
import { Badge, Button, Card, CardContent, CardHeader, CardTitle, ConfirmDialog, DateTimePicker, FormDialog, Input, Label, Table, TableCell, TableHead, TableRow, Tabs } from '../components/ui';
|
||||
@ -297,30 +297,22 @@ function ApiKeyPanel(props: {
|
||||
}
|
||||
|
||||
function TaskPanel(props: { data: ConsoleData }) {
|
||||
const task = props.data.taskResult;
|
||||
const usage = task?.usage ?? {};
|
||||
const tokenText = usage.totalTokens ? `${usage.totalTokens}` : '-';
|
||||
const chargeText = task?.finalChargeAmount ? `${task.finalChargeAmount}` : '-';
|
||||
const tasks = useMemo(() => {
|
||||
const latest = props.data.taskResult;
|
||||
if (!latest) return props.data.tasks;
|
||||
return [latest, ...props.data.tasks.filter((item) => item.id !== latest.id)];
|
||||
}, [props.data.taskResult, props.data.tasks]);
|
||||
return (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>任务记录</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{task ? (
|
||||
<div className="taskPreview">
|
||||
<Badge variant={task.status === 'succeeded' ? 'success' : 'secondary'}>{task.status}</Badge>
|
||||
<strong>{task.kind}</strong>
|
||||
<span>{task.model}</span>
|
||||
<div className="infoGrid compact">
|
||||
<InfoItem label="API Key" value={task.apiKeyName || task.apiKeyId || '-'} />
|
||||
<InfoItem label="RequestID" value={task.requestId || '-'} />
|
||||
<InfoItem label="实际模型" value={task.resolvedModel || task.model} />
|
||||
<InfoItem label="Token" value={tokenText} />
|
||||
<InfoItem label="扣费" value={chargeText} />
|
||||
<InfoItem label="响应耗时" value={task.responseDurationMs ? `${task.responseDurationMs}ms` : '-'} />
|
||||
</div>
|
||||
<pre>{JSON.stringify({ result: task.result, usage: task.usage, billings: task.billings, billingSummary: task.billingSummary, metrics: task.metrics }, null, 2)}</pre>
|
||||
{tasks.length ? (
|
||||
<div className="taskList">
|
||||
{tasks.map((task) => (
|
||||
<TaskRecord key={task.id} task={task} />
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="emptyState">
|
||||
@ -332,6 +324,34 @@ function TaskPanel(props: { data: ConsoleData }) {
|
||||
);
|
||||
}
|
||||
|
||||
function TaskRecord(props: { task: GatewayTask }) {
|
||||
const usage = props.task.usage ?? {};
|
||||
const tokenText = usage.totalTokens ? `${usage.totalTokens}` : '-';
|
||||
const chargeText = props.task.finalChargeAmount !== undefined ? `${props.task.finalChargeAmount}` : '-';
|
||||
const badgeVariant = props.task.status === 'succeeded' ? 'success' : props.task.status === 'failed' ? 'destructive' : 'secondary';
|
||||
return (
|
||||
<div className="taskPreview">
|
||||
<div className="taskRecordHeader">
|
||||
<Badge variant={badgeVariant}>{props.task.status}</Badge>
|
||||
<strong>{props.task.kind}</strong>
|
||||
<span>{props.task.model}</span>
|
||||
<span>{formatDateTime(props.task.createdAt)}</span>
|
||||
</div>
|
||||
<div className="infoGrid compact">
|
||||
<InfoItem label="API Key" value={props.task.apiKeyName || props.task.apiKeyId || '-'} />
|
||||
<InfoItem label="RequestID" value={props.task.requestId || '-'} />
|
||||
<InfoItem label="模型类型" value={props.task.modelType || '-'} />
|
||||
<InfoItem label="实际模型" value={props.task.resolvedModel || props.task.model} />
|
||||
<InfoItem label="Token" value={tokenText} />
|
||||
<InfoItem label="扣费" value={chargeText} />
|
||||
<InfoItem label="响应耗时" value={props.task.responseDurationMs ? `${props.task.responseDurationMs}ms` : '-'} />
|
||||
<InfoItem label="错误" value={props.task.errorCode || props.task.errorMessage || '-'} />
|
||||
</div>
|
||||
<pre>{JSON.stringify({ result: props.task.result, usage: props.task.usage, billings: props.task.billings, billingSummary: props.task.billingSummary, metrics: props.task.metrics }, null, 2)}</pre>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function InfoItem(props: { label: string; value: string }) {
|
||||
return (
|
||||
<div className="infoItem">
|
||||
|
||||
@ -304,7 +304,7 @@ function buildPlatformTree(platforms: IntegrationPlatform[], platformModels: Pla
|
||||
.map((model) => ({
|
||||
id: model.id,
|
||||
name: modelLabel(model),
|
||||
subtitle: `${model.modelType} / ${model.modelName}`,
|
||||
subtitle: `${model.modelType.join(', ')} / ${model.modelName}`,
|
||||
})),
|
||||
})).sort((a, b) => a.name.localeCompare(b.name));
|
||||
}
|
||||
|
||||
@ -59,7 +59,7 @@ export function PlatformManagementPanel(props: {
|
||||
model.displayName,
|
||||
model.modelName,
|
||||
model.modelAlias,
|
||||
model.modelType,
|
||||
...model.modelType,
|
||||
model.provider,
|
||||
platform?.name,
|
||||
platform?.internalName,
|
||||
@ -473,7 +473,7 @@ function PlatformModelTable(props: {
|
||||
<ModelCatalogCard
|
||||
key={model.id}
|
||||
badges={[
|
||||
<Badge variant="outline">{model.modelType}</Badge>,
|
||||
<Badge variant="outline">{model.modelType.join(', ')}</Badge>,
|
||||
<Badge variant={model.enabled ? 'success' : 'secondary'}>{model.enabled ? 'enabled' : 'disabled'}</Badge>,
|
||||
]}
|
||||
chips={platformModelChips(model)}
|
||||
@ -889,7 +889,7 @@ function findBaseModelForPlatformModel(platform: IntegrationPlatform | undefined
|
||||
return baseModels.find((item) => item.id === model.baseModelId) ??
|
||||
baseModels.find((item) => item.canonicalModelKey === model.modelAlias) ??
|
||||
baseModels.find((item) => stableModelAlias(item) === model.modelAlias) ??
|
||||
baseModels.find((item) => item.providerKey === platform?.provider && item.providerModelName === model.modelName && baseModelTypes(item).includes(model.modelType));
|
||||
baseModels.find((item) => item.providerKey === platform?.provider && item.providerModelName === model.modelName && model.modelType.some((type) => baseModelTypes(item).includes(type)));
|
||||
}
|
||||
|
||||
function readPlatformModelIconPath(model: PlatformModel, baseModel?: BaseModelCatalogItem) {
|
||||
|
||||
@ -149,7 +149,7 @@ export function platformModelPayloads(models: BaseModelCatalogItem[], form: Plat
|
||||
canonicalModelKey: model.canonicalModelKey,
|
||||
modelName: model.providerModelName,
|
||||
modelAlias: stableModelAlias(model),
|
||||
modelType: primaryBaseModelType(model),
|
||||
modelType: baseModelTypes(model),
|
||||
displayName: stableModelAlias(model) || model.providerModelName,
|
||||
pricingMode: 'inherit_discount',
|
||||
discountFactor: optionalPositiveNumber(form.modelDiscountFactors[model.id]) ?? optionalPositiveNumber(form.modelDiscountFactor),
|
||||
|
||||
@ -638,7 +638,7 @@ function capabilityTypeKeys(
|
||||
: [contextKey, 'text_to_video', 'image_to_video', 'omni_video', 'video_generate', 'video'];
|
||||
return uniqueStrings([
|
||||
...stringListFromCapability(source.originalTypes),
|
||||
model.modelType,
|
||||
...model.modelType,
|
||||
...modeTypeHints.filter((item): item is string => Boolean(item)),
|
||||
]);
|
||||
}
|
||||
|
||||
@ -224,6 +224,23 @@ strong {
|
||||
gap: 14px;
|
||||
}
|
||||
|
||||
.taskList {
|
||||
display: grid;
|
||||
gap: 14px;
|
||||
}
|
||||
|
||||
.taskRecordHeader {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
gap: 8px 12px;
|
||||
color: var(--muted-foreground);
|
||||
}
|
||||
|
||||
.taskRecordHeader strong {
|
||||
color: var(--text-strong);
|
||||
}
|
||||
|
||||
.appShell {
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ export interface PlatformModelForm {
|
||||
canonicalModelKey: string;
|
||||
modelName: string;
|
||||
modelAlias: string;
|
||||
modelType: string;
|
||||
modelType: string[];
|
||||
pricingRuleSetId: string;
|
||||
discountFactor: string;
|
||||
}
|
||||
@ -84,7 +84,7 @@ export interface PlatformModelBindingInput {
|
||||
baseModelId?: string;
|
||||
modelName: string;
|
||||
modelAlias?: string;
|
||||
modelType: string;
|
||||
modelType: string[];
|
||||
displayName?: string;
|
||||
pricingMode?: string;
|
||||
retryPolicy?: Record<string, unknown>;
|
||||
|
||||
@ -36,6 +36,19 @@
|
||||
|
||||
状态约定:`未执行`、`执行中`、`成功`、`失败`、`阻塞`。
|
||||
|
||||
### 本次自动化回环记录
|
||||
|
||||
| 项目 | 结果 |
|
||||
| --- | --- |
|
||||
| 测试批次 | `loopback-20260511-integration` |
|
||||
| 执行方式 | `httptest` 本地服务 + 独立 PostgreSQL 测试库 `easyai_ai_gateway_test_codex` |
|
||||
| 已通过验证 | 注册 / 登录 / API Key、API Key scopes 拦截、管理员接口拒绝 API Key、模型候选路由、真实 Chat / 兼容 Chat / 文生图 / 图生图 / 文生视频 / 图生视频 / 首尾帧视频、任务详情、任务列表、事件流、callback outbox、重试、运行策略限流、降级、自动禁用、计费规则集绑定、钱包扣费和交易流水 |
|
||||
| 核心修复验证 | `base_model_catalog.model_type` 和 `platform_models.model_type` 均为 JSONB 数组;`doubao-4.5图像编辑` 同时包含 `image_generate` / `image_edit`;`豆包Seedance-1.5-pro` 同时包含 `video_generate` / `image_to_video`;运行时候选查询只用 `model_type` 命中 |
|
||||
| 真实外部链路 | 已使用主库 `Playground API Key` 和已配置平台 KEY 验证通过;Chat=`qwen-plus`,图像=`doubao-4.5图像编辑`,视频=`豆包Seedance-1.5-pro` |
|
||||
| 真实任务证据 | Chat `6b454b1f-29b8-48ef-8ebb-740062d2e8b4`;文生图 `2db64a7e-f01d-424c-a5eb-027ef58cacde`;图生图 `af4cae00-c6d7-4c70-8477-64326c5d5cfc`;文生视频 `2436e7aa-3518-442b-8559-1e48a9cb3462`;图生视频 `80cc9655-4c87-466d-b263-5df23d23c157`;首尾帧视频 `5353a0cf-efa8-4fbc-90fd-60d1897fd012` |
|
||||
| 验证命令 | `AI_GATEWAY_TEST_DATABASE_URL=... go test ./...`、`pnpm --filter @easyai-ai-gateway/web typecheck`、`pnpm test`、`git diff --check` |
|
||||
| 结论 | 本地自动化覆盖项和真实外部模型链路均为 `成功` |
|
||||
|
||||
## 3. 测试数据准备
|
||||
|
||||
| ID | 任务 | 接口 / 方式 | 成功判定 | 状态 | 结果记录 |
|
||||
@ -98,6 +111,7 @@
|
||||
| HISTORY-03 | 事件记录保存 | `gateway_task_events` | 事件包含接收、运行、重试、完成或失败,`simulated` 标记准确 | 未执行 | 事件摘要待填写 |
|
||||
| HISTORY-04 | 回调 outbox 保存 | `gateway_task_callback_outbox`,开启回调配置时验证 | 每个 task event 写入幂等 outbox,`task_id + seq + callback_url` 唯一 | 未执行 | outbox 数量待填写 |
|
||||
| HISTORY-05 | API Key 身份保存 | 任务详情或数据库字段 | `apiKeyId`、`apiKeyName`、`apiKeyPrefix` 可追溯到发起 Key | 未执行 | key 信息待填写 |
|
||||
| HISTORY-06 | 前端任务记录 | `GET /api/v1/tasks` + 工作台“任务记录”页 | 页面从服务端任务列表读取历史记录,不依赖当前会话里刚运行的单条 taskResult | 成功 | 已新增任务列表接口和前端列表渲染;自动化断言 task list 包含已完成任务 |
|
||||
|
||||
## 6. 定价规则绑定与扣费计算
|
||||
|
||||
@ -109,6 +123,7 @@
|
||||
| PRICE-02 | 绑定规则集到平台模型 | `POST /api/v1/platforms/{platformID}/models` 或更新模型 | 平台模型返回 `pricingRuleSetId=loopback-pricing-*` 或 `billingConfigOverride` 生效 | 未执行 | platformModelId 待填写 |
|
||||
| PRICE-03 | Chat 预估计费 | `POST /api/v1/pricing/estimate`,Chat 请求体 | 返回 `resolver=effective-pricing-v1`,`text_input` 与 `text_output` 金额按 tokens 计算 | 未执行 | estimate items 待填写 |
|
||||
| PRICE-04 | Chat 最终扣费 | 执行 Chat 成功任务 | `finalChargeAmount=sum(billings.amount)`,输入/输出 token 数与 usage 对齐 | 未执行 | usage、billings、finalChargeAmount 待填写 |
|
||||
| PRICE-04B | 钱包实际扣费 | 执行 Chat 成功任务后读取钱包 | `gateway_wallet_accounts.balance` 按 `finalChargeAmount` 扣减,`gateway_wallet_transactions` 写入 `task_billing` 幂等流水 | 成功 | 集成测试断言 100 -> 99.972,交易 amount=0.028 |
|
||||
| PRICE-05 | 文生图扣费 | 执行文生图成功任务 | `resourceType=image`,数量、尺寸、质量权重参与计算 | 未执行 | billings 待填写 |
|
||||
| PRICE-06 | 图生图扣费 | 执行图生图成功任务 | `resourceType=image_edit`,优先使用编辑价格,缺省时回退图片价格 | 未执行 | billings 待填写 |
|
||||
| PRICE-07 | 视频扣费 | 执行文生视频、图生视频、首尾帧视频成功任务 | `resourceType=video`,数量、分辨率或时长权重按规则计算,三类视频请求均有计费记录 | 未执行 | billings 待填写 |
|
||||
@ -175,7 +190,10 @@
|
||||
|
||||
| 时间 | 用例 ID | 失败现象 | 初步原因 | 修复位置 | 重跑结果 |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| 待填写 | 待填写 | 待填写 | 待填写 | 待填写 | 待填写 |
|
||||
| 2026-05-11 | TASK-IMAGE-01 | `doubao-4.5图像编辑` 使用 `1024x1024` 返回 `http_400`,提示像素不足 | 测试参数不符合真实模型最小尺寸约束 | 调整真实验证请求为 `2048x2048` | 文生图任务 `2db64a7e-f01d-424c-a5eb-027ef58cacde` 成功 |
|
||||
| 2026-05-11 | TASK-VIDEO-02 | 图生视频传 `image` 时上游按 `r2v` 处理并返回不支持 `Seedance-1.5-pro` | Volces 适配层把通用 `image` 映射成 `reference_image`,而非图生视频首帧 | `apps/api/internal/clients/volces.go`,将 `image` 默认映射为 `first_frame`,显式 `reference_image` 才走参考图 | 图生视频任务 `80cc9655-4c87-466d-b263-5df23d23c157` 成功 |
|
||||
| 2026-05-11 | TASK-VIDEO-02 / TASK-VIDEO-03 | 带图片的视频请求任务 `modelType` 仍记录为 `video_generate` | `videos.generations` 未根据请求体区分文生视频和图生视频 | `apps/api/internal/runner/service.go` / `pricing.go`,带图片、首帧或尾帧时使用 `image_to_video` | 图生视频与首尾帧视频均以 `modelType=image_to_video` 成功 |
|
||||
| 2026-05-11 | HISTORY-06 | 前端“任务记录”没有任何记录 | 前端只展示本地 `taskResult`,后端没有当前用户任务列表接口 | `GET /api/v1/tasks`、`ListTasks`、工作台任务列表状态和渲染 | 集成测试验证任务列表返回已完成任务;前端 typecheck 成功 |
|
||||
|
||||
## 12. 清理清单
|
||||
|
||||
|
||||
@ -560,7 +560,7 @@ export interface PlatformModel {
|
||||
platformName?: string;
|
||||
modelName: string;
|
||||
modelAlias?: string;
|
||||
modelType: 'chat' | 'image' | 'video' | 'audio' | 'embedding' | string;
|
||||
modelType: string[];
|
||||
displayName: string;
|
||||
capabilityOverride?: Record<string, unknown>;
|
||||
capabilities?: Record<string, unknown>;
|
||||
@ -625,6 +625,8 @@ export interface GatewayTask {
|
||||
responseDurationMs?: number;
|
||||
finishedAt?: string;
|
||||
error?: string;
|
||||
errorCode?: string;
|
||||
errorMessage?: string;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user