feat: add river-backed async task queue

This commit is contained in:
wangbo 2026-05-12 10:11:54 +08:00
parent d69aaed444
commit 7e220b7477
30 changed files with 1342 additions and 200 deletions

View File

@ -33,18 +33,19 @@ func main() {
if recovery, err := db.RecoverInterruptedRuntimeState(ctx); err != nil {
logger.Error("recover interrupted runtime state failed", "error", err)
os.Exit(1)
} else if recovery.ReleasedConcurrencyLeases > 0 || recovery.ReleasedRateReservations > 0 || recovery.FailedAttempts > 0 || recovery.FailedTasks > 0 {
} else if recovery.ReleasedConcurrencyLeases > 0 || recovery.ReleasedRateReservations > 0 || recovery.FailedAttempts > 0 || recovery.FailedTasks > 0 || recovery.RequeuedAsyncTasks > 0 {
logger.Warn("interrupted runtime state recovered",
"releasedConcurrencyLeases", recovery.ReleasedConcurrencyLeases,
"releasedRateReservations", recovery.ReleasedRateReservations,
"failedAttempts", recovery.FailedAttempts,
"failedTasks", recovery.FailedTasks,
"requeuedAsyncTasks", recovery.RequeuedAsyncTasks,
)
}
server := &http.Server{
Addr: cfg.HTTPAddr,
Handler: httpapi.NewServer(cfg, db, logger),
Handler: httpapi.NewServerWithContext(ctx, cfg, db, logger),
ReadHeaderTimeout: 10 * time.Second,
}

View File

@ -4,14 +4,28 @@ go 1.23
require (
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/jackc/pgx/v5 v5.7.2
golang.org/x/crypto v0.31.0
github.com/jackc/pgx/v5 v5.9.2
github.com/riverqueue/river v0.24.0
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0
github.com/riverqueue/river/rivertype v0.24.0
golang.org/x/crypto v0.37.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/text v0.21.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/riverqueue/river/riverdriver v0.24.0 // indirect
github.com/riverqueue/river/rivershared v0.24.0 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
go.uber.org/goleak v1.3.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/text v0.36.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -3,28 +3,63 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0=
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/riverqueue/river v0.24.0 h1:CesL6vymWgz0d+zNwtnSGRWaB+E8Dax+o9cxD7sUmKc=
github.com/riverqueue/river v0.24.0/go.mod h1:UZ3AxU5t6WtyqNssaea/AkRS8h/kJ+E9ImSB3xyb3ns=
github.com/riverqueue/river/riverdriver v0.24.0 h1:HqGgGkls11u+YKDA7cKOdYKlQwRNJyHuGa3UtOvpdT0=
github.com/riverqueue/river/riverdriver v0.24.0/go.mod h1:dEew9DDIKenNvzpm8Edw8+PkqP3c0zl1fKjiQTq2n/w=
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0 h1:yV37OIbRrhRwIiGeRT7P4D3szhAemu87BgCf8gTCoU4=
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0/go.mod h1:QfznySVKC4ljx53syd/bA/LRSsydAyuD3Q9/EbSniKA=
github.com/riverqueue/river/rivershared v0.24.0 h1:KysokksW75pug2a5RTOc6WESOupWmsylVc6VWvAx+4Y=
github.com/riverqueue/river/rivershared v0.24.0/go.mod h1:UIBfSdai0oWFlwFcoqG4DZX83iA/fLWTEBGrj7Oe1ho=
github.com/riverqueue/river/rivertype v0.24.0 h1:xrQZm/h6U8TBPyTsQPYD5leOapuoBAcdz30bdBwTqOg=
github.com/riverqueue/river/rivertype v0.24.0/go.mod h1:lmdl3vLNDfchDWbYdW2uAocIuwIN+ZaXqAukdSCFqWs=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -329,6 +329,7 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
var gotModel string
var gotText string
var gotFirstFrameRole string
var submittedRemoteTaskID string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
switch r.Method + " " + r.URL.Path {
@ -385,6 +386,13 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
"volcesPollTimeoutSeconds": 1,
},
},
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
submittedRemoteTaskID = remoteTaskID
if payload["id"] != "cgt-test" {
t.Fatalf("unexpected submitted payload: %+v", payload)
}
return nil
},
})
if err != nil {
t.Fatalf("run volces video: %v", err)
@ -392,6 +400,9 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
if submitPath != "/contents/generations/tasks" || pollPath != "/contents/generations/tasks/cgt-test" || gotAuth != "Bearer volces-key" {
t.Fatalf("unexpected paths/auth submit=%s poll=%s auth=%s", submitPath, pollPath, gotAuth)
}
if submittedRemoteTaskID != "cgt-test" {
t.Fatalf("remote task submit callback did not receive task id, got %q", submittedRemoteTaskID)
}
if gotModel != "doubao-seedance-2-0-260128" || gotFirstFrameRole != "first_frame" {
t.Fatalf("unexpected submitted model=%s role=%s", gotModel, gotFirstFrameRole)
}
@ -407,6 +418,53 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
}
}
func TestVolcesClientVideoResumePollsExistingTaskID(t *testing.T) {
var submitCalled bool
var pollPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method + " " + r.URL.Path {
case "POST /contents/generations/tasks":
submitCalled = true
t.Fatalf("resume should skip upstream submit when remote task id exists")
case "GET /contents/generations/tasks/cgt-existing":
pollPath = r.URL.Path
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "cgt-existing",
"status": "succeeded",
"created_at": 789,
"content": map[string]any{"video_url": "https://example.com/resumed.mp4"},
})
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
}))
defer server.Close()
response, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "豆包Seedance-2.0",
Body: map[string]any{"prompt": "resume polling", "pollIntervalMs": 100, "pollTimeoutSeconds": 1},
RemoteTaskID: "cgt-existing",
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ModelName: "豆包Seedance-2.0",
Credentials: map[string]any{"apiKey": "volces-key"},
},
})
if err != nil {
t.Fatalf("resume volces video: %v", err)
}
if submitCalled || pollPath != "/contents/generations/tasks/cgt-existing" {
t.Fatalf("resume should poll existing task only, submit=%v poll=%s", submitCalled, pollPath)
}
data, _ := response.Result["data"].([]any)
item, _ := data[0].(map[string]any)
if response.Result["upstream_task_id"] != "cgt-existing" || item["url"] != "https://example.com/resumed.mp4" {
t.Fatalf("unexpected resumed response: %+v", response.Result)
}
}
func extractText(result map[string]any) string {
choices, _ := result["choices"].([]any)
choice, _ := choices[0].(map[string]any)

View File

@ -11,14 +11,17 @@ import (
)
type Request struct {
Kind string
ModelType string
Model string
Body map[string]any
Candidate store.RuntimeModelCandidate
HTTPClient *http.Client
Stream bool
StreamDelta StreamDelta
Kind string
ModelType string
Model string
Body map[string]any
Candidate store.RuntimeModelCandidate
HTTPClient *http.Client
RemoteTaskID string
RemoteTaskPayload map[string]any
OnRemoteTaskSubmitted func(remoteTaskID string, payload map[string]any) error
Stream bool
StreamDelta StreamDelta
}
type Response struct {

View File

@ -67,16 +67,25 @@ func (c VolcesClient) runImage(ctx context.Context, request Request, apiKey stri
}
func (c VolcesClient) runVideo(ctx context.Context, request Request, apiKey string) (Response, error) {
body := volcesVideoBody(request)
submitStartedAt := time.Now()
submitResult, submitRequestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body)
submitFinishedAt := time.Now()
if err != nil {
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, submitFinishedAt)
}
upstreamTaskID := strings.TrimSpace(stringFromAny(submitResult["id"]))
submitRequestID := strings.TrimSpace(request.RemoteTaskID)
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
if upstreamTaskID == "" {
return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false}
body := volcesVideoBody(request)
submitResult, requestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body)
submitRequestID = requestID
if err != nil {
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, time.Now())
}
upstreamTaskID = strings.TrimSpace(stringFromAny(submitResult["id"]))
if upstreamTaskID == "" {
return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false}
}
if request.OnRemoteTaskSubmitted != nil {
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, submitResult); err != nil {
return Response{}, err
}
}
}
interval := volcesPollInterval(request)

View File

@ -18,6 +18,7 @@ import (
"testing"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/jackc/pgx/v5/pgxpool"
@ -37,7 +38,11 @@ func TestCoreLocalFlow(t *testing.T) {
}
defer db.Close()
handler := NewServer(config.Config{
assertRuntimeRecoveryReleasesPendingRateReservations(t, ctx, db)
serverCtx, cancelServer := context.WithCancel(ctx)
defer cancelServer()
handler := NewServerWithContext(serverCtx, config.Config{
AppEnv: "test",
HTTPAddr: ":0",
DatabaseURL: databaseURL,
@ -169,13 +174,14 @@ func TestCoreLocalFlow(t *testing.T) {
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountBefore); err != nil {
t.Fatalf("count tasks before scoped request: %v", err)
}
defaultImageModel := "openai:gpt-image-1"
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": "gpt-image-1",
"model": defaultImageModel,
"prompt": "scope should block this",
}, http.StatusForbidden, nil)
doJSON(t, server.URL, http.MethodPost, "/api/v1/pricing/estimate", chatOnlyAPIKeyResponse.Secret, map[string]any{
"kind": "images.generations",
"model": "gpt-image-1",
"model": defaultImageModel,
"prompt": "scope should block this estimate",
}, http.StatusForbidden, nil)
var taskCountAfter int
@ -309,8 +315,9 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
Result map[string]any `json:"result"`
} `json:"task"`
}
defaultTextModel := "openai:gpt-4o-mini"
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": "gpt-4o-mini",
"model": defaultTextModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
@ -334,7 +341,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
var compatChat map[string]any
doJSON(t, server.URL, http.MethodPost, "/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": "gpt-4o-mini",
"model": defaultTextModel,
"runMode": "simulation",
"messages": []map[string]any{{"role": "user", "content": "ping"}},
"simulation": true,
@ -352,7 +359,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", apiKeyResponse.Secret, map[string]any{
"model": "gpt-image-1",
"model": defaultImageModel,
"runMode": "simulation",
"prompt": "a tiny gateway console",
"size": "1024x1024",
@ -372,7 +379,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
} `json:"task"`
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{
"model": "gpt-image-1",
"model": defaultImageModel,
"runMode": "simulation",
"prompt": "replace background with clean studio light",
"image": "https://example.com/source.png",
@ -384,6 +391,42 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
t.Fatalf("unexpected image edit task: %+v", imageEditResponse.Task)
}
doubaoLiteImageEditModel := "doubao-5.0-lite图像编辑"
var doubaoLitePlatformModel struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "easyai:doubao-5.0-lite图像编辑",
"modelName": doubaoLiteImageEditModel,
"modelAlias": doubaoLiteImageEditModel,
"modelType": []string{"image_edit", "image_generate"},
"displayName": doubaoLiteImageEditModel,
}, http.StatusCreated, &doubaoLitePlatformModel)
var doubaoLiteAsyncEdit struct {
TaskID string `json:"taskId"`
Task struct {
ID string `json:"id"`
Status string `json:"status"`
AsyncMode bool `json:"asyncMode"`
} `json:"task"`
}
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{
"model": doubaoLiteImageEditModel,
"runMode": "simulation",
"prompt": "turn the attached bright desktop object into a clean product-style render",
"image": "https://example.com/doubao-lite-source.png",
"size": "2048x2048",
"simulation": true,
"simulationDurationMs": 5,
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &doubaoLiteAsyncEdit)
if doubaoLiteAsyncEdit.TaskID == "" || !doubaoLiteAsyncEdit.Task.AsyncMode {
t.Fatalf("doubao-5.0-lite image edit async task should be accepted: %+v", doubaoLiteAsyncEdit)
}
doubaoLiteAsyncEditDone := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, doubaoLiteAsyncEdit.TaskID, []string{"succeeded"}, 5*time.Second)
if doubaoLiteAsyncEditDone.Status != "succeeded" {
t.Fatalf("doubao-5.0-lite image edit async task should succeed through river queue, got %+v", doubaoLiteAsyncEditDone)
}
var gptImageModelTypesRaw []byte
if err := testPool.QueryRow(ctx, `
SELECT model_type
@ -641,6 +684,7 @@ WHERE reference_type = 'gateway_task'
}
rateLimitedModel := "rate-limit-smoke-" + suffixText
rateLimitWindowSeconds := 3
var rateLimitPolicySet struct {
ID string `json:"id"`
}
@ -652,7 +696,7 @@ WHERE reference_type = 'gateway_task'
"maxAttempts": 1,
},
"rateLimitPolicy": map[string]any{
"rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": 60}},
"rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": rateLimitWindowSeconds}},
},
}, http.StatusCreated, &rateLimitPolicySet)
var rateLimitPlatformModel map[string]any
@ -682,6 +726,7 @@ WHERE reference_type = 'gateway_task'
if rateLimitFailedTask.Task.Status != "failed" || rateLimitFailedTask.Task.ErrorCode != "bad_request" {
t.Fatalf("failed rate-limited task should fail before consuming rpm: %+v", rateLimitFailedTask.Task)
}
waitForRateLimitWindowHead(t, rateLimitWindowSeconds)
var rateLimitTaskOne struct {
Task struct {
Status string `json:"status"`
@ -713,6 +758,38 @@ WHERE reference_type = 'gateway_task'
if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" {
t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task)
}
var asyncRateLimitTask struct {
TaskID string `json:"taskId"`
Task struct {
ID string `json:"id"`
Status string `json:"status"`
AsyncMode bool `json:"asyncMode"`
} `json:"task"`
}
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": rateLimitedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "async queued"}},
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &asyncRateLimitTask)
if asyncRateLimitTask.TaskID == "" || asyncRateLimitTask.Task.ID != asyncRateLimitTask.TaskID || !asyncRateLimitTask.Task.AsyncMode {
t.Fatalf("async task response should expose task id and async mode: %+v", asyncRateLimitTask)
}
asyncRateLimitDetail := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"queued"}, 2*time.Second)
if asyncRateLimitDetail.Status != "queued" {
t.Fatalf("async rate-limited task should return to queued state, got %+v", asyncRateLimitDetail)
}
if len(asyncRateLimitDetail.Attempts) == 0 || asyncRateLimitDetail.Attempts[0].ErrorCode != "rate_limit" {
t.Fatalf("async rate-limited task should record a rate_limit attempt before requeue: %+v", asyncRateLimitDetail)
}
asyncRateLimitCompleted := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"succeeded"}, time.Duration(rateLimitWindowSeconds+3)*time.Second)
if asyncRateLimitCompleted.Status != "succeeded" {
t.Fatalf("async rate-limited task should be pulled from queue after the limit window resets, got %+v", asyncRateLimitCompleted)
}
if len(asyncRateLimitCompleted.Attempts) < 2 || asyncRateLimitCompleted.Attempts[len(asyncRateLimitCompleted.Attempts)-1].Status != "succeeded" {
t.Fatalf("async rate-limited task should create a new successful attempt after requeue: %+v", asyncRateLimitCompleted)
}
videoRouteModel := "video-route-smoke-" + suffixText
var videoRoutePlatformModel map[string]any
@ -823,16 +900,19 @@ WHERE reference_type = 'gateway_task'
Metrics map[string]any `json:"metrics"`
}
doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks/"+failoverTask.Task.ID, apiKeyResponse.Secret, nil, http.StatusOK, &failoverDetail)
if len(failoverDetail.Attempts) != 2 {
t.Fatalf("failover task history should include two attempts, got %+v", failoverDetail.Attempts)
if len(failoverDetail.Attempts) != 3 {
t.Fatalf("failover task history should include two failed retries plus one successful failover attempt, got %+v", failoverDetail.Attempts)
}
if failoverDetail.Attempts[0].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[0].Status != "failed" || !failoverDetail.Attempts[0].Retryable || failoverDetail.Attempts[0].ErrorCode == "" {
t.Fatalf("first failover attempt should preserve failed platform and reason: %+v", failoverDetail.Attempts[0])
}
if failoverDetail.Attempts[1].PlatformName != "OpenAI Retry Success" || failoverDetail.Attempts[1].Status != "succeeded" || failoverDetail.Attempts[1].ResponseMS <= 0 {
t.Fatalf("second failover attempt should preserve successful platform: %+v", failoverDetail.Attempts[1])
if failoverDetail.Attempts[1].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[1].Status != "failed" || !failoverDetail.Attempts[1].Retryable {
t.Fatalf("second failover attempt should preserve the same-client retry failure: %+v", failoverDetail.Attempts[1])
}
if summary, ok := failoverDetail.Metrics["attempts"].([]any); !ok || len(summary) != 2 {
if failoverDetail.Attempts[2].PlatformName != "OpenAI Retry Success" || failoverDetail.Attempts[2].Status != "succeeded" || failoverDetail.Attempts[2].ResponseMS <= 0 {
t.Fatalf("third failover attempt should preserve successful platform: %+v", failoverDetail.Attempts[2])
}
if summary, ok := failoverDetail.Metrics["attempts"].([]any); !ok || len(summary) != 3 {
t.Fatalf("task metrics should keep attempt-chain summary, got %+v", failoverDetail.Metrics)
}
@ -1045,6 +1125,9 @@ WHERE m.platform_id = $1::uuid
if resp.StatusCode != http.StatusOK || !bytes.Contains(body, []byte("task.completed")) {
t.Fatalf("unexpected events response status=%d body=%s", resp.StatusCode, string(body))
}
if !bytes.Contains(body, []byte("task.running")) {
t.Fatalf("events response should include running transition event body=%s", string(body))
}
if !bytes.Contains(body, []byte("task.progress")) {
t.Fatalf("events response should include progress events body=%s", string(body))
}
@ -1074,6 +1157,47 @@ WHERE m.platform_id = $1::uuid
if callbackRows == 0 {
t.Fatal("task progress callback outbox should receive events")
}
var restartAsyncTask struct {
TaskID string `json:"taskId"`
Task struct {
ID string `json:"id"`
Status string `json:"status"`
AsyncMode bool `json:"asyncMode"`
} `json:"task"`
}
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": defaultTextModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 2000,
"messages": []map[string]any{{"role": "user", "content": "river worker restart"}},
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &restartAsyncTask)
if restartAsyncTask.TaskID == "" || !restartAsyncTask.Task.AsyncMode {
t.Fatalf("restart async task should be accepted as async: %+v", restartAsyncTask)
}
restartRunning := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, restartAsyncTask.TaskID, []string{"running"}, 3*time.Second)
if restartRunning.Status != "running" {
t.Fatalf("restart async task should be running before worker restart, got %+v", restartRunning)
}
cancelServer()
serverCtx2, cancelServer2 := context.WithCancel(ctx)
defer cancelServer2()
server2 := httptest.NewServer(NewServerWithContext(serverCtx2, config.Config{
AppEnv: "test",
HTTPAddr: ":0",
DatabaseURL: databaseURL,
IdentityMode: "hybrid",
JWTSecret: "test-secret",
TaskProgressCallbackEnabled: true,
TaskProgressCallbackURL: "http://callback.local/task-progress",
CORSAllowedOrigin: "*",
}, db, slog.New(slog.NewTextHandler(io.Discard, nil))))
defer server2.Close()
restartRecovered := waitForTaskStatus(t, server2.URL, apiKeyResponse.Secret, restartAsyncTask.TaskID, []string{"succeeded"}, 8*time.Second)
if restartRecovered.Status != "succeeded" {
t.Fatalf("river worker restart should recover and finish async task, got %+v", restartRecovered)
}
}
func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) {
@ -1089,6 +1213,51 @@ func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) {
}
}
func assertRuntimeRecoveryReleasesPendingRateReservations(t *testing.T, ctx context.Context, db *store.Store) {
t.Helper()
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
task, err := db.CreateTask(ctx, store.CreateTaskInput{
Kind: "chat.completions",
Model: "recovery-reservation-smoke-" + suffix,
RunMode: "test",
Async: true,
Request: map[string]any{"messages": []any{}},
}, &auth.User{ID: "recovery-user-" + suffix})
if err != nil {
t.Fatalf("create recovery reservation task: %v", err)
}
scopeKey := "recovery-client-" + suffix
if _, err := db.ReserveRateLimits(ctx, task.ID, "", []store.RateLimitReservation{{
ScopeType: "client",
ScopeKey: scopeKey,
Metric: "rpm",
Limit: 10,
Amount: 3,
WindowSeconds: 60,
}}); err != nil {
t.Fatalf("reserve recovery rate limit: %v", err)
}
recovery, err := db.RecoverInterruptedRuntimeState(ctx)
if err != nil {
t.Fatalf("recover interrupted runtime state with pending reservation: %v", err)
}
if recovery.ReleasedRateReservations == 0 {
t.Fatalf("recovery should release pending rate reservation, got %+v", recovery)
}
var reservedValue float64
if err := db.Pool().QueryRow(ctx, `
SELECT COALESCE(MAX(reserved_value), 0)::float8
FROM gateway_rate_limit_counters
WHERE scope_type = 'client'
AND scope_key = $1
AND metric = 'rpm'`, scopeKey).Scan(&reservedValue); err != nil {
t.Fatalf("query recovered rate limit counter: %v", err)
}
if reservedValue != 0 {
t.Fatalf("recovery should release reserved counter value, got %f", reservedValue)
}
}
func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
t.Helper()
_, filename, _, _ := runtime.Caller(0)
@ -1114,6 +1283,11 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
}
func doJSON(t *testing.T, baseURL string, method string, path string, token string, payload any, expectedStatus int, out any) {
t.Helper()
doJSONWithHeaders(t, baseURL, method, path, token, payload, nil, expectedStatus, out)
}
func doJSONWithHeaders(t *testing.T, baseURL string, method string, path string, token string, payload any, headers map[string]string, expectedStatus int, out any) {
t.Helper()
var body io.Reader
if payload != nil {
@ -1133,6 +1307,9 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
for key, value := range headers {
req.Header.Set(key, value)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("%s %s: %v", method, path, err)
@ -1149,6 +1326,50 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri
}
}
type taskWaitDetail struct {
ID string `json:"id"`
Status string `json:"status"`
Attempts []struct {
Status string `json:"status"`
ErrorCode string `json:"errorCode"`
} `json:"attempts"`
}
func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string, statuses []string, timeout time.Duration) taskWaitDetail {
t.Helper()
wanted := map[string]bool{}
for _, status := range statuses {
wanted[status] = true
}
var detail taskWaitDetail
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskID, token, nil, http.StatusOK, &detail)
if wanted[detail.Status] {
return detail
}
time.Sleep(100 * time.Millisecond)
}
return detail
}
func waitForRateLimitWindowHead(t *testing.T, windowSeconds int) {
t.Helper()
if windowSeconds <= 0 {
return
}
window := time.Duration(windowSeconds) * time.Second
deadline := time.Now().Add(window + time.Second)
for time.Now().Before(deadline) {
elapsed := time.Duration(time.Now().UnixNano() % int64(window))
if elapsed < 300*time.Millisecond {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("timed out waiting for %ds rate limit window head", windowSeconds)
}
func stringSliceContains(values []string, target string) bool {
for _, value := range values {
if value == target {

View File

@ -542,11 +542,13 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
return
}
asyncMode := asyncRequest(r)
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
Kind: kind,
Model: model,
RunMode: runModeFromRequest(body),
Async: asyncMode,
Request: body,
}, user)
if err != nil {
@ -554,6 +556,14 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
writeError(w, http.StatusInternalServerError, "create task failed")
return
}
if asyncMode {
if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil {
writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed")
return
}
writeTaskAccepted(w, task)
return
}
if compatible {
if boolValue(body, "stream") {
flusher := prepareCompatibleStream(w)
@ -602,13 +612,23 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr)
}
writeJSON(w, http.StatusAccepted, map[string]any{
"task": result.Task,
"next": map[string]string{
"events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID),
"detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID),
},
})
writeTaskAccepted(w, result.Task)
})
}
func asyncRequest(r *http.Request) bool {
value := strings.TrimSpace(strings.ToLower(r.Header.Get("x-async")))
return value == "1" || value == "true" || value == "yes" || value == "on"
}
func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) {
writeJSON(w, http.StatusAccepted, map[string]any{
"taskId": task.ID,
"task": task,
"next": map[string]string{
"events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID),
"detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID),
},
})
}

View File

@ -1,6 +1,7 @@
package httpapi
import (
"context"
"log/slog"
"net/http"
"strings"
@ -12,6 +13,7 @@ import (
)
type Server struct {
ctx context.Context
cfg config.Config
store *store.Store
auth *auth.Authenticator
@ -20,7 +22,12 @@ type Server struct {
}
func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler {
return NewServerWithContext(context.Background(), cfg, db, logger)
}
func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler {
server := &Server{
ctx: ctx,
cfg: cfg,
store: db,
auth: auth.New(cfg.JWTSecret, cfg.ServerMainBaseURL, cfg.ServerMainInternalToken),
@ -28,6 +35,7 @@ func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Han
logger: logger,
}
server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey
server.runner.StartAsyncQueueWorker(ctx)
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", server.health)
@ -146,7 +154,7 @@ func (s *Server) cors(next http.Handler) http.Handler {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
}
if r.Method == http.MethodOptions {

View File

@ -2,13 +2,49 @@ package runner
import (
"context"
"errors"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type localRateLimitError struct {
clientErr *clients.ClientError
cause error
retryAfter time.Duration
}
func (e *localRateLimitError) Error() string {
if e == nil || e.clientErr == nil {
return store.ErrRateLimited.Error()
}
return e.clientErr.Error()
}
func (e *localRateLimitError) Unwrap() []error {
if e == nil || e.clientErr == nil {
if e != nil && e.cause != nil {
return []error{e.cause}
}
return []error{store.ErrRateLimited}
}
if e.cause != nil {
return []error{e.clientErr, e.cause}
}
return []error{e.clientErr, store.ErrRateLimited}
}
func localRateLimitRetryAfter(err error) time.Duration {
var limitErr *localRateLimitError
if errors.As(err, &limitErr) && limitErr.retryAfter > 0 {
return limitErr.retryAfter
}
return store.RateLimitRetryAfter(err)
}
func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation {
out := make([]store.RateLimitReservation, 0)
out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...)

View File

@ -0,0 +1,216 @@
package runner
import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/riverqueue/river"
"github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivermigrate"
"github.com/riverqueue/river/rivertype"
)
const asyncTaskQueueName = "gateway_tasks"
type asyncTaskArgs struct {
TaskID string `json:"task_id" river:"unique"`
}
func (asyncTaskArgs) Kind() string { return "gateway_task_run" }
type asyncTaskWorker struct {
river.WorkerDefaults[asyncTaskArgs]
service *Service
}
func (w *asyncTaskWorker) Work(ctx context.Context, job *river.Job[asyncTaskArgs]) error {
task, err := w.service.store.GetTask(ctx, job.Args.TaskID)
if err != nil {
return err
}
if task.Status == "succeeded" || task.Status == "failed" || task.Status == "cancelled" {
return nil
}
result, runErr := w.service.Execute(ctx, task, authUserFromTask(task))
if runErr == nil {
w.service.logger.Debug("river async task completed", "taskID", task.ID, "status", result.Task.Status, "riverJobID", job.ID)
return nil
}
var queuedErr *TaskQueuedError
if errors.As(runErr, &queuedErr) {
return river.JobSnooze(queuedErr.Delay)
}
if ctx.Err() != nil {
queued, queueErr := w.service.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
if queueErr != nil {
return queueErr
}
w.service.logger.Debug("river async task interrupted and requeued", "taskID", task.ID, "status", queued.Status, "riverJobID", job.ID)
return river.JobSnooze(0)
}
w.service.logger.Warn("river async task completed with failure", "taskID", task.ID, "error", runErr, "riverJobID", job.ID)
return nil
}
func (s *Service) StartAsyncQueueWorker(ctx context.Context) {
if err := s.startRiverQueue(ctx); err != nil {
s.logger.Error("start river async queue failed", "error", err)
panic(err)
}
}
func (s *Service) startRiverQueue(ctx context.Context) error {
driver := riverpgxv5.New(s.store.Pool())
migrator, err := rivermigrate.New(driver, nil)
if err != nil {
return err
}
if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil {
return err
}
workers := river.NewWorkers()
if err := river.AddWorkerSafely(workers, &asyncTaskWorker{service: s}); err != nil {
return err
}
riverClient, err := river.NewClient(driver, &river.Config{
ID: asyncWorkerID(),
JobTimeout: -1,
Logger: s.logger,
CompletedJobRetentionPeriod: 24 * time.Hour,
Queues: map[string]river.QueueConfig{
asyncTaskQueueName: {MaxWorkers: 32},
},
RescueStuckJobsAfter: 30 * time.Second,
TestOnly: s.cfg.AppEnv == "test",
Workers: workers,
})
if err != nil {
return err
}
s.riverClient = riverClient
if err := riverClient.Start(ctx); err != nil {
return err
}
if err := s.recoverAsyncRiverJobs(ctx); err != nil {
return err
}
go func() {
<-ctx.Done()
stopCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := riverClient.StopAndCancel(stopCtx); err != nil {
s.logger.Warn("stop river async queue failed", "error", err)
}
}()
return nil
}
func (s *Service) EnqueueAsyncTask(ctx context.Context, task store.GatewayTask) error {
if s.riverClient == nil {
return errors.New("river async queue is not started")
}
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: task.ID}, asyncTaskInsertOpts(task))
if err != nil {
return err
}
if result.Job != nil {
return s.store.SetTaskRiverJobID(ctx, task.ID, result.Job.ID)
}
return nil
}
func (s *Service) WakeAsyncQueueAfter(ctx context.Context, delay time.Duration) {
}
func (s *Service) RunAsyncTask(ctx context.Context, task store.GatewayTask, user *auth.User) {
if err := s.EnqueueAsyncTask(ctx, task); err != nil {
s.logger.Warn("enqueue river async task failed", "taskID", task.ID, "error", err)
}
}
func (s *Service) recoverAsyncRiverJobs(ctx context.Context) error {
items, err := s.store.ListRecoverableAsyncTasks(ctx, 1000)
if err != nil {
return err
}
for _, item := range items {
task := store.GatewayTask{ID: item.ID}
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: item.ID}, asyncTaskInsertOpts(task))
if err != nil {
return err
}
if result.Job != nil {
if err := s.store.SetTaskRiverJobID(ctx, item.ID, result.Job.ID); err != nil {
return err
}
}
}
if len(items) > 0 {
s.logger.Info("river async queue recovered persisted tasks", "count", len(items))
}
return nil
}
func asyncTaskInsertOpts(task store.GatewayTask) *river.InsertOpts {
priority := 2
if task.ID == "" {
priority = 3
}
return &river.InsertOpts{
MaxAttempts: 1000,
Priority: priority,
Queue: asyncTaskQueueName,
Tags: []string{"gateway-task"},
UniqueOpts: river.UniqueOpts{
ByArgs: true,
ByQueue: true,
ByState: []rivertype.JobState{
rivertype.JobStateAvailable,
rivertype.JobStatePending,
rivertype.JobStateRetryable,
rivertype.JobStateRunning,
rivertype.JobStateScheduled,
},
},
}
}
func authUserFromTask(task store.GatewayTask) *auth.User {
roles := []string{"user"}
if strings.TrimSpace(task.UserID) == "" {
roles = nil
}
return &auth.User{
ID: firstNonEmptyString(task.GatewayUserID, task.UserID),
Roles: roles,
TenantID: task.TenantID,
GatewayTenantID: task.GatewayTenantID,
TenantKey: task.TenantKey,
Source: firstNonEmptyString(task.UserSource, "gateway"),
GatewayUserID: task.GatewayUserID,
UserGroupID: task.UserGroupID,
UserGroupKey: task.UserGroupKey,
APIKeyID: task.APIKeyID,
APIKeyName: task.APIKeyName,
APIKeyPrefix: task.APIKeyPrefix,
}
}
func asyncWorkerID() string {
host, _ := os.Hostname()
host = strings.TrimSpace(host)
if host == "" {
host = "localhost"
}
return fmt.Sprintf("%s:%d:%d", host, os.Getpid(), time.Now().UnixNano())
}
var _ river.Worker[asyncTaskArgs] = (*asyncTaskWorker)(nil)

View File

@ -1,6 +1,7 @@
package runner
import (
"errors"
"fmt"
"strings"
@ -54,6 +55,9 @@ func shouldRetrySameClient(candidate store.RuntimeModelCandidate, err error) boo
func retryDecisionForCandidate(candidate store.RuntimeModelCandidate, err error) retryDecision {
policy := effectiveRetryPolicy(candidate)
info := failureInfoFromError(err)
if errors.Is(err, store.ErrRateLimited) {
return retryDecision{Retry: false, Reason: "local_rate_limit_wait_queue", Match: policyRuleMatch{Source: "gateway_rate_limits", Policy: "rateLimitPolicy", Rule: "localCapacity", Value: "exceeded"}, Info: info}
}
if !boolFromPolicy(policy, "enabled", true) {
return retryDecision{Retry: false, Reason: "retry_disabled", Match: policyRuleMatch{Source: "model_runtime_policy_sets.retry_policy", Policy: "retryPolicy", Rule: "enabled", Value: "false"}, Info: info}
}
@ -94,6 +98,9 @@ func failoverDecisionForCandidate(runnerPolicy store.RunnerPolicy, candidate sto
if cooldownSeconds <= 0 {
cooldownSeconds = 300
}
if errors.Is(err, store.ErrRateLimited) && store.RateLimitRetryable(err) {
return failoverDecision{Retry: true, Action: "next", Reason: "local_rate_limit_try_next_candidate", CooldownSeconds: cooldownSeconds, Match: policyRuleMatch{Source: "gateway_rate_limits", Policy: "rateLimitPolicy", Rule: "localCapacity", Value: "exceeded"}, Info: info}
}
if match, ok := failoverAllowMatchWithSources(runnerPolicy.FailoverPolicy, overridePolicy, info); ok {
return failoverDecision{Retry: true, Action: action, Reason: "failover_allow_policy", CooldownSeconds: cooldownSeconds, Match: match, Info: info}
}

View File

@ -2,6 +2,7 @@ package runner
import (
"context"
"errors"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
@ -65,6 +66,9 @@ func (s *Service) applyFailoverAction(ctx context.Context, taskID string, candid
}
func (s *Service) applyPriorityDemotePolicy(ctx context.Context, taskID string, attemptNo int, runnerPolicy store.RunnerPolicy, candidate store.RuntimeModelCandidate, cause error, simulated bool) {
if errors.Is(cause, store.ErrRateLimited) {
return
}
decision := priorityDemoteDecisionForCandidate(runnerPolicy, cause)
if !decision.Demote {
return

View File

@ -3,6 +3,7 @@ package runner
import (
"context"
"errors"
"fmt"
"log/slog"
"strconv"
"strings"
@ -12,6 +13,8 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/jackc/pgx/v5"
"github.com/riverqueue/river"
)
type Service struct {
@ -20,6 +23,7 @@ type Service struct {
logger *slog.Logger
clients map[string]clients.Client
httpClients *httpClientCache
riverClient *river.Client[pgx.Tx]
}
type Result struct {
@ -27,6 +31,20 @@ type Result struct {
Output map[string]any
}
var ErrTaskQueued = errors.New("task queued")
type TaskQueuedError struct {
Delay time.Duration
}
func (e *TaskQueuedError) Error() string {
return ErrTaskQueued.Error()
}
func (e *TaskQueuedError) Is(target error) bool {
return target == ErrTaskQueued
}
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
httpClients := newHTTPClientCache()
return &Service{
@ -55,6 +73,14 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
executeStartedAt := time.Now()
body := normalizeRequest(task.Kind, task.Request)
modelType := modelTypeFromKind(task.Kind, body)
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
return Result{}, err
}
if task.Status != "running" {
if err := s.emit(ctx, task.ID, "task.running", "running", "starting", 0.12, "task pulled from queue and started", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
return Result{}, err
}
}
if err := validateRequest(task.Kind, body); err != nil {
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
if finishErr != nil {
@ -83,9 +109,6 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
return Result{}, err
}
}
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
return Result{}, err
}
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
return Result{}, err
}
@ -96,7 +119,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
}
maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy)
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
attemptNo := 0
attemptNo := task.AttemptCount
var lastErr error
for index, candidate := range candidates {
if index >= maxPlatforms {
@ -251,6 +274,20 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
if lastErr != nil {
message = lastErr.Error()
}
if task.AsyncMode && ctx.Err() != nil {
queued, queueErr := s.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
if queueErr != nil {
return Result{}, queueErr
}
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: 0}
}
if task.AsyncMode && errors.Is(lastErr, store.ErrRateLimited) && store.RateLimitRetryable(lastErr) {
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr)
if queueErr != nil {
return Result{}, queueErr
}
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay}
}
failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr)
if err != nil {
return Result{}, err
@ -261,7 +298,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) {
simulated := isSimulation(task, candidate)
if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil {
return clients.Response{}, err
return clients.Response{}, fmt.Errorf("emit attempt started: %w", err)
}
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
TaskID: task.ID,
@ -276,21 +313,22 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
Metrics: attemptMetrics(candidate, attemptNo, simulated),
})
if err != nil {
return clients.Response{}, err
return clients.Response{}, fmt.Errorf("create task attempt: %w", err)
}
reservations := s.rateLimitReservations(ctx, user, candidate, body)
limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, attemptID, reservations)
if err != nil {
clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: false}
retryable := store.RateLimitRetryable(err)
clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: retryable}
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "failed",
Retryable: false,
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(clientErr, false)}}),
Retryable: retryable,
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": retryable, "retryAfterMs": localRateLimitRetryAfter(err).Milliseconds(), "trace": []any{failureTraceEntry(clientErr, retryable)}}),
ErrorCode: "rate_limit",
ErrorMessage: err.Error(),
})
return clients.Response{}, clientErr
return clients.Response{}, &localRateLimitError{clientErr: clientErr, cause: err, retryAfter: localRateLimitRetryAfter(err)}
}
rateReservationsFinalized := false
defer func() {
@ -301,7 +339,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs)
if err := s.store.RecordClientAssignment(ctx, candidate); err != nil {
return clients.Response{}, err
return clients.Response{}, fmt.Errorf("record client assignment: %w", err)
}
defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "")
@ -315,17 +353,25 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
ErrorCode: clients.ErrorCode(err),
ErrorMessage: err.Error(),
})
return clients.Response{}, err
return clients.Response{}, fmt.Errorf("prepare http client: %w", err)
}
client := s.clientFor(candidate, simulated)
callStartedAt := time.Now()
response, err := client.Run(ctx, clients.Request{
Kind: task.Kind,
ModelType: candidate.ModelType,
Model: task.Model,
Body: body,
Candidate: candidate,
HTTPClient: requestHTTPClient,
Kind: task.Kind,
ModelType: candidate.ModelType,
Model: task.Model,
Body: body,
Candidate: candidate,
HTTPClient: requestHTTPClient,
RemoteTaskID: task.RemoteTaskID,
RemoteTaskPayload: task.RemoteTaskPayload,
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
if strings.TrimSpace(remoteTaskID) == "" {
return nil
}
return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload)
},
Stream: boolFromMap(body, "stream"),
StreamDelta: onDelta,
})
@ -400,11 +446,11 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
response.Result = uploadedResult
for _, progress := range response.Progress {
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
return clients.Response{}, err
return clients.Response{}, fmt.Errorf("emit task progress: %w", err)
}
}
if err := s.store.CommitRateLimitReservations(ctx, limitResult.Reservations, tokenUsageAmounts(response.Usage)); err != nil {
return clients.Response{}, err
return clients.Response{}, fmt.Errorf("commit rate limit reservations: %w", err)
}
rateReservationsFinalized = true
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
@ -418,7 +464,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
ResponseFinishedAt: response.ResponseFinishedAt,
ResponseDurationMS: response.ResponseDurationMS,
}); err != nil {
return clients.Response{}, err
return clients.Response{}, fmt.Errorf("finish task attempt: %w", err)
}
return response, nil
}
@ -459,6 +505,41 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess
return failed, nil
}
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error) (store.GatewayTask, time.Duration, error) {
delay := localRateLimitRetryAfter(cause)
if delay <= 0 {
delay = 5 * time.Second
}
queued, err := s.store.RequeueTask(ctx, task.ID, delay)
if err != nil {
return store.GatewayTask{}, 0, err
}
payload := map[string]any{
"code": "rate_limit",
"message": cause.Error(),
"retryAfterMs": delay.Milliseconds(),
}
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "rate_limited", 0.2, "task queued by local rate limit", payload, task.RunMode == "simulation"); eventErr != nil {
return store.GatewayTask{}, 0, eventErr
}
return queued, delay, nil
}
func (s *Service) requeueInterruptedAsyncTask(ctx context.Context, task store.GatewayTask) (store.GatewayTask, error) {
queued, err := s.store.RequeueTask(ctx, task.ID, 0)
if err != nil {
return store.GatewayTask{}, err
}
payload := map[string]any{"code": "worker_interrupted"}
if task.RemoteTaskID != "" {
payload["remoteTaskId"] = task.RemoteTaskID
}
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "worker_interrupted", 0.2, "async task queued after worker interruption", payload, task.RunMode == "simulation"); eventErr != nil {
return store.GatewayTask{}, eventErr
}
return queued, nil
}
func (s *Service) withAttemptHistory(ctx context.Context, taskID string, metrics map[string]any) map[string]any {
attempts, err := s.store.ListTaskAttempts(ctx, taskID)
if err != nil {

View File

@ -19,7 +19,7 @@ SELECT p.id::text, p.platform_key, p.name, p.provider,
COALESCE(p.dynamic_priority, p.priority) AS effective_priority,
m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''),
COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''),
$2 AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
$2::text AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
COALESCE(b.pricing_rule_set_id::text, ''),
@ -33,21 +33,21 @@ LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provi
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
LEFT JOIN runtime_client_states s
ON s.client_id = p.platform_key || ':' || $2 || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
WHERE p.status = 'enabled'
AND p.deleted_at IS NULL
AND m.enabled = true
AND m.model_type @> jsonb_build_array($2)
AND m.model_type @> jsonb_build_array($2::text)
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
AND (m.cooldown_until IS NULL OR m.cooldown_until <= now())
AND (
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1)
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
OR (
COALESCE(m.model_alias, '') = ''
AND (
m.model_name = $1
OR b.canonical_model_key = $1
OR b.provider_model_name = $1
m.model_name = $1::text
OR b.canonical_model_key = $1::text
OR b.provider_model_name = $1::text
)
)
)
@ -184,15 +184,15 @@ LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
WHERE p.status = 'enabled'
AND p.deleted_at IS NULL
AND m.enabled = true
AND m.model_type @> jsonb_build_array($2)
AND m.model_type @> jsonb_build_array($2::text)
AND (
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1)
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
OR (
COALESCE(m.model_alias, '') = ''
AND (
m.model_name = $1
OR b.canonical_model_key = $1
OR b.provider_model_name = $1
m.model_name = $1::text
OR b.canonical_model_key = $1::text
OR b.provider_model_name = $1::text
)
)
)

View File

@ -53,6 +53,10 @@ func (s *Store) Ping(ctx context.Context) error {
return s.pool.Ping(ctx)
}
func (s *Store) Pool() *pgxpool.Pool {
return s.pool
}
type Platform struct {
ID string `json:"id"`
Provider string `json:"provider"`
@ -374,6 +378,7 @@ type CreateTaskInput struct {
Kind string `json:"kind"`
Model string `json:"model"`
RunMode string `json:"runMode"`
Async bool `json:"async"`
Request map[string]any `json:"request"`
}
@ -398,7 +403,12 @@ type GatewayTask struct {
ResolvedModel string `json:"resolvedModel,omitempty"`
RequestID string `json:"requestId,omitempty"`
Request map[string]any `json:"request,omitempty"`
AsyncMode bool `json:"asyncMode"`
RiverJobID int64 `json:"riverJobId,omitempty"`
Status string `json:"status"`
AttemptCount int `json:"attemptCount"`
RemoteTaskID string `json:"remoteTaskId,omitempty"`
RemoteTaskPayload map[string]any `json:"remoteTaskPayload,omitempty"`
Result map[string]any `json:"result,omitempty"`
Billings []any `json:"billings,omitempty"`
Usage map[string]any `json:"usage"`
@ -423,7 +433,9 @@ COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_
COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''),
COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model,
COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''),
request, status, COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
request, COALESCE(async_mode, false), COALESCE(river_job_id, 0), status, COALESCE(attempt_count, 0),
COALESCE(remote_task_id, ''), COALESCE(remote_task_payload, '{}'::jsonb),
COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
COALESCE(usage, '{}'::jsonb), COALESCE(metrics, '{}'::jsonb), COALESCE(billing_summary, '{}'::jsonb),
COALESCE(final_charge_amount, 0)::float8, COALESCE(response_started_at::text, ''),
COALESCE(response_finished_at::text, ''), COALESCE(response_duration_ms, 0), COALESCE(error, ''),
@ -1675,11 +1687,11 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut
INSERT INTO gateway_tasks (
kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key,
api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key,
model, requested_model, request, status, result, billings, finished_at
model, requested_model, request, async_mode, status, result, billings, finished_at
)
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17::jsonb, $18::jsonb, CASE WHEN $19 THEN now() ELSE NULL END)
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, CASE WHEN $20 THEN now() ELSE NULL END)
RETURNING `+gatewayTaskColumns,
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, status, resultBody, billingsBody, false,
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, false,
))
if err != nil {
return GatewayTask{}, err
@ -1689,7 +1701,7 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut
payload, _ := json.Marshal(event.Payload)
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
VALUES ($1::uuid, $2, $3, NULLIF($4, ''), NULLIF($5, ''), $6, NULLIF($7, ''), $8::jsonb, $9)`,
VALUES ($1::uuid, $2, $3::text, NULLIF($4::text, ''), NULLIF($5::text, ''), $6, NULLIF($7::text, ''), $8::jsonb, $9)`,
task.ID, event.Seq, event.EventType, event.Status, event.Phase, event.Progress, event.Message, string(payload), event.Simulated,
); err != nil {
return GatewayTask{}, err
@ -1730,6 +1742,7 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
var usageBytes []byte
var metricsBytes []byte
var billingSummaryBytes []byte
var remoteTaskPayloadBytes []byte
if err := scanner.Scan(
&task.ID,
&task.Kind,
@ -1751,7 +1764,12 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
&task.ResolvedModel,
&task.RequestID,
&requestBytes,
&task.AsyncMode,
&task.RiverJobID,
&task.Status,
&task.AttemptCount,
&task.RemoteTaskID,
&remoteTaskPayloadBytes,
&resultBytes,
&billingsBytes,
&usageBytes,
@ -1771,6 +1789,7 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
return GatewayTask{}, err
}
task.Request = decodeObject(requestBytes)
task.RemoteTaskPayload = decodeObject(remoteTaskPayloadBytes)
task.Result = decodeObject(resultBytes)
task.Billings = decodeArray(billingsBytes)
task.Usage = decodeObject(usageBytes)

View File

@ -31,6 +31,7 @@ type ModelRateLimitStatus struct {
PlatformCooldownUntil string `json:"platformCooldownUntil,omitempty"`
ModelCooldownUntil string `json:"modelCooldownUntil,omitempty"`
Concurrent RateLimitMetricStatus `json:"concurrent"`
QueuedTasks float64 `json:"queuedTasks"`
RPM RateLimitMetricStatus `json:"rpm"`
TPM RateLimitMetricStatus `json:"tpm"`
LoadRatio float64 `json:"loadRatio"`
@ -38,15 +39,16 @@ type ModelRateLimitStatus struct {
func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) {
rows, err := s.pool.Query(ctx, `
SELECT m.id::text, m.platform_id::text, p.name, p.provider,
m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''),
m.model_type, m.display_name, m.enabled,
p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy,
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
COALESCE(con.active, 0)::float8,
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''),
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '')
SELECT m.id::text, m.platform_id::text, p.name, p.provider,
m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''),
m.model_type, m.display_name, m.enabled,
p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy,
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
COALESCE(con.active, 0)::float8,
COALESCE(queued.waiting, 0)::float8,
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''),
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '')
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
@ -59,6 +61,19 @@ LEFT JOIN (
AND expires_at > now()
GROUP BY scope_key
) con ON con.scope_key = m.id::text
LEFT JOIN (
SELECT latest.platform_model_id, COUNT(*) AS waiting
FROM (
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
FROM gateway_tasks t
JOIN gateway_task_attempts a ON a.task_id = t.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND a.platform_model_id IS NOT NULL
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
) latest
GROUP BY latest.platform_model_id
) queued ON queued.platform_model_id = m.id::text
LEFT JOIN (
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value, reset_at
FROM gateway_rate_limit_counters
@ -93,6 +108,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
var platformCooldownUntil string
var modelCooldownUntil string
var concurrentCurrent float64
var queuedTasks float64
var rpmUsed float64
var rpmReserved float64
var rpmResetAt string
@ -117,6 +133,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
&platformCooldownUntil,
&modelCooldownUntil,
&concurrentCurrent,
&queuedTasks,
&rpmUsed,
&rpmReserved,
&rpmResetAt,
@ -136,6 +153,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
item.PlatformCooldownUntil = platformCooldownUntil
item.ModelCooldownUntil = modelCooldownUntil
item.RateLimitPolicy = policy
item.QueuedTasks = queuedTasks
item.Concurrent = metricStatus(concurrentCurrent, concurrentCurrent, 0, rateLimitForMetric(policy, "concurrent"), "")
item.RPM = metricStatus(rpmUsed+rpmReserved, rpmUsed, rpmReserved, rateLimitForMetric(policy, "rpm"), rpmResetAt)
item.TPM = metricStatus(tpmUsed+tpmReserved, tpmUsed, tpmReserved, tpmLimit(policy), tpmResetAt)

View File

@ -3,6 +3,7 @@ package store
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
@ -13,6 +14,7 @@ type RuntimeRecoveryResult struct {
ReleasedRateReservations int64 `json:"releasedRateReservations"`
FailedAttempts int64 `json:"failedAttempts"`
FailedTasks int64 `json:"failedTasks"`
RequeuedAsyncTasks int64 `json:"requeuedAsyncTasks"`
}
func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID string, reservations []RateLimitReservation) (RateLimitResult, error) {
@ -28,7 +30,11 @@ func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID
continue
}
if reservation.Metric == "" || reservation.Amount > reservation.Limit {
return RateLimitResult{}, ErrRateLimited
return RateLimitResult{}, &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit),
Retryable: false,
}
}
if reservation.WindowSeconds <= 0 {
reservation.WindowSeconds = 60
@ -55,8 +61,10 @@ func reserveConcurrencyLease(ctx context.Context, tx pgx.Tx, taskID string, atte
reservation.LeaseTTLSeconds = 120
}
var active float64
var nextAvailableAt time.Time
if err := tx.QueryRow(ctx, `
SELECT COALESCE(SUM(lease_value), 0)::float8
SELECT COALESCE(SUM(lease_value), 0)::float8,
COALESCE(MIN(expires_at), now() + ($3::int * interval '1 second'))
FROM gateway_concurrency_leases
WHERE scope_type = $1
AND scope_key = $2
@ -64,11 +72,17 @@ WHERE scope_type = $1
AND expires_at > now()`,
reservation.ScopeType,
reservation.ScopeKey,
).Scan(&active); err != nil {
reservation.LeaseTTLSeconds,
).Scan(&active, &nextAvailableAt); err != nil {
return "", err
}
if active+reservation.Amount > reservation.Limit {
return "", ErrRateLimited
return "", &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit),
RetryAfter: concurrencyRetryAfter(nextAvailableAt),
Retryable: true,
}
}
var leaseID string
if err := tx.QueryRow(ctx, `
@ -92,13 +106,16 @@ func reserveCounterWindow(ctx context.Context, tx pgx.Tx, taskID string, attempt
reservedAmount := reservation.Amount
var windowStart time.Time
err := tx.QueryRow(ctx, `
WITH bounds AS (
SELECT
to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) AS window_start,
to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) + ($7::int * interval '1 second') AS reset_at
)
INSERT INTO gateway_rate_limit_counters (
scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at
)
VALUES (
$1, $2, $3, date_trunc('minute', now()), $4, $5, $6,
date_trunc('minute', now()) + ($7::int * interval '1 second')
)
SELECT $1, $2, $3, bounds.window_start, $4, $5, $6, bounds.reset_at
FROM bounds
ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE
SET limit_value = EXCLUDED.limit_value,
used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value,
@ -117,7 +134,28 @@ RETURNING window_start`,
).Scan(&windowStart)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return RateLimitReservation{}, ErrRateLimited
resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second)
_ = tx.QueryRow(ctx, `
WITH bounds AS (
SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start
)
SELECT counters.reset_at
FROM gateway_rate_limit_counters counters
JOIN bounds ON counters.window_start = bounds.window_start
WHERE scope_type = $1
AND scope_key = $2
AND metric = $3`,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Metric,
reservation.WindowSeconds,
).Scan(&resetAt)
return RateLimitReservation{}, &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric),
RetryAfter: retryAfterUntil(resetAt),
Retryable: true,
}
}
return RateLimitReservation{}, err
}
@ -144,6 +182,28 @@ RETURNING id::text`,
return reservation, nil
}
func retryAfterUntil(when time.Time) time.Duration {
if when.IsZero() {
return 0
}
duration := time.Until(when)
if duration < time.Second {
return time.Second
}
return duration
}
func concurrencyRetryAfter(leaseExpiresAt time.Time) time.Duration {
if leaseExpiresAt.IsZero() {
return time.Second
}
duration := time.Until(leaseExpiresAt)
if duration <= time.Second {
return time.Second
}
return time.Second
}
func (s *Store) CommitRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64) error {
return s.finishRateLimitReservations(ctx, reservations, actualByMetric, "committed", "success")
}
@ -189,23 +249,26 @@ RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
reservations := make([]RateLimitReservation, 0)
for rows.Next() {
var reservation RateLimitReservation
if err := rows.Scan(&reservation.ScopeType, &reservation.ScopeKey, &reservation.Metric, &reservation.WindowStart, &reservation.Amount); err != nil {
rows.Close()
return RuntimeRecoveryResult{}, err
}
if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil {
rows.Close()
return RuntimeRecoveryResult{}, err
}
result.ReleasedRateReservations++
reservations = append(reservations, reservation)
}
if err := rows.Err(); err != nil {
rows.Close()
return RuntimeRecoveryResult{}, err
}
rows.Close()
for _, reservation := range reservations {
if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil {
return RuntimeRecoveryResult{}, err
}
}
result.ReleasedRateReservations = int64(len(reservations))
tag, err := tx.Exec(ctx, `
UPDATE gateway_concurrency_leases
@ -220,7 +283,7 @@ WHERE released_at IS NULL
tag, err = tx.Exec(ctx, `
UPDATE gateway_task_attempts
SET status = 'failed',
retryable = false,
retryable = true,
error_code = 'server_restarted',
error_message = 'attempt interrupted by service restart',
finished_at = now()
@ -230,6 +293,57 @@ WHERE status = 'running'`)
}
result.FailedAttempts = tag.RowsAffected()
asyncTaskRows, err := tx.Query(ctx, `
UPDATE gateway_tasks
SET status = 'queued',
error = NULL,
error_code = NULL,
error_message = NULL,
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
next_run_at = now(),
finished_at = NULL,
updated_at = now()
WHERE async_mode = true
AND status = 'running'
RETURNING id::text`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
asyncTaskIDs := make([]string, 0)
for asyncTaskRows.Next() {
var taskID string
if err := asyncTaskRows.Scan(&taskID); err != nil {
asyncTaskRows.Close()
return RuntimeRecoveryResult{}, err
}
asyncTaskIDs = append(asyncTaskIDs, taskID)
}
if err := asyncTaskRows.Err(); err != nil {
asyncTaskRows.Close()
return RuntimeRecoveryResult{}, err
}
asyncTaskRows.Close()
for _, taskID := range asyncTaskIDs {
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
VALUES (
$1::uuid,
COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1),
'task.recovered',
'queued',
'recovered',
0.2,
'async task recovered after service restart',
'{"code":"server_restarted"}'::jsonb,
false
)`, taskID); err != nil {
return RuntimeRecoveryResult{}, err
}
}
result.RequeuedAsyncTasks = int64(len(asyncTaskIDs))
taskRows, err := tx.Query(ctx, `
UPDATE gateway_tasks
SET status = 'failed',
@ -238,7 +352,8 @@ SET status = 'failed',
error_message = 'task interrupted by service restart',
finished_at = now(),
updated_at = now()
WHERE status IN ('queued', 'running')
WHERE async_mode = false
AND status = 'running'
RETURNING id::text`)
if err != nil {
return RuntimeRecoveryResult{}, err
@ -301,9 +416,9 @@ func (s *Store) finishRateLimitReservations(ctx context.Context, reservations []
var stored RateLimitReservation
err := tx.QueryRow(ctx, `
UPDATE gateway_rate_limit_reservations
SET status = $2,
reason = NULLIF($3, ''),
actual_amount = CASE WHEN $2 = 'committed' THEN $4 ELSE actual_amount END,
SET status = $2::text,
reason = NULLIF($3::text, ''),
actual_amount = CASE WHEN $2::text = 'committed' THEN $4 ELSE actual_amount END,
finalized_at = now(),
updated_at = now()
WHERE id = $1::uuid

View File

@ -32,6 +32,43 @@ func ModelCandidateErrorCode(err error) string {
return "no_model_candidate"
}
type RateLimitExceededError struct {
Metric string
Message string
RetryAfter time.Duration
Retryable bool
}
func (e *RateLimitExceededError) Error() string {
if strings.TrimSpace(e.Message) != "" {
return e.Message
}
if strings.TrimSpace(e.Metric) != "" {
return "rate limit exceeded for " + e.Metric
}
return ErrRateLimited.Error()
}
func (e *RateLimitExceededError) Unwrap() error {
return ErrRateLimited
}
func RateLimitRetryAfter(err error) time.Duration {
var limitErr *RateLimitExceededError
if errors.As(err, &limitErr) && limitErr.RetryAfter > 0 {
return limitErr.RetryAfter
}
return 0
}
func RateLimitRetryable(err error) bool {
var limitErr *RateLimitExceededError
if errors.As(err, &limitErr) {
return limitErr.Retryable
}
return errors.Is(err, ErrRateLimited)
}
type CreatePlatformModelInput struct {
PlatformID string `json:"platformId"`
BaseModelID string `json:"baseModelId"`
@ -132,6 +169,11 @@ type CreateTaskAttemptInput struct {
Metrics map[string]any
}
type AsyncTaskQueueItem struct {
ID string
Priority int
}
type FinishTaskAttemptInput struct {
AttemptID string
Status string

View File

@ -157,7 +157,7 @@ func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType st
_, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'running',
model_type = NULLIF($2, ''),
model_type = NULLIF($2::text, ''),
normalized_request = $3::jsonb,
locked_at = now(),
heartbeat_at = now(),
@ -166,6 +166,124 @@ WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON))
return err
}
func (s *Store) ClaimAsyncQueuedTask(ctx context.Context, workerID string) (GatewayTask, error) {
return scanGatewayTask(s.pool.QueryRow(ctx, `
WITH picked AS (
SELECT id AS task_id
FROM gateway_tasks
WHERE async_mode = true
AND status = 'queued'
AND next_run_at <= now()
ORDER BY priority ASC, created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE gateway_tasks t
SET status = 'running',
locked_by = NULLIF($1::text, ''),
locked_at = now(),
heartbeat_at = now(),
updated_at = now()
FROM picked
WHERE t.id = picked.task_id
RETURNING `+gatewayTaskColumns, workerID))
}
func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration) (GatewayTask, error) {
if delay < time.Second {
delay = time.Second
}
if delay > 10*time.Minute {
delay = 10 * time.Minute
}
nextRunAt := time.Now().Add(delay)
return scanGatewayTask(s.pool.QueryRow(ctx, `
UPDATE gateway_tasks
SET status = 'queued',
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
next_run_at = $2::timestamptz,
error = NULL,
error_code = NULL,
error_message = NULL,
updated_at = now()
WHERE id = $1::uuid
RETURNING `+gatewayTaskColumns, taskID, nextRunAt))
}
func (s *Store) SetTaskRiverJobID(ctx context.Context, taskID string, riverJobID int64) error {
if riverJobID <= 0 {
return nil
}
_, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET river_job_id = $2,
updated_at = now()
WHERE id = $1::uuid`, taskID, riverJobID)
return err
}
func (s *Store) SetTaskRemoteTask(ctx context.Context, taskID string, attemptID string, remoteTaskID string, payload map[string]any) error {
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
if _, err := tx.Exec(ctx, `
UPDATE gateway_tasks
SET remote_task_id = NULLIF($2::text, ''),
remote_task_payload = $3::jsonb,
updated_at = now()
WHERE id = $1::uuid`,
taskID,
remoteTaskID,
string(payloadJSON),
); err != nil {
return err
}
if strings.TrimSpace(attemptID) == "" {
return nil
}
_, err := tx.Exec(ctx, `
UPDATE gateway_task_attempts
SET remote_task_id = NULLIF($2::text, ''),
response_snapshot = COALESCE(response_snapshot, '{}'::jsonb) || jsonb_build_object('remote_task_payload', $3::jsonb)
WHERE id = $1::uuid`,
attemptID,
remoteTaskID,
string(payloadJSON),
)
return err
})
}
func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) {
if limit <= 0 {
limit = 500
}
rows, err := s.pool.Query(ctx, `
SELECT id::text, priority
FROM gateway_tasks
WHERE async_mode = true
AND status IN ('queued', 'running')
ORDER BY priority ASC, created_at ASC
LIMIT $1`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]AsyncTaskQueueItem, 0)
for rows.Next() {
var item AsyncTaskQueueItem
if err := rows.Scan(&item.ID, &item.Priority); err != nil {
return nil, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) {
requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot))
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
@ -182,7 +300,7 @@ INSERT INTO gateway_task_attempts (
status, simulated, request_snapshot, metrics
)
VALUES (
$1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6,
$1::uuid, $2::int, NULLIF($3::text, '')::uuid, NULLIF($4::text, '')::uuid, NULLIF($5::text, ''), $6,
$7, $8, $9::jsonb, $10::jsonb
)
RETURNING id::text`,
@ -202,7 +320,7 @@ RETURNING id::text`,
}
if _, err := tx.Exec(ctx, `
UPDATE gateway_tasks
SET attempt_count = GREATEST(attempt_count, $2), updated_at = now()
SET attempt_count = GREATEST(attempt_count, $2::int), updated_at = now()
WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil {
return "", err
}
@ -252,7 +370,7 @@ SET metrics = jsonb_set(
true
)
WHERE task_id = $1::uuid
AND attempt_no = $2`, taskID, attemptNo, string(entryJSON))
AND attempt_no = $2::int`, taskID, attemptNo, string(entryJSON))
return err
}
@ -386,17 +504,17 @@ func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptIn
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
_, err := s.pool.Exec(ctx, `
UPDATE gateway_task_attempts
SET status = $2,
SET status = $2::text,
retryable = $3,
request_id = NULLIF($4, ''),
request_id = NULLIF($4::text, ''),
usage = $5::jsonb,
metrics = $6::jsonb,
response_snapshot = $7::jsonb,
response_started_at = $8::timestamptz,
response_finished_at = $9::timestamptz,
response_duration_ms = $10,
error_code = NULLIF($11, ''),
error_message = NULLIF($12, ''),
error_code = NULLIF($11::text, ''),
error_message = NULLIF($12::text, ''),
finished_at = now()
WHERE id = $1::uuid`,
input.AttemptID,
@ -438,6 +556,9 @@ SET status = 'succeeded',
error = NULL,
error_code = NULL,
error_message = NULL,
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
finished_at = now(),
updated_at = now()
WHERE id = $1::uuid`,
@ -561,14 +682,17 @@ func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureIn
if _, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'failed',
error = NULLIF($2, ''),
error_code = NULLIF($3, ''),
error_message = NULLIF($2, ''),
request_id = NULLIF($4, ''),
error = NULLIF($2::text, ''),
error_code = NULLIF($3::text, ''),
error_message = NULLIF($2::text, ''),
request_id = NULLIF($4::text, ''),
metrics = $5::jsonb,
response_started_at = $6::timestamptz,
response_finished_at = $7::timestamptz,
response_duration_ms = $8,
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
finished_at = now(),
updated_at = now()
WHERE id = $1::uuid`,
@ -604,7 +728,7 @@ WITH next_seq AS (
WHERE task_id = $1::uuid
)
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8
SELECT $1::uuid, next_seq.seq, $2::text, NULLIF($3::text, ''), NULLIF($4::text, ''), $5, NULLIF($6::text, ''), $7::jsonb, $8
FROM next_seq
RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''),
COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`,
@ -688,7 +812,7 @@ func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastEr
_, err := s.pool.Exec(ctx, `
UPDATE runtime_client_states
SET running_count = GREATEST(running_count - 1, 0),
last_error = NULLIF($2, ''),
last_error = NULLIF($2::text, ''),
updated_at = now()
WHERE client_id = $1`, clientID, lastError)
return err

View File

@ -0,0 +1,6 @@
ALTER TABLE IF EXISTS gateway_tasks
ADD COLUMN IF NOT EXISTS async_mode boolean NOT NULL DEFAULT false;
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_async_queue
ON gateway_tasks(async_mode, status, next_run_at, priority, created_at)
WHERE async_mode = true;

View File

@ -0,0 +1,19 @@
ALTER TABLE IF EXISTS gateway_tasks
ADD COLUMN IF NOT EXISTS river_job_id bigint,
ADD COLUMN IF NOT EXISTS remote_task_id text,
ADD COLUMN IF NOT EXISTS remote_task_payload jsonb;
UPDATE gateway_tasks
SET remote_task_payload = '{}'::jsonb
WHERE remote_task_payload IS NULL;
ALTER TABLE IF EXISTS gateway_tasks
ALTER COLUMN remote_task_payload SET DEFAULT '{}'::jsonb;
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_river_job
ON gateway_tasks(river_job_id)
WHERE river_job_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_async_recover
ON gateway_tasks(async_mode, status, priority, created_at)
WHERE async_mode = true AND status IN ('queued', 'running');

View File

@ -46,7 +46,6 @@ import {
getHealth,
getNetworkProxyConfig,
getRunnerPolicy,
getTask,
getWalletSummary,
listAccessRules,
listAuditLogs,
@ -72,6 +71,7 @@ import {
listUserGroups,
listUsers,
loginLocalAccount,
pollTaskUntilSettled,
registerLocalAccount,
replacePlatformModels,
setUserWalletBalance,
@ -788,7 +788,11 @@ export function App() {
setCoreMessage('');
try {
const response = await runTask(credential, taskForm);
const detail = await getTask(credential, response.task.id);
const syncTask = (detail: GatewayTask) => {
setTaskResult(detail);
setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]);
};
const detail = await pollTaskUntilSettled(credential, response.task, { onUpdate: syncTask });
setTaskResult(detail);
setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]);
invalidateDataKeys('tasks', 'wallet', 'walletTransactions');

View File

@ -521,6 +521,7 @@ export async function createChatTask(
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/chat/completions', {
body: input,
headers: { 'X-Async': 'true' },
method: 'POST',
token,
});
@ -598,6 +599,7 @@ export async function createImageGenerationTask(
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/generations', {
body: input,
headers: { 'X-Async': 'true' },
method: 'POST',
token,
});
@ -609,6 +611,7 @@ export async function createImageEditTask(
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/edits', {
body: input,
headers: { 'X-Async': 'true' },
method: 'POST',
token,
});
@ -636,6 +639,7 @@ export async function createVideoGenerationTask(
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/videos/generations', {
body: input,
headers: { 'X-Async': 'true' },
method: 'POST',
token,
});
@ -656,6 +660,33 @@ export async function getTask(token: string, taskId: string): Promise<GatewayTas
return request<GatewayTask>(`/api/workspace/tasks/${taskId}`, { token });
}
export async function pollTaskUntilSettled(
token: string,
task: GatewayTask,
options: { intervalMs?: number; maxAttempts?: number | null; onUpdate?: (task: GatewayTask) => void } = {},
): Promise<GatewayTask> {
let detail = task;
const intervalMs = options.intervalMs ?? 1200;
const maxAttempts = options.maxAttempts ?? Number.POSITIVE_INFINITY;
for (let attempt = 0; attempt < maxAttempts; attempt += 1) {
if (!taskIsPending(detail.status)) return detail;
try {
detail = await getTask(token, detail.id);
options.onUpdate?.(detail);
if (!taskIsPending(detail.status)) return detail;
} catch {
// Backend restarts or short network gaps should not turn a durable task into a failed UI run.
// Only an explicit terminal task status from the task detail endpoint settles the run.
}
await delay(intervalMs);
}
return detail;
}
export function taskIsPending(status: string) {
return status === 'queued' || status === 'running' || status === 'submitting';
}
export async function listTasks(token: string, query: WorkspaceTaskQuery): Promise<ListResponse<GatewayTask>> {
const search = new URLSearchParams({
page: String(query.page),
@ -707,9 +738,9 @@ export async function getNetworkProxyConfig(token: string): Promise<GatewayNetwo
async function request<T>(
path: string,
options: { token?: string; auth?: boolean; method?: string; body?: unknown } = {},
options: { token?: string; auth?: boolean; method?: string; body?: unknown; headers?: Record<string, string> } = {},
): Promise<T> {
const headers: Record<string, string> = {};
const headers: Record<string, string> = { ...(options.headers ?? {}) };
if (options.auth !== false && options.token) {
headers.Authorization = `Bearer ${options.token}`;
}
@ -731,6 +762,10 @@ async function request<T>(
return response.json() as Promise<T>;
}
function delay(ms: number) {
return new Promise((resolve) => window.setTimeout(resolve, ms));
}
function parseErrorMessage(body: string) {
return formatGatewayErrorDetails(parseErrorDetails(body));
}

View File

@ -21,7 +21,7 @@ import { mermaid } from '@streamdown/mermaid';
import type { GatewayApiKey, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts';
import { Bot, ChevronDown, Image as ImageIcon, MessageSquarePlus, Paperclip, Send, Sparkles, Video } from 'lucide-react';
import { Badge, Button, Select, Textarea } from '../components/ui';
import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, getTask, streamChatCompletionText } from '../api';
import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, pollTaskUntilSettled, streamChatCompletionText, taskIsPending } from '../api';
import type { PlaygroundMode } from '../types';
import {
defaultMediaGenerationSettings,
@ -170,14 +170,12 @@ export function PlaygroundPage(props: {
resumableRuns.forEach((run) => {
if (!run.task?.id) return;
resumedTaskIdsRef.current.add(run.task.id);
void pollTaskUntilSettled(credential, run.task)
void pollTaskUntilSettled(credential, run.task, {
onUpdate: (detail) => updateMediaRunFromTask(run.localId, detail),
})
.then((detail) => {
if (!isMountedRef.current) return;
setMediaRuns((current) => updateMediaRun(current, run.localId, {
error: gatewayTaskErrorText(detail, '任务执行失败'),
status: detail.status,
task: detail,
}));
updateMediaRunFromTask(run.localId, detail);
})
.catch((err) => {
if (!isMountedRef.current) return;
@ -249,13 +247,11 @@ export function PlaygroundPage(props: {
async function pollMediaRunUntilSettled(credential: string, localId: string, task: GatewayTask) {
try {
const detail = await pollTaskUntilSettled(credential, task);
const detail = await pollTaskUntilSettled(credential, task, {
onUpdate: (nextTask) => updateMediaRunFromTask(localId, nextTask),
});
if (!isMountedRef.current) return;
setMediaRuns((current) => updateMediaRun(current, localId, {
error: gatewayTaskErrorText(detail, '任务执行失败'),
status: detail.status,
task: detail,
}));
updateMediaRunFromTask(localId, detail);
} catch (err) {
if (!isMountedRef.current) return;
const errorMessage = err instanceof Error ? err.message : '任务状态同步失败';
@ -263,6 +259,15 @@ export function PlaygroundPage(props: {
}
}
function updateMediaRunFromTask(localId: string, task: GatewayTask) {
if (!isMountedRef.current) return;
setMediaRuns((current) => updateMediaRun(current, localId, {
error: taskIsPending(task.status) ? '' : gatewayTaskErrorText(task, '任务执行失败'),
status: task.status,
task,
}));
}
function editMediaRun(run: MediaGenerationRun) {
setPrompt(run.prompt);
setMediaSettings(run.settings);
@ -1042,20 +1047,6 @@ function updateMediaRun(runs: MediaGenerationRun[], localId: string, patch: Part
return runs.map((run) => run.localId === localId ? { ...run, ...patch } : run);
}
async function pollTaskUntilSettled(token: string, task: GatewayTask) {
let detail = task;
for (let attempt = 0; attempt < 20; attempt += 1) {
detail = await getTask(token, detail.id);
if (!taskIsPending(detail.status)) return detail;
await delay(1200);
}
return detail;
}
function taskIsPending(status: string) {
return status === 'queued' || status === 'running' || status === 'submitting';
}
function readStoredMediaRuns(): MediaGenerationRun[] {
if (typeof window === 'undefined') return [];
try {
@ -1182,10 +1173,6 @@ function booleanFromUnknown(value: unknown, fallback: boolean) {
return fallback;
}
function delay(ms: number) {
return new Promise((resolve) => window.setTimeout(resolve, ms));
}
function newLocalId() {
return typeof crypto !== 'undefined' && 'randomUUID' in crypto
? crypto.randomUUID()

View File

@ -593,11 +593,14 @@ function RateLimitStatusTable(props: { statuses: ModelRateLimitStatus[]; platfor
<TableRow className="shTableHeader">
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead>TPM</TableHead>
<TableHead>RPM</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead className="platformLimitMetricHead platformLimitNumberHead" title="正在执行 / 并发上限 / 排队任务">
<span></span>
<small> / / </small>
</TableHead>
<TableHead className="platformLimitNumberHead">TPM</TableHead>
<TableHead className="platformLimitNumberHead">RPM</TableHead>
<TableHead className="platformLimitStatusHead"></TableHead>
<TableHead className="platformLimitNumberHead"></TableHead>
</TableRow>
{props.statuses.map((status) => {
const platform = props.platformMap.get(status.platformId);
@ -615,12 +618,12 @@ function RateLimitStatusTable(props: { statuses: ModelRateLimitStatus[]; platfor
<small>{status.provider}</small>
</span>
</TableCell>
<TableCell>{metricCell(status.concurrent)}</TableCell>
<TableCell>{metricCell(status.tpm, true)}</TableCell>
<TableCell>{metricCell(status.rpm)}</TableCell>
<TableCell>{modelRuntimeStatusCell(status, props.now)}</TableCell>
<TableCell>
<span className="rateLoadCell">
<TableCell className="platformLimitNumberCell">{concurrencyMetricCell(status)}</TableCell>
<TableCell className="platformLimitNumberCell">{metricCell(status.tpm, true)}</TableCell>
<TableCell className="platformLimitNumberCell">{metricCell(status.rpm)}</TableCell>
<TableCell className="platformLimitStatusCell">{modelRuntimeStatusCell(status, props.now)}</TableCell>
<TableCell className="platformLimitNumberCell">
<span className="rateLoadCell" data-overloaded={status.loadRatio > 0.8 ? 'true' : undefined}>
<strong>{formatPercent(status.loadRatio)}</strong>
<span className="rateLoadTrack"><i style={{ width: `${Math.min(status.loadRatio * 100, 100)}%` }} /></span>
</span>
@ -1210,6 +1213,16 @@ function metricCell(metric: ModelRateLimitStatus['rpm'], includeReserved = false
);
}
function concurrencyMetricCell(status: ModelRateLimitStatus) {
const queuedTasks = status.queuedTasks ?? 0;
const limitText = status.concurrent.limited ? formatLimit(status.concurrent.limitValue) : '不限';
return (
<span className="rateMetricCell" title="正在执行 / 并发上限 / 排队任务">
<strong>{formatLimit(status.concurrent.currentValue)} / {limitText} / {formatLimit(queuedTasks)}</strong>
</span>
);
}
function reservedMetricText(metric: ModelRateLimitStatus['rpm']) {
return `已结算 ${formatLimit(metric.usedValue)} + 预占 ${formatLimit(metric.reservedValue)}`;
}

View File

@ -1016,27 +1016,65 @@
}
.platformLimitTable .shTableRow {
grid-template-columns: clamp(150px, 16vw, 220px) minmax(148px, 1fr) minmax(104px, max-content) minmax(136px, max-content) minmax(122px, max-content) minmax(132px, max-content) minmax(128px, max-content);
min-width: 920px;
grid-template-columns: minmax(180px, 1.15fr) minmax(160px, 0.95fr) 150px 170px 140px 132px 132px;
min-width: 1064px;
}
.platformLimitTable .shTableHead,
.platformLimitTable .shTableCell {
display: grid;
align-content: center;
min-height: 68px;
padding-right: 10px;
padding-left: 10px;
}
.platformLimitTable .shTableHead {
min-height: 74px;
}
.platformLimitMetricHead {
display: grid;
align-content: center;
gap: 3px;
white-space: normal;
}
.platformLimitMetricHead small {
overflow: hidden;
color: var(--muted-foreground);
font-size: var(--font-size-xs);
font-weight: var(--font-weight-medium);
line-height: 1.2;
text-overflow: ellipsis;
white-space: nowrap;
}
.platformLimitNumberHead,
.platformLimitNumberCell {
justify-items: start;
}
.platformLimitStatusHead,
.platformLimitStatusCell {
justify-items: start;
}
.rateMetricCell,
.rateLoadCell {
display: grid;
min-width: 0;
gap: 4px;
width: 100%;
align-content: start;
font-variant-numeric: tabular-nums;
}
.rateMetricCell strong,
.rateLoadCell strong {
color: var(--text-strong);
font-size: var(--font-size-sm);
line-height: 1.25;
}
.rateMetricCell small {
@ -1050,6 +1088,7 @@
.rateLoadTrack {
display: block;
height: 6px;
width: min(112px, 100%);
overflow: hidden;
border-radius: 999px;
background: #eef2f6;
@ -1062,6 +1101,18 @@
background: #0f766e;
}
.rateLoadCell[data-overloaded="true"] strong {
color: var(--destructive);
}
.rateLoadCell[data-overloaded="true"] .rateLoadTrack {
background: #fee2e2;
}
.rateLoadCell[data-overloaded="true"] .rateLoadTrack i {
background: var(--destructive);
}
.platformModelToolbar {
display: grid;
grid-template-columns: minmax(220px, 0.6fr) minmax(260px, 1fr);

View File

@ -1,5 +1,3 @@
go 1.23
use (
./apps/api
)
use ./apps/api

View File

@ -1,16 +1,12 @@
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@ -775,6 +775,7 @@ export interface ModelRateLimitStatus {
platformCooldownUntil?: string;
modelCooldownUntil?: string;
concurrent: RateLimitMetricStatus;
queuedTasks: number;
rpm: RateLimitMetricStatus;
tpm: RateLimitMetricStatus;
loadRatio: number;
@ -808,6 +809,7 @@ export interface GatewayTask {
resolvedModel?: string;
requestId?: string;
request?: Record<string, unknown>;
asyncMode?: boolean;
status: 'queued' | 'running' | 'succeeded' | 'failed' | 'cancelled' | string;
result?: Record<string, unknown>;
billings?: unknown[];