feat: add river-backed async task queue

This commit is contained in:
wangbo 2026-05-12 10:11:54 +08:00
parent d69aaed444
commit 7e220b7477
30 changed files with 1342 additions and 200 deletions

View File

@ -33,18 +33,19 @@ func main() {
if recovery, err := db.RecoverInterruptedRuntimeState(ctx); err != nil { if recovery, err := db.RecoverInterruptedRuntimeState(ctx); err != nil {
logger.Error("recover interrupted runtime state failed", "error", err) logger.Error("recover interrupted runtime state failed", "error", err)
os.Exit(1) os.Exit(1)
} else if recovery.ReleasedConcurrencyLeases > 0 || recovery.ReleasedRateReservations > 0 || recovery.FailedAttempts > 0 || recovery.FailedTasks > 0 { } else if recovery.ReleasedConcurrencyLeases > 0 || recovery.ReleasedRateReservations > 0 || recovery.FailedAttempts > 0 || recovery.FailedTasks > 0 || recovery.RequeuedAsyncTasks > 0 {
logger.Warn("interrupted runtime state recovered", logger.Warn("interrupted runtime state recovered",
"releasedConcurrencyLeases", recovery.ReleasedConcurrencyLeases, "releasedConcurrencyLeases", recovery.ReleasedConcurrencyLeases,
"releasedRateReservations", recovery.ReleasedRateReservations, "releasedRateReservations", recovery.ReleasedRateReservations,
"failedAttempts", recovery.FailedAttempts, "failedAttempts", recovery.FailedAttempts,
"failedTasks", recovery.FailedTasks, "failedTasks", recovery.FailedTasks,
"requeuedAsyncTasks", recovery.RequeuedAsyncTasks,
) )
} }
server := &http.Server{ server := &http.Server{
Addr: cfg.HTTPAddr, Addr: cfg.HTTPAddr,
Handler: httpapi.NewServer(cfg, db, logger), Handler: httpapi.NewServerWithContext(ctx, cfg, db, logger),
ReadHeaderTimeout: 10 * time.Second, ReadHeaderTimeout: 10 * time.Second,
} }

View File

@ -4,14 +4,28 @@ go 1.23
require ( require (
github.com/golang-jwt/jwt/v5 v5.2.2 github.com/golang-jwt/jwt/v5 v5.2.2
github.com/jackc/pgx/v5 v5.7.2 github.com/jackc/pgx/v5 v5.9.2
golang.org/x/crypto v0.31.0 github.com/riverqueue/river v0.24.0
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0
github.com/riverqueue/river/rivertype v0.24.0
golang.org/x/crypto v0.37.0
) )
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/sync v0.10.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/text v0.21.0 // indirect github.com/riverqueue/river/riverdriver v0.24.0 // indirect
github.com/riverqueue/river/rivershared v0.24.0 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
go.uber.org/goleak v1.3.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/text v0.36.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@ -3,28 +3,63 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0=
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/riverqueue/river v0.24.0 h1:CesL6vymWgz0d+zNwtnSGRWaB+E8Dax+o9cxD7sUmKc=
github.com/riverqueue/river v0.24.0/go.mod h1:UZ3AxU5t6WtyqNssaea/AkRS8h/kJ+E9ImSB3xyb3ns=
github.com/riverqueue/river/riverdriver v0.24.0 h1:HqGgGkls11u+YKDA7cKOdYKlQwRNJyHuGa3UtOvpdT0=
github.com/riverqueue/river/riverdriver v0.24.0/go.mod h1:dEew9DDIKenNvzpm8Edw8+PkqP3c0zl1fKjiQTq2n/w=
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0 h1:yV37OIbRrhRwIiGeRT7P4D3szhAemu87BgCf8gTCoU4=
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0/go.mod h1:QfznySVKC4ljx53syd/bA/LRSsydAyuD3Q9/EbSniKA=
github.com/riverqueue/river/rivershared v0.24.0 h1:KysokksW75pug2a5RTOc6WESOupWmsylVc6VWvAx+4Y=
github.com/riverqueue/river/rivershared v0.24.0/go.mod h1:UIBfSdai0oWFlwFcoqG4DZX83iA/fLWTEBGrj7Oe1ho=
github.com/riverqueue/river/rivertype v0.24.0 h1:xrQZm/h6U8TBPyTsQPYD5leOapuoBAcdz30bdBwTqOg=
github.com/riverqueue/river/rivertype v0.24.0/go.mod h1:lmdl3vLNDfchDWbYdW2uAocIuwIN+ZaXqAukdSCFqWs=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -329,6 +329,7 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
var gotModel string var gotModel string
var gotText string var gotText string
var gotFirstFrameRole string var gotFirstFrameRole string
var submittedRemoteTaskID string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization") gotAuth = r.Header.Get("Authorization")
switch r.Method + " " + r.URL.Path { switch r.Method + " " + r.URL.Path {
@ -385,6 +386,13 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
"volcesPollTimeoutSeconds": 1, "volcesPollTimeoutSeconds": 1,
}, },
}, },
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
submittedRemoteTaskID = remoteTaskID
if payload["id"] != "cgt-test" {
t.Fatalf("unexpected submitted payload: %+v", payload)
}
return nil
},
}) })
if err != nil { if err != nil {
t.Fatalf("run volces video: %v", err) t.Fatalf("run volces video: %v", err)
@ -392,6 +400,9 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
if submitPath != "/contents/generations/tasks" || pollPath != "/contents/generations/tasks/cgt-test" || gotAuth != "Bearer volces-key" { if submitPath != "/contents/generations/tasks" || pollPath != "/contents/generations/tasks/cgt-test" || gotAuth != "Bearer volces-key" {
t.Fatalf("unexpected paths/auth submit=%s poll=%s auth=%s", submitPath, pollPath, gotAuth) t.Fatalf("unexpected paths/auth submit=%s poll=%s auth=%s", submitPath, pollPath, gotAuth)
} }
if submittedRemoteTaskID != "cgt-test" {
t.Fatalf("remote task submit callback did not receive task id, got %q", submittedRemoteTaskID)
}
if gotModel != "doubao-seedance-2-0-260128" || gotFirstFrameRole != "first_frame" { if gotModel != "doubao-seedance-2-0-260128" || gotFirstFrameRole != "first_frame" {
t.Fatalf("unexpected submitted model=%s role=%s", gotModel, gotFirstFrameRole) t.Fatalf("unexpected submitted model=%s role=%s", gotModel, gotFirstFrameRole)
} }
@ -407,6 +418,53 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
} }
} }
func TestVolcesClientVideoResumePollsExistingTaskID(t *testing.T) {
var submitCalled bool
var pollPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method + " " + r.URL.Path {
case "POST /contents/generations/tasks":
submitCalled = true
t.Fatalf("resume should skip upstream submit when remote task id exists")
case "GET /contents/generations/tasks/cgt-existing":
pollPath = r.URL.Path
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "cgt-existing",
"status": "succeeded",
"created_at": 789,
"content": map[string]any{"video_url": "https://example.com/resumed.mp4"},
})
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
}))
defer server.Close()
response, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "豆包Seedance-2.0",
Body: map[string]any{"prompt": "resume polling", "pollIntervalMs": 100, "pollTimeoutSeconds": 1},
RemoteTaskID: "cgt-existing",
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ModelName: "豆包Seedance-2.0",
Credentials: map[string]any{"apiKey": "volces-key"},
},
})
if err != nil {
t.Fatalf("resume volces video: %v", err)
}
if submitCalled || pollPath != "/contents/generations/tasks/cgt-existing" {
t.Fatalf("resume should poll existing task only, submit=%v poll=%s", submitCalled, pollPath)
}
data, _ := response.Result["data"].([]any)
item, _ := data[0].(map[string]any)
if response.Result["upstream_task_id"] != "cgt-existing" || item["url"] != "https://example.com/resumed.mp4" {
t.Fatalf("unexpected resumed response: %+v", response.Result)
}
}
func extractText(result map[string]any) string { func extractText(result map[string]any) string {
choices, _ := result["choices"].([]any) choices, _ := result["choices"].([]any)
choice, _ := choices[0].(map[string]any) choice, _ := choices[0].(map[string]any)

View File

@ -11,14 +11,17 @@ import (
) )
type Request struct { type Request struct {
Kind string Kind string
ModelType string ModelType string
Model string Model string
Body map[string]any Body map[string]any
Candidate store.RuntimeModelCandidate Candidate store.RuntimeModelCandidate
HTTPClient *http.Client HTTPClient *http.Client
Stream bool RemoteTaskID string
StreamDelta StreamDelta RemoteTaskPayload map[string]any
OnRemoteTaskSubmitted func(remoteTaskID string, payload map[string]any) error
Stream bool
StreamDelta StreamDelta
} }
type Response struct { type Response struct {

View File

@ -67,16 +67,25 @@ func (c VolcesClient) runImage(ctx context.Context, request Request, apiKey stri
} }
func (c VolcesClient) runVideo(ctx context.Context, request Request, apiKey string) (Response, error) { func (c VolcesClient) runVideo(ctx context.Context, request Request, apiKey string) (Response, error) {
body := volcesVideoBody(request)
submitStartedAt := time.Now() submitStartedAt := time.Now()
submitResult, submitRequestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body) submitRequestID := strings.TrimSpace(request.RemoteTaskID)
submitFinishedAt := time.Now() upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
if err != nil {
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, submitFinishedAt)
}
upstreamTaskID := strings.TrimSpace(stringFromAny(submitResult["id"]))
if upstreamTaskID == "" { if upstreamTaskID == "" {
return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false} body := volcesVideoBody(request)
submitResult, requestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body)
submitRequestID = requestID
if err != nil {
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, time.Now())
}
upstreamTaskID = strings.TrimSpace(stringFromAny(submitResult["id"]))
if upstreamTaskID == "" {
return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false}
}
if request.OnRemoteTaskSubmitted != nil {
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, submitResult); err != nil {
return Response{}, err
}
}
} }
interval := volcesPollInterval(request) interval := volcesPollInterval(request)

View File

@ -18,6 +18,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config" "github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
@ -37,7 +38,11 @@ func TestCoreLocalFlow(t *testing.T) {
} }
defer db.Close() defer db.Close()
handler := NewServer(config.Config{ assertRuntimeRecoveryReleasesPendingRateReservations(t, ctx, db)
serverCtx, cancelServer := context.WithCancel(ctx)
defer cancelServer()
handler := NewServerWithContext(serverCtx, config.Config{
AppEnv: "test", AppEnv: "test",
HTTPAddr: ":0", HTTPAddr: ":0",
DatabaseURL: databaseURL, DatabaseURL: databaseURL,
@ -169,13 +174,14 @@ func TestCoreLocalFlow(t *testing.T) {
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountBefore); err != nil { if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountBefore); err != nil {
t.Fatalf("count tasks before scoped request: %v", err) t.Fatalf("count tasks before scoped request: %v", err)
} }
defaultImageModel := "openai:gpt-image-1"
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", chatOnlyAPIKeyResponse.Secret, map[string]any{ doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", chatOnlyAPIKeyResponse.Secret, map[string]any{
"model": "gpt-image-1", "model": defaultImageModel,
"prompt": "scope should block this", "prompt": "scope should block this",
}, http.StatusForbidden, nil) }, http.StatusForbidden, nil)
doJSON(t, server.URL, http.MethodPost, "/api/v1/pricing/estimate", chatOnlyAPIKeyResponse.Secret, map[string]any{ doJSON(t, server.URL, http.MethodPost, "/api/v1/pricing/estimate", chatOnlyAPIKeyResponse.Secret, map[string]any{
"kind": "images.generations", "kind": "images.generations",
"model": "gpt-image-1", "model": defaultImageModel,
"prompt": "scope should block this estimate", "prompt": "scope should block this estimate",
}, http.StatusForbidden, nil) }, http.StatusForbidden, nil)
var taskCountAfter int var taskCountAfter int
@ -309,8 +315,9 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
Result map[string]any `json:"result"` Result map[string]any `json:"result"`
} `json:"task"` } `json:"task"`
} }
defaultTextModel := "openai:gpt-4o-mini"
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": "gpt-4o-mini", "model": defaultTextModel,
"runMode": "simulation", "runMode": "simulation",
"simulation": true, "simulation": true,
"simulationDurationMs": 5, "simulationDurationMs": 5,
@ -334,7 +341,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
var compatChat map[string]any var compatChat map[string]any
doJSON(t, server.URL, http.MethodPost, "/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ doJSON(t, server.URL, http.MethodPost, "/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": "gpt-4o-mini", "model": defaultTextModel,
"runMode": "simulation", "runMode": "simulation",
"messages": []map[string]any{{"role": "user", "content": "ping"}}, "messages": []map[string]any{{"role": "user", "content": "ping"}},
"simulation": true, "simulation": true,
@ -352,7 +359,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
} `json:"task"` } `json:"task"`
} }
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", apiKeyResponse.Secret, map[string]any{ doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", apiKeyResponse.Secret, map[string]any{
"model": "gpt-image-1", "model": defaultImageModel,
"runMode": "simulation", "runMode": "simulation",
"prompt": "a tiny gateway console", "prompt": "a tiny gateway console",
"size": "1024x1024", "size": "1024x1024",
@ -372,7 +379,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
} `json:"task"` } `json:"task"`
} }
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{ doJSON(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{
"model": "gpt-image-1", "model": defaultImageModel,
"runMode": "simulation", "runMode": "simulation",
"prompt": "replace background with clean studio light", "prompt": "replace background with clean studio light",
"image": "https://example.com/source.png", "image": "https://example.com/source.png",
@ -384,6 +391,42 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
t.Fatalf("unexpected image edit task: %+v", imageEditResponse.Task) t.Fatalf("unexpected image edit task: %+v", imageEditResponse.Task)
} }
doubaoLiteImageEditModel := "doubao-5.0-lite图像编辑"
var doubaoLitePlatformModel struct {
ID string `json:"id"`
}
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
"canonicalModelKey": "easyai:doubao-5.0-lite图像编辑",
"modelName": doubaoLiteImageEditModel,
"modelAlias": doubaoLiteImageEditModel,
"modelType": []string{"image_edit", "image_generate"},
"displayName": doubaoLiteImageEditModel,
}, http.StatusCreated, &doubaoLitePlatformModel)
var doubaoLiteAsyncEdit struct {
TaskID string `json:"taskId"`
Task struct {
ID string `json:"id"`
Status string `json:"status"`
AsyncMode bool `json:"asyncMode"`
} `json:"task"`
}
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{
"model": doubaoLiteImageEditModel,
"runMode": "simulation",
"prompt": "turn the attached bright desktop object into a clean product-style render",
"image": "https://example.com/doubao-lite-source.png",
"size": "2048x2048",
"simulation": true,
"simulationDurationMs": 5,
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &doubaoLiteAsyncEdit)
if doubaoLiteAsyncEdit.TaskID == "" || !doubaoLiteAsyncEdit.Task.AsyncMode {
t.Fatalf("doubao-5.0-lite image edit async task should be accepted: %+v", doubaoLiteAsyncEdit)
}
doubaoLiteAsyncEditDone := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, doubaoLiteAsyncEdit.TaskID, []string{"succeeded"}, 5*time.Second)
if doubaoLiteAsyncEditDone.Status != "succeeded" {
t.Fatalf("doubao-5.0-lite image edit async task should succeed through river queue, got %+v", doubaoLiteAsyncEditDone)
}
var gptImageModelTypesRaw []byte var gptImageModelTypesRaw []byte
if err := testPool.QueryRow(ctx, ` if err := testPool.QueryRow(ctx, `
SELECT model_type SELECT model_type
@ -641,6 +684,7 @@ WHERE reference_type = 'gateway_task'
} }
rateLimitedModel := "rate-limit-smoke-" + suffixText rateLimitedModel := "rate-limit-smoke-" + suffixText
rateLimitWindowSeconds := 3
var rateLimitPolicySet struct { var rateLimitPolicySet struct {
ID string `json:"id"` ID string `json:"id"`
} }
@ -652,7 +696,7 @@ WHERE reference_type = 'gateway_task'
"maxAttempts": 1, "maxAttempts": 1,
}, },
"rateLimitPolicy": map[string]any{ "rateLimitPolicy": map[string]any{
"rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": 60}}, "rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": rateLimitWindowSeconds}},
}, },
}, http.StatusCreated, &rateLimitPolicySet) }, http.StatusCreated, &rateLimitPolicySet)
var rateLimitPlatformModel map[string]any var rateLimitPlatformModel map[string]any
@ -682,6 +726,7 @@ WHERE reference_type = 'gateway_task'
if rateLimitFailedTask.Task.Status != "failed" || rateLimitFailedTask.Task.ErrorCode != "bad_request" { if rateLimitFailedTask.Task.Status != "failed" || rateLimitFailedTask.Task.ErrorCode != "bad_request" {
t.Fatalf("failed rate-limited task should fail before consuming rpm: %+v", rateLimitFailedTask.Task) t.Fatalf("failed rate-limited task should fail before consuming rpm: %+v", rateLimitFailedTask.Task)
} }
waitForRateLimitWindowHead(t, rateLimitWindowSeconds)
var rateLimitTaskOne struct { var rateLimitTaskOne struct {
Task struct { Task struct {
Status string `json:"status"` Status string `json:"status"`
@ -713,6 +758,38 @@ WHERE reference_type = 'gateway_task'
if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" { if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" {
t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task) t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task)
} }
var asyncRateLimitTask struct {
TaskID string `json:"taskId"`
Task struct {
ID string `json:"id"`
Status string `json:"status"`
AsyncMode bool `json:"asyncMode"`
} `json:"task"`
}
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": rateLimitedModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "async queued"}},
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &asyncRateLimitTask)
if asyncRateLimitTask.TaskID == "" || asyncRateLimitTask.Task.ID != asyncRateLimitTask.TaskID || !asyncRateLimitTask.Task.AsyncMode {
t.Fatalf("async task response should expose task id and async mode: %+v", asyncRateLimitTask)
}
asyncRateLimitDetail := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"queued"}, 2*time.Second)
if asyncRateLimitDetail.Status != "queued" {
t.Fatalf("async rate-limited task should return to queued state, got %+v", asyncRateLimitDetail)
}
if len(asyncRateLimitDetail.Attempts) == 0 || asyncRateLimitDetail.Attempts[0].ErrorCode != "rate_limit" {
t.Fatalf("async rate-limited task should record a rate_limit attempt before requeue: %+v", asyncRateLimitDetail)
}
asyncRateLimitCompleted := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"succeeded"}, time.Duration(rateLimitWindowSeconds+3)*time.Second)
if asyncRateLimitCompleted.Status != "succeeded" {
t.Fatalf("async rate-limited task should be pulled from queue after the limit window resets, got %+v", asyncRateLimitCompleted)
}
if len(asyncRateLimitCompleted.Attempts) < 2 || asyncRateLimitCompleted.Attempts[len(asyncRateLimitCompleted.Attempts)-1].Status != "succeeded" {
t.Fatalf("async rate-limited task should create a new successful attempt after requeue: %+v", asyncRateLimitCompleted)
}
videoRouteModel := "video-route-smoke-" + suffixText videoRouteModel := "video-route-smoke-" + suffixText
var videoRoutePlatformModel map[string]any var videoRoutePlatformModel map[string]any
@ -823,16 +900,19 @@ WHERE reference_type = 'gateway_task'
Metrics map[string]any `json:"metrics"` Metrics map[string]any `json:"metrics"`
} }
doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks/"+failoverTask.Task.ID, apiKeyResponse.Secret, nil, http.StatusOK, &failoverDetail) doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks/"+failoverTask.Task.ID, apiKeyResponse.Secret, nil, http.StatusOK, &failoverDetail)
if len(failoverDetail.Attempts) != 2 { if len(failoverDetail.Attempts) != 3 {
t.Fatalf("failover task history should include two attempts, got %+v", failoverDetail.Attempts) t.Fatalf("failover task history should include two failed retries plus one successful failover attempt, got %+v", failoverDetail.Attempts)
} }
if failoverDetail.Attempts[0].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[0].Status != "failed" || !failoverDetail.Attempts[0].Retryable || failoverDetail.Attempts[0].ErrorCode == "" { if failoverDetail.Attempts[0].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[0].Status != "failed" || !failoverDetail.Attempts[0].Retryable || failoverDetail.Attempts[0].ErrorCode == "" {
t.Fatalf("first failover attempt should preserve failed platform and reason: %+v", failoverDetail.Attempts[0]) t.Fatalf("first failover attempt should preserve failed platform and reason: %+v", failoverDetail.Attempts[0])
} }
if failoverDetail.Attempts[1].PlatformName != "OpenAI Retry Success" || failoverDetail.Attempts[1].Status != "succeeded" || failoverDetail.Attempts[1].ResponseMS <= 0 { if failoverDetail.Attempts[1].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[1].Status != "failed" || !failoverDetail.Attempts[1].Retryable {
t.Fatalf("second failover attempt should preserve successful platform: %+v", failoverDetail.Attempts[1]) t.Fatalf("second failover attempt should preserve the same-client retry failure: %+v", failoverDetail.Attempts[1])
} }
if summary, ok := failoverDetail.Metrics["attempts"].([]any); !ok || len(summary) != 2 { if failoverDetail.Attempts[2].PlatformName != "OpenAI Retry Success" || failoverDetail.Attempts[2].Status != "succeeded" || failoverDetail.Attempts[2].ResponseMS <= 0 {
t.Fatalf("third failover attempt should preserve successful platform: %+v", failoverDetail.Attempts[2])
}
if summary, ok := failoverDetail.Metrics["attempts"].([]any); !ok || len(summary) != 3 {
t.Fatalf("task metrics should keep attempt-chain summary, got %+v", failoverDetail.Metrics) t.Fatalf("task metrics should keep attempt-chain summary, got %+v", failoverDetail.Metrics)
} }
@ -1045,6 +1125,9 @@ WHERE m.platform_id = $1::uuid
if resp.StatusCode != http.StatusOK || !bytes.Contains(body, []byte("task.completed")) { if resp.StatusCode != http.StatusOK || !bytes.Contains(body, []byte("task.completed")) {
t.Fatalf("unexpected events response status=%d body=%s", resp.StatusCode, string(body)) t.Fatalf("unexpected events response status=%d body=%s", resp.StatusCode, string(body))
} }
if !bytes.Contains(body, []byte("task.running")) {
t.Fatalf("events response should include running transition event body=%s", string(body))
}
if !bytes.Contains(body, []byte("task.progress")) { if !bytes.Contains(body, []byte("task.progress")) {
t.Fatalf("events response should include progress events body=%s", string(body)) t.Fatalf("events response should include progress events body=%s", string(body))
} }
@ -1074,6 +1157,47 @@ WHERE m.platform_id = $1::uuid
if callbackRows == 0 { if callbackRows == 0 {
t.Fatal("task progress callback outbox should receive events") t.Fatal("task progress callback outbox should receive events")
} }
var restartAsyncTask struct {
TaskID string `json:"taskId"`
Task struct {
ID string `json:"id"`
Status string `json:"status"`
AsyncMode bool `json:"asyncMode"`
} `json:"task"`
}
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
"model": defaultTextModel,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 2000,
"messages": []map[string]any{{"role": "user", "content": "river worker restart"}},
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &restartAsyncTask)
if restartAsyncTask.TaskID == "" || !restartAsyncTask.Task.AsyncMode {
t.Fatalf("restart async task should be accepted as async: %+v", restartAsyncTask)
}
restartRunning := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, restartAsyncTask.TaskID, []string{"running"}, 3*time.Second)
if restartRunning.Status != "running" {
t.Fatalf("restart async task should be running before worker restart, got %+v", restartRunning)
}
cancelServer()
serverCtx2, cancelServer2 := context.WithCancel(ctx)
defer cancelServer2()
server2 := httptest.NewServer(NewServerWithContext(serverCtx2, config.Config{
AppEnv: "test",
HTTPAddr: ":0",
DatabaseURL: databaseURL,
IdentityMode: "hybrid",
JWTSecret: "test-secret",
TaskProgressCallbackEnabled: true,
TaskProgressCallbackURL: "http://callback.local/task-progress",
CORSAllowedOrigin: "*",
}, db, slog.New(slog.NewTextHandler(io.Discard, nil))))
defer server2.Close()
restartRecovered := waitForTaskStatus(t, server2.URL, apiKeyResponse.Secret, restartAsyncTask.TaskID, []string{"succeeded"}, 8*time.Second)
if restartRecovered.Status != "succeeded" {
t.Fatalf("river worker restart should recover and finish async task, got %+v", restartRecovered)
}
} }
func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) { func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) {
@ -1089,6 +1213,51 @@ func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) {
} }
} }
func assertRuntimeRecoveryReleasesPendingRateReservations(t *testing.T, ctx context.Context, db *store.Store) {
t.Helper()
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
task, err := db.CreateTask(ctx, store.CreateTaskInput{
Kind: "chat.completions",
Model: "recovery-reservation-smoke-" + suffix,
RunMode: "test",
Async: true,
Request: map[string]any{"messages": []any{}},
}, &auth.User{ID: "recovery-user-" + suffix})
if err != nil {
t.Fatalf("create recovery reservation task: %v", err)
}
scopeKey := "recovery-client-" + suffix
if _, err := db.ReserveRateLimits(ctx, task.ID, "", []store.RateLimitReservation{{
ScopeType: "client",
ScopeKey: scopeKey,
Metric: "rpm",
Limit: 10,
Amount: 3,
WindowSeconds: 60,
}}); err != nil {
t.Fatalf("reserve recovery rate limit: %v", err)
}
recovery, err := db.RecoverInterruptedRuntimeState(ctx)
if err != nil {
t.Fatalf("recover interrupted runtime state with pending reservation: %v", err)
}
if recovery.ReleasedRateReservations == 0 {
t.Fatalf("recovery should release pending rate reservation, got %+v", recovery)
}
var reservedValue float64
if err := db.Pool().QueryRow(ctx, `
SELECT COALESCE(MAX(reserved_value), 0)::float8
FROM gateway_rate_limit_counters
WHERE scope_type = 'client'
AND scope_key = $1
AND metric = 'rpm'`, scopeKey).Scan(&reservedValue); err != nil {
t.Fatalf("query recovered rate limit counter: %v", err)
}
if reservedValue != 0 {
t.Fatalf("recovery should release reserved counter value, got %f", reservedValue)
}
}
func applyMigration(t *testing.T, ctx context.Context, databaseURL string) { func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
t.Helper() t.Helper()
_, filename, _, _ := runtime.Caller(0) _, filename, _, _ := runtime.Caller(0)
@ -1114,6 +1283,11 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
} }
func doJSON(t *testing.T, baseURL string, method string, path string, token string, payload any, expectedStatus int, out any) { func doJSON(t *testing.T, baseURL string, method string, path string, token string, payload any, expectedStatus int, out any) {
t.Helper()
doJSONWithHeaders(t, baseURL, method, path, token, payload, nil, expectedStatus, out)
}
func doJSONWithHeaders(t *testing.T, baseURL string, method string, path string, token string, payload any, headers map[string]string, expectedStatus int, out any) {
t.Helper() t.Helper()
var body io.Reader var body io.Reader
if payload != nil { if payload != nil {
@ -1133,6 +1307,9 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri
if token != "" { if token != "" {
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
} }
for key, value := range headers {
req.Header.Set(key, value)
}
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Fatalf("%s %s: %v", method, path, err) t.Fatalf("%s %s: %v", method, path, err)
@ -1149,6 +1326,50 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri
} }
} }
type taskWaitDetail struct {
ID string `json:"id"`
Status string `json:"status"`
Attempts []struct {
Status string `json:"status"`
ErrorCode string `json:"errorCode"`
} `json:"attempts"`
}
func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string, statuses []string, timeout time.Duration) taskWaitDetail {
t.Helper()
wanted := map[string]bool{}
for _, status := range statuses {
wanted[status] = true
}
var detail taskWaitDetail
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskID, token, nil, http.StatusOK, &detail)
if wanted[detail.Status] {
return detail
}
time.Sleep(100 * time.Millisecond)
}
return detail
}
func waitForRateLimitWindowHead(t *testing.T, windowSeconds int) {
t.Helper()
if windowSeconds <= 0 {
return
}
window := time.Duration(windowSeconds) * time.Second
deadline := time.Now().Add(window + time.Second)
for time.Now().Before(deadline) {
elapsed := time.Duration(time.Now().UnixNano() % int64(window))
if elapsed < 300*time.Millisecond {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("timed out waiting for %ds rate limit window head", windowSeconds)
}
func stringSliceContains(values []string, target string) bool { func stringSliceContains(values []string, target string) bool {
for _, value := range values { for _, value := range values {
if value == target { if value == target {

View File

@ -542,11 +542,13 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
writeError(w, http.StatusForbidden, "api key scope does not allow this capability") writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
return return
} }
asyncMode := asyncRequest(r)
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{ task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
Kind: kind, Kind: kind,
Model: model, Model: model,
RunMode: runModeFromRequest(body), RunMode: runModeFromRequest(body),
Async: asyncMode,
Request: body, Request: body,
}, user) }, user)
if err != nil { if err != nil {
@ -554,6 +556,14 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
writeError(w, http.StatusInternalServerError, "create task failed") writeError(w, http.StatusInternalServerError, "create task failed")
return return
} }
if asyncMode {
if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil {
writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed")
return
}
writeTaskAccepted(w, task)
return
}
if compatible { if compatible {
if boolValue(body, "stream") { if boolValue(body, "stream") {
flusher := prepareCompatibleStream(w) flusher := prepareCompatibleStream(w)
@ -602,13 +612,23 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr) s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr)
} }
writeJSON(w, http.StatusAccepted, map[string]any{ writeTaskAccepted(w, result.Task)
"task": result.Task, })
"next": map[string]string{ }
"events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID),
"detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID), func asyncRequest(r *http.Request) bool {
}, value := strings.TrimSpace(strings.ToLower(r.Header.Get("x-async")))
}) return value == "1" || value == "true" || value == "yes" || value == "on"
}
func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) {
writeJSON(w, http.StatusAccepted, map[string]any{
"taskId": task.ID,
"task": task,
"next": map[string]string{
"events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID),
"detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID),
},
}) })
} }

View File

@ -1,6 +1,7 @@
package httpapi package httpapi
import ( import (
"context"
"log/slog" "log/slog"
"net/http" "net/http"
"strings" "strings"
@ -12,6 +13,7 @@ import (
) )
type Server struct { type Server struct {
ctx context.Context
cfg config.Config cfg config.Config
store *store.Store store *store.Store
auth *auth.Authenticator auth *auth.Authenticator
@ -20,7 +22,12 @@ type Server struct {
} }
func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler { func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler {
return NewServerWithContext(context.Background(), cfg, db, logger)
}
func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler {
server := &Server{ server := &Server{
ctx: ctx,
cfg: cfg, cfg: cfg,
store: db, store: db,
auth: auth.New(cfg.JWTSecret, cfg.ServerMainBaseURL, cfg.ServerMainInternalToken), auth: auth.New(cfg.JWTSecret, cfg.ServerMainBaseURL, cfg.ServerMainInternalToken),
@ -28,6 +35,7 @@ func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Han
logger: logger, logger: logger,
} }
server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey
server.runner.StartAsyncQueueWorker(ctx)
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", server.health) mux.HandleFunc("GET /healthz", server.health)
@ -146,7 +154,7 @@ func (s *Server) cors(next http.Handler) http.Handler {
w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin") w.Header().Set("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
} }
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {

View File

@ -2,13 +2,49 @@ package runner
import ( import (
"context" "context"
"errors"
"strings" "strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
) )
type localRateLimitError struct {
clientErr *clients.ClientError
cause error
retryAfter time.Duration
}
func (e *localRateLimitError) Error() string {
if e == nil || e.clientErr == nil {
return store.ErrRateLimited.Error()
}
return e.clientErr.Error()
}
func (e *localRateLimitError) Unwrap() []error {
if e == nil || e.clientErr == nil {
if e != nil && e.cause != nil {
return []error{e.cause}
}
return []error{store.ErrRateLimited}
}
if e.cause != nil {
return []error{e.clientErr, e.cause}
}
return []error{e.clientErr, store.ErrRateLimited}
}
func localRateLimitRetryAfter(err error) time.Duration {
var limitErr *localRateLimitError
if errors.As(err, &limitErr) && limitErr.retryAfter > 0 {
return limitErr.retryAfter
}
return store.RateLimitRetryAfter(err)
}
func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation { func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation {
out := make([]store.RateLimitReservation, 0) out := make([]store.RateLimitReservation, 0)
out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...) out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...)

View File

@ -0,0 +1,216 @@
package runner
import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/riverqueue/river"
"github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivermigrate"
"github.com/riverqueue/river/rivertype"
)
const asyncTaskQueueName = "gateway_tasks"
type asyncTaskArgs struct {
TaskID string `json:"task_id" river:"unique"`
}
func (asyncTaskArgs) Kind() string { return "gateway_task_run" }
type asyncTaskWorker struct {
river.WorkerDefaults[asyncTaskArgs]
service *Service
}
func (w *asyncTaskWorker) Work(ctx context.Context, job *river.Job[asyncTaskArgs]) error {
task, err := w.service.store.GetTask(ctx, job.Args.TaskID)
if err != nil {
return err
}
if task.Status == "succeeded" || task.Status == "failed" || task.Status == "cancelled" {
return nil
}
result, runErr := w.service.Execute(ctx, task, authUserFromTask(task))
if runErr == nil {
w.service.logger.Debug("river async task completed", "taskID", task.ID, "status", result.Task.Status, "riverJobID", job.ID)
return nil
}
var queuedErr *TaskQueuedError
if errors.As(runErr, &queuedErr) {
return river.JobSnooze(queuedErr.Delay)
}
if ctx.Err() != nil {
queued, queueErr := w.service.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
if queueErr != nil {
return queueErr
}
w.service.logger.Debug("river async task interrupted and requeued", "taskID", task.ID, "status", queued.Status, "riverJobID", job.ID)
return river.JobSnooze(0)
}
w.service.logger.Warn("river async task completed with failure", "taskID", task.ID, "error", runErr, "riverJobID", job.ID)
return nil
}
func (s *Service) StartAsyncQueueWorker(ctx context.Context) {
if err := s.startRiverQueue(ctx); err != nil {
s.logger.Error("start river async queue failed", "error", err)
panic(err)
}
}
func (s *Service) startRiverQueue(ctx context.Context) error {
driver := riverpgxv5.New(s.store.Pool())
migrator, err := rivermigrate.New(driver, nil)
if err != nil {
return err
}
if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil {
return err
}
workers := river.NewWorkers()
if err := river.AddWorkerSafely(workers, &asyncTaskWorker{service: s}); err != nil {
return err
}
riverClient, err := river.NewClient(driver, &river.Config{
ID: asyncWorkerID(),
JobTimeout: -1,
Logger: s.logger,
CompletedJobRetentionPeriod: 24 * time.Hour,
Queues: map[string]river.QueueConfig{
asyncTaskQueueName: {MaxWorkers: 32},
},
RescueStuckJobsAfter: 30 * time.Second,
TestOnly: s.cfg.AppEnv == "test",
Workers: workers,
})
if err != nil {
return err
}
s.riverClient = riverClient
if err := riverClient.Start(ctx); err != nil {
return err
}
if err := s.recoverAsyncRiverJobs(ctx); err != nil {
return err
}
go func() {
<-ctx.Done()
stopCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := riverClient.StopAndCancel(stopCtx); err != nil {
s.logger.Warn("stop river async queue failed", "error", err)
}
}()
return nil
}
func (s *Service) EnqueueAsyncTask(ctx context.Context, task store.GatewayTask) error {
if s.riverClient == nil {
return errors.New("river async queue is not started")
}
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: task.ID}, asyncTaskInsertOpts(task))
if err != nil {
return err
}
if result.Job != nil {
return s.store.SetTaskRiverJobID(ctx, task.ID, result.Job.ID)
}
return nil
}
func (s *Service) WakeAsyncQueueAfter(ctx context.Context, delay time.Duration) {
}
func (s *Service) RunAsyncTask(ctx context.Context, task store.GatewayTask, user *auth.User) {
if err := s.EnqueueAsyncTask(ctx, task); err != nil {
s.logger.Warn("enqueue river async task failed", "taskID", task.ID, "error", err)
}
}
func (s *Service) recoverAsyncRiverJobs(ctx context.Context) error {
items, err := s.store.ListRecoverableAsyncTasks(ctx, 1000)
if err != nil {
return err
}
for _, item := range items {
task := store.GatewayTask{ID: item.ID}
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: item.ID}, asyncTaskInsertOpts(task))
if err != nil {
return err
}
if result.Job != nil {
if err := s.store.SetTaskRiverJobID(ctx, item.ID, result.Job.ID); err != nil {
return err
}
}
}
if len(items) > 0 {
s.logger.Info("river async queue recovered persisted tasks", "count", len(items))
}
return nil
}
func asyncTaskInsertOpts(task store.GatewayTask) *river.InsertOpts {
priority := 2
if task.ID == "" {
priority = 3
}
return &river.InsertOpts{
MaxAttempts: 1000,
Priority: priority,
Queue: asyncTaskQueueName,
Tags: []string{"gateway-task"},
UniqueOpts: river.UniqueOpts{
ByArgs: true,
ByQueue: true,
ByState: []rivertype.JobState{
rivertype.JobStateAvailable,
rivertype.JobStatePending,
rivertype.JobStateRetryable,
rivertype.JobStateRunning,
rivertype.JobStateScheduled,
},
},
}
}
func authUserFromTask(task store.GatewayTask) *auth.User {
roles := []string{"user"}
if strings.TrimSpace(task.UserID) == "" {
roles = nil
}
return &auth.User{
ID: firstNonEmptyString(task.GatewayUserID, task.UserID),
Roles: roles,
TenantID: task.TenantID,
GatewayTenantID: task.GatewayTenantID,
TenantKey: task.TenantKey,
Source: firstNonEmptyString(task.UserSource, "gateway"),
GatewayUserID: task.GatewayUserID,
UserGroupID: task.UserGroupID,
UserGroupKey: task.UserGroupKey,
APIKeyID: task.APIKeyID,
APIKeyName: task.APIKeyName,
APIKeyPrefix: task.APIKeyPrefix,
}
}
func asyncWorkerID() string {
host, _ := os.Hostname()
host = strings.TrimSpace(host)
if host == "" {
host = "localhost"
}
return fmt.Sprintf("%s:%d:%d", host, os.Getpid(), time.Now().UnixNano())
}
var _ river.Worker[asyncTaskArgs] = (*asyncTaskWorker)(nil)

View File

@ -1,6 +1,7 @@
package runner package runner
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
@ -54,6 +55,9 @@ func shouldRetrySameClient(candidate store.RuntimeModelCandidate, err error) boo
func retryDecisionForCandidate(candidate store.RuntimeModelCandidate, err error) retryDecision { func retryDecisionForCandidate(candidate store.RuntimeModelCandidate, err error) retryDecision {
policy := effectiveRetryPolicy(candidate) policy := effectiveRetryPolicy(candidate)
info := failureInfoFromError(err) info := failureInfoFromError(err)
if errors.Is(err, store.ErrRateLimited) {
return retryDecision{Retry: false, Reason: "local_rate_limit_wait_queue", Match: policyRuleMatch{Source: "gateway_rate_limits", Policy: "rateLimitPolicy", Rule: "localCapacity", Value: "exceeded"}, Info: info}
}
if !boolFromPolicy(policy, "enabled", true) { if !boolFromPolicy(policy, "enabled", true) {
return retryDecision{Retry: false, Reason: "retry_disabled", Match: policyRuleMatch{Source: "model_runtime_policy_sets.retry_policy", Policy: "retryPolicy", Rule: "enabled", Value: "false"}, Info: info} return retryDecision{Retry: false, Reason: "retry_disabled", Match: policyRuleMatch{Source: "model_runtime_policy_sets.retry_policy", Policy: "retryPolicy", Rule: "enabled", Value: "false"}, Info: info}
} }
@ -94,6 +98,9 @@ func failoverDecisionForCandidate(runnerPolicy store.RunnerPolicy, candidate sto
if cooldownSeconds <= 0 { if cooldownSeconds <= 0 {
cooldownSeconds = 300 cooldownSeconds = 300
} }
if errors.Is(err, store.ErrRateLimited) && store.RateLimitRetryable(err) {
return failoverDecision{Retry: true, Action: "next", Reason: "local_rate_limit_try_next_candidate", CooldownSeconds: cooldownSeconds, Match: policyRuleMatch{Source: "gateway_rate_limits", Policy: "rateLimitPolicy", Rule: "localCapacity", Value: "exceeded"}, Info: info}
}
if match, ok := failoverAllowMatchWithSources(runnerPolicy.FailoverPolicy, overridePolicy, info); ok { if match, ok := failoverAllowMatchWithSources(runnerPolicy.FailoverPolicy, overridePolicy, info); ok {
return failoverDecision{Retry: true, Action: action, Reason: "failover_allow_policy", CooldownSeconds: cooldownSeconds, Match: match, Info: info} return failoverDecision{Retry: true, Action: action, Reason: "failover_allow_policy", CooldownSeconds: cooldownSeconds, Match: match, Info: info}
} }

View File

@ -2,6 +2,7 @@ package runner
import ( import (
"context" "context"
"errors"
"strings" "strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
@ -65,6 +66,9 @@ func (s *Service) applyFailoverAction(ctx context.Context, taskID string, candid
} }
func (s *Service) applyPriorityDemotePolicy(ctx context.Context, taskID string, attemptNo int, runnerPolicy store.RunnerPolicy, candidate store.RuntimeModelCandidate, cause error, simulated bool) { func (s *Service) applyPriorityDemotePolicy(ctx context.Context, taskID string, attemptNo int, runnerPolicy store.RunnerPolicy, candidate store.RuntimeModelCandidate, cause error, simulated bool) {
if errors.Is(cause, store.ErrRateLimited) {
return
}
decision := priorityDemoteDecisionForCandidate(runnerPolicy, cause) decision := priorityDemoteDecisionForCandidate(runnerPolicy, cause)
if !decision.Demote { if !decision.Demote {
return return

View File

@ -3,6 +3,7 @@ package runner
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"strconv" "strconv"
"strings" "strings"
@ -12,6 +13,8 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config" "github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/jackc/pgx/v5"
"github.com/riverqueue/river"
) )
type Service struct { type Service struct {
@ -20,6 +23,7 @@ type Service struct {
logger *slog.Logger logger *slog.Logger
clients map[string]clients.Client clients map[string]clients.Client
httpClients *httpClientCache httpClients *httpClientCache
riverClient *river.Client[pgx.Tx]
} }
type Result struct { type Result struct {
@ -27,6 +31,20 @@ type Result struct {
Output map[string]any Output map[string]any
} }
var ErrTaskQueued = errors.New("task queued")
type TaskQueuedError struct {
Delay time.Duration
}
func (e *TaskQueuedError) Error() string {
return ErrTaskQueued.Error()
}
func (e *TaskQueuedError) Is(target error) bool {
return target == ErrTaskQueued
}
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service { func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
httpClients := newHTTPClientCache() httpClients := newHTTPClientCache()
return &Service{ return &Service{
@ -55,6 +73,14 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
executeStartedAt := time.Now() executeStartedAt := time.Now()
body := normalizeRequest(task.Kind, task.Request) body := normalizeRequest(task.Kind, task.Request)
modelType := modelTypeFromKind(task.Kind, body) modelType := modelTypeFromKind(task.Kind, body)
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
return Result{}, err
}
if task.Status != "running" {
if err := s.emit(ctx, task.ID, "task.running", "running", "starting", 0.12, "task pulled from queue and started", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
return Result{}, err
}
}
if err := validateRequest(task.Kind, body); err != nil { if err := validateRequest(task.Kind, body); err != nil {
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err) failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
if finishErr != nil { if finishErr != nil {
@ -83,9 +109,6 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
return Result{}, err return Result{}, err
} }
} }
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
return Result{}, err
}
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil { if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
return Result{}, err return Result{}, err
} }
@ -96,7 +119,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
} }
maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy) maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy)
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy) maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
attemptNo := 0 attemptNo := task.AttemptCount
var lastErr error var lastErr error
for index, candidate := range candidates { for index, candidate := range candidates {
if index >= maxPlatforms { if index >= maxPlatforms {
@ -251,6 +274,20 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
if lastErr != nil { if lastErr != nil {
message = lastErr.Error() message = lastErr.Error()
} }
if task.AsyncMode && ctx.Err() != nil {
queued, queueErr := s.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
if queueErr != nil {
return Result{}, queueErr
}
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: 0}
}
if task.AsyncMode && errors.Is(lastErr, store.ErrRateLimited) && store.RateLimitRetryable(lastErr) {
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr)
if queueErr != nil {
return Result{}, queueErr
}
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay}
}
failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr) failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr)
if err != nil { if err != nil {
return Result{}, err return Result{}, err
@ -261,7 +298,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) { func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) {
simulated := isSimulation(task, candidate) simulated := isSimulation(task, candidate)
if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil { if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil {
return clients.Response{}, err return clients.Response{}, fmt.Errorf("emit attempt started: %w", err)
} }
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{ attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
TaskID: task.ID, TaskID: task.ID,
@ -276,21 +313,22 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
Metrics: attemptMetrics(candidate, attemptNo, simulated), Metrics: attemptMetrics(candidate, attemptNo, simulated),
}) })
if err != nil { if err != nil {
return clients.Response{}, err return clients.Response{}, fmt.Errorf("create task attempt: %w", err)
} }
reservations := s.rateLimitReservations(ctx, user, candidate, body) reservations := s.rateLimitReservations(ctx, user, candidate, body)
limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, attemptID, reservations) limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, attemptID, reservations)
if err != nil { if err != nil {
clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: false} retryable := store.RateLimitRetryable(err)
clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: retryable}
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID, AttemptID: attemptID,
Status: "failed", Status: "failed",
Retryable: false, Retryable: retryable,
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(clientErr, false)}}), Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": retryable, "retryAfterMs": localRateLimitRetryAfter(err).Milliseconds(), "trace": []any{failureTraceEntry(clientErr, retryable)}}),
ErrorCode: "rate_limit", ErrorCode: "rate_limit",
ErrorMessage: err.Error(), ErrorMessage: err.Error(),
}) })
return clients.Response{}, clientErr return clients.Response{}, &localRateLimitError{clientErr: clientErr, cause: err, retryAfter: localRateLimitRetryAfter(err)}
} }
rateReservationsFinalized := false rateReservationsFinalized := false
defer func() { defer func() {
@ -301,7 +339,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs) defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs)
if err := s.store.RecordClientAssignment(ctx, candidate); err != nil { if err := s.store.RecordClientAssignment(ctx, candidate); err != nil {
return clients.Response{}, err return clients.Response{}, fmt.Errorf("record client assignment: %w", err)
} }
defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "") defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "")
@ -315,17 +353,25 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
ErrorCode: clients.ErrorCode(err), ErrorCode: clients.ErrorCode(err),
ErrorMessage: err.Error(), ErrorMessage: err.Error(),
}) })
return clients.Response{}, err return clients.Response{}, fmt.Errorf("prepare http client: %w", err)
} }
client := s.clientFor(candidate, simulated) client := s.clientFor(candidate, simulated)
callStartedAt := time.Now() callStartedAt := time.Now()
response, err := client.Run(ctx, clients.Request{ response, err := client.Run(ctx, clients.Request{
Kind: task.Kind, Kind: task.Kind,
ModelType: candidate.ModelType, ModelType: candidate.ModelType,
Model: task.Model, Model: task.Model,
Body: body, Body: body,
Candidate: candidate, Candidate: candidate,
HTTPClient: requestHTTPClient, HTTPClient: requestHTTPClient,
RemoteTaskID: task.RemoteTaskID,
RemoteTaskPayload: task.RemoteTaskPayload,
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
if strings.TrimSpace(remoteTaskID) == "" {
return nil
}
return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload)
},
Stream: boolFromMap(body, "stream"), Stream: boolFromMap(body, "stream"),
StreamDelta: onDelta, StreamDelta: onDelta,
}) })
@ -400,11 +446,11 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
response.Result = uploadedResult response.Result = uploadedResult
for _, progress := range response.Progress { for _, progress := range response.Progress {
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil { if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
return clients.Response{}, err return clients.Response{}, fmt.Errorf("emit task progress: %w", err)
} }
} }
if err := s.store.CommitRateLimitReservations(ctx, limitResult.Reservations, tokenUsageAmounts(response.Usage)); err != nil { if err := s.store.CommitRateLimitReservations(ctx, limitResult.Reservations, tokenUsageAmounts(response.Usage)); err != nil {
return clients.Response{}, err return clients.Response{}, fmt.Errorf("commit rate limit reservations: %w", err)
} }
rateReservationsFinalized = true rateReservationsFinalized = true
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
@ -418,7 +464,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
ResponseFinishedAt: response.ResponseFinishedAt, ResponseFinishedAt: response.ResponseFinishedAt,
ResponseDurationMS: response.ResponseDurationMS, ResponseDurationMS: response.ResponseDurationMS,
}); err != nil { }); err != nil {
return clients.Response{}, err return clients.Response{}, fmt.Errorf("finish task attempt: %w", err)
} }
return response, nil return response, nil
} }
@ -459,6 +505,41 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess
return failed, nil return failed, nil
} }
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error) (store.GatewayTask, time.Duration, error) {
delay := localRateLimitRetryAfter(cause)
if delay <= 0 {
delay = 5 * time.Second
}
queued, err := s.store.RequeueTask(ctx, task.ID, delay)
if err != nil {
return store.GatewayTask{}, 0, err
}
payload := map[string]any{
"code": "rate_limit",
"message": cause.Error(),
"retryAfterMs": delay.Milliseconds(),
}
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "rate_limited", 0.2, "task queued by local rate limit", payload, task.RunMode == "simulation"); eventErr != nil {
return store.GatewayTask{}, 0, eventErr
}
return queued, delay, nil
}
func (s *Service) requeueInterruptedAsyncTask(ctx context.Context, task store.GatewayTask) (store.GatewayTask, error) {
queued, err := s.store.RequeueTask(ctx, task.ID, 0)
if err != nil {
return store.GatewayTask{}, err
}
payload := map[string]any{"code": "worker_interrupted"}
if task.RemoteTaskID != "" {
payload["remoteTaskId"] = task.RemoteTaskID
}
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "worker_interrupted", 0.2, "async task queued after worker interruption", payload, task.RunMode == "simulation"); eventErr != nil {
return store.GatewayTask{}, eventErr
}
return queued, nil
}
func (s *Service) withAttemptHistory(ctx context.Context, taskID string, metrics map[string]any) map[string]any { func (s *Service) withAttemptHistory(ctx context.Context, taskID string, metrics map[string]any) map[string]any {
attempts, err := s.store.ListTaskAttempts(ctx, taskID) attempts, err := s.store.ListTaskAttempts(ctx, taskID)
if err != nil { if err != nil {

View File

@ -19,7 +19,7 @@ SELECT p.id::text, p.platform_key, p.name, p.provider,
COALESCE(p.dynamic_priority, p.priority) AS effective_priority, COALESCE(p.dynamic_priority, p.priority) AS effective_priority,
m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''), m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''),
COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''), COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''),
$2 AS requested_model_type, m.display_name, m.capabilities, m.capability_override, $2::text AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override, COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''), m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
COALESCE(b.pricing_rule_set_id::text, ''), COALESCE(b.pricing_rule_set_id::text, ''),
@ -33,21 +33,21 @@ LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provi
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id) LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
LEFT JOIN runtime_client_states s LEFT JOIN runtime_client_states s
ON s.client_id = p.platform_key || ':' || $2 || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name) ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
WHERE p.status = 'enabled' WHERE p.status = 'enabled'
AND p.deleted_at IS NULL AND p.deleted_at IS NULL
AND m.enabled = true AND m.enabled = true
AND m.model_type @> jsonb_build_array($2) AND m.model_type @> jsonb_build_array($2::text)
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now()) AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
AND (m.cooldown_until IS NULL OR m.cooldown_until <= now()) AND (m.cooldown_until IS NULL OR m.cooldown_until <= now())
AND ( AND (
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1) (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
OR ( OR (
COALESCE(m.model_alias, '') = '' COALESCE(m.model_alias, '') = ''
AND ( AND (
m.model_name = $1 m.model_name = $1::text
OR b.canonical_model_key = $1 OR b.canonical_model_key = $1::text
OR b.provider_model_name = $1 OR b.provider_model_name = $1::text
) )
) )
) )
@ -184,15 +184,15 @@ LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
WHERE p.status = 'enabled' WHERE p.status = 'enabled'
AND p.deleted_at IS NULL AND p.deleted_at IS NULL
AND m.enabled = true AND m.enabled = true
AND m.model_type @> jsonb_build_array($2) AND m.model_type @> jsonb_build_array($2::text)
AND ( AND (
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1) (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
OR ( OR (
COALESCE(m.model_alias, '') = '' COALESCE(m.model_alias, '') = ''
AND ( AND (
m.model_name = $1 m.model_name = $1::text
OR b.canonical_model_key = $1 OR b.canonical_model_key = $1::text
OR b.provider_model_name = $1 OR b.provider_model_name = $1::text
) )
) )
) )

View File

@ -53,6 +53,10 @@ func (s *Store) Ping(ctx context.Context) error {
return s.pool.Ping(ctx) return s.pool.Ping(ctx)
} }
func (s *Store) Pool() *pgxpool.Pool {
return s.pool
}
type Platform struct { type Platform struct {
ID string `json:"id"` ID string `json:"id"`
Provider string `json:"provider"` Provider string `json:"provider"`
@ -374,6 +378,7 @@ type CreateTaskInput struct {
Kind string `json:"kind"` Kind string `json:"kind"`
Model string `json:"model"` Model string `json:"model"`
RunMode string `json:"runMode"` RunMode string `json:"runMode"`
Async bool `json:"async"`
Request map[string]any `json:"request"` Request map[string]any `json:"request"`
} }
@ -398,7 +403,12 @@ type GatewayTask struct {
ResolvedModel string `json:"resolvedModel,omitempty"` ResolvedModel string `json:"resolvedModel,omitempty"`
RequestID string `json:"requestId,omitempty"` RequestID string `json:"requestId,omitempty"`
Request map[string]any `json:"request,omitempty"` Request map[string]any `json:"request,omitempty"`
AsyncMode bool `json:"asyncMode"`
RiverJobID int64 `json:"riverJobId,omitempty"`
Status string `json:"status"` Status string `json:"status"`
AttemptCount int `json:"attemptCount"`
RemoteTaskID string `json:"remoteTaskId,omitempty"`
RemoteTaskPayload map[string]any `json:"remoteTaskPayload,omitempty"`
Result map[string]any `json:"result,omitempty"` Result map[string]any `json:"result,omitempty"`
Billings []any `json:"billings,omitempty"` Billings []any `json:"billings,omitempty"`
Usage map[string]any `json:"usage"` Usage map[string]any `json:"usage"`
@ -423,7 +433,9 @@ COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_
COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''), COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''),
COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model, COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model,
COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''), COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''),
request, status, COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb), request, COALESCE(async_mode, false), COALESCE(river_job_id, 0), status, COALESCE(attempt_count, 0),
COALESCE(remote_task_id, ''), COALESCE(remote_task_payload, '{}'::jsonb),
COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
COALESCE(usage, '{}'::jsonb), COALESCE(metrics, '{}'::jsonb), COALESCE(billing_summary, '{}'::jsonb), COALESCE(usage, '{}'::jsonb), COALESCE(metrics, '{}'::jsonb), COALESCE(billing_summary, '{}'::jsonb),
COALESCE(final_charge_amount, 0)::float8, COALESCE(response_started_at::text, ''), COALESCE(final_charge_amount, 0)::float8, COALESCE(response_started_at::text, ''),
COALESCE(response_finished_at::text, ''), COALESCE(response_duration_ms, 0), COALESCE(error, ''), COALESCE(response_finished_at::text, ''), COALESCE(response_duration_ms, 0), COALESCE(error, ''),
@ -1675,11 +1687,11 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut
INSERT INTO gateway_tasks ( INSERT INTO gateway_tasks (
kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key, kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key,
api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key, api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key,
model, requested_model, request, status, result, billings, finished_at model, requested_model, request, async_mode, status, result, billings, finished_at
) )
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17::jsonb, $18::jsonb, CASE WHEN $19 THEN now() ELSE NULL END) VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, CASE WHEN $20 THEN now() ELSE NULL END)
RETURNING `+gatewayTaskColumns, RETURNING `+gatewayTaskColumns,
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, status, resultBody, billingsBody, false, input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, false,
)) ))
if err != nil { if err != nil {
return GatewayTask{}, err return GatewayTask{}, err
@ -1689,7 +1701,7 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut
payload, _ := json.Marshal(event.Payload) payload, _ := json.Marshal(event.Payload)
if _, err := tx.Exec(ctx, ` if _, err := tx.Exec(ctx, `
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
VALUES ($1::uuid, $2, $3, NULLIF($4, ''), NULLIF($5, ''), $6, NULLIF($7, ''), $8::jsonb, $9)`, VALUES ($1::uuid, $2, $3::text, NULLIF($4::text, ''), NULLIF($5::text, ''), $6, NULLIF($7::text, ''), $8::jsonb, $9)`,
task.ID, event.Seq, event.EventType, event.Status, event.Phase, event.Progress, event.Message, string(payload), event.Simulated, task.ID, event.Seq, event.EventType, event.Status, event.Phase, event.Progress, event.Message, string(payload), event.Simulated,
); err != nil { ); err != nil {
return GatewayTask{}, err return GatewayTask{}, err
@ -1730,6 +1742,7 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
var usageBytes []byte var usageBytes []byte
var metricsBytes []byte var metricsBytes []byte
var billingSummaryBytes []byte var billingSummaryBytes []byte
var remoteTaskPayloadBytes []byte
if err := scanner.Scan( if err := scanner.Scan(
&task.ID, &task.ID,
&task.Kind, &task.Kind,
@ -1751,7 +1764,12 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
&task.ResolvedModel, &task.ResolvedModel,
&task.RequestID, &task.RequestID,
&requestBytes, &requestBytes,
&task.AsyncMode,
&task.RiverJobID,
&task.Status, &task.Status,
&task.AttemptCount,
&task.RemoteTaskID,
&remoteTaskPayloadBytes,
&resultBytes, &resultBytes,
&billingsBytes, &billingsBytes,
&usageBytes, &usageBytes,
@ -1771,6 +1789,7 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
return GatewayTask{}, err return GatewayTask{}, err
} }
task.Request = decodeObject(requestBytes) task.Request = decodeObject(requestBytes)
task.RemoteTaskPayload = decodeObject(remoteTaskPayloadBytes)
task.Result = decodeObject(resultBytes) task.Result = decodeObject(resultBytes)
task.Billings = decodeArray(billingsBytes) task.Billings = decodeArray(billingsBytes)
task.Usage = decodeObject(usageBytes) task.Usage = decodeObject(usageBytes)

View File

@ -31,6 +31,7 @@ type ModelRateLimitStatus struct {
PlatformCooldownUntil string `json:"platformCooldownUntil,omitempty"` PlatformCooldownUntil string `json:"platformCooldownUntil,omitempty"`
ModelCooldownUntil string `json:"modelCooldownUntil,omitempty"` ModelCooldownUntil string `json:"modelCooldownUntil,omitempty"`
Concurrent RateLimitMetricStatus `json:"concurrent"` Concurrent RateLimitMetricStatus `json:"concurrent"`
QueuedTasks float64 `json:"queuedTasks"`
RPM RateLimitMetricStatus `json:"rpm"` RPM RateLimitMetricStatus `json:"rpm"`
TPM RateLimitMetricStatus `json:"tpm"` TPM RateLimitMetricStatus `json:"tpm"`
LoadRatio float64 `json:"loadRatio"` LoadRatio float64 `json:"loadRatio"`
@ -38,15 +39,16 @@ type ModelRateLimitStatus struct {
func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) { func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) {
rows, err := s.pool.Query(ctx, ` rows, err := s.pool.Query(ctx, `
SELECT m.id::text, m.platform_id::text, p.name, p.provider, SELECT m.id::text, m.platform_id::text, p.name, p.provider,
m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''), m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''),
m.model_type, m.display_name, m.enabled, m.model_type, m.display_name, m.enabled,
p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy, p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy,
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''), COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''), COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
COALESCE(con.active, 0)::float8, COALESCE(con.active, 0)::float8,
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''), COALESCE(queued.waiting, 0)::float8,
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '') COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''),
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '')
FROM platform_models m FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
@ -59,6 +61,19 @@ LEFT JOIN (
AND expires_at > now() AND expires_at > now()
GROUP BY scope_key GROUP BY scope_key
) con ON con.scope_key = m.id::text ) con ON con.scope_key = m.id::text
LEFT JOIN (
SELECT latest.platform_model_id, COUNT(*) AS waiting
FROM (
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
FROM gateway_tasks t
JOIN gateway_task_attempts a ON a.task_id = t.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND a.platform_model_id IS NOT NULL
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
) latest
GROUP BY latest.platform_model_id
) queued ON queued.platform_model_id = m.id::text
LEFT JOIN ( LEFT JOIN (
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value, reset_at SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value, reset_at
FROM gateway_rate_limit_counters FROM gateway_rate_limit_counters
@ -93,6 +108,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
var platformCooldownUntil string var platformCooldownUntil string
var modelCooldownUntil string var modelCooldownUntil string
var concurrentCurrent float64 var concurrentCurrent float64
var queuedTasks float64
var rpmUsed float64 var rpmUsed float64
var rpmReserved float64 var rpmReserved float64
var rpmResetAt string var rpmResetAt string
@ -117,6 +133,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
&platformCooldownUntil, &platformCooldownUntil,
&modelCooldownUntil, &modelCooldownUntil,
&concurrentCurrent, &concurrentCurrent,
&queuedTasks,
&rpmUsed, &rpmUsed,
&rpmReserved, &rpmReserved,
&rpmResetAt, &rpmResetAt,
@ -136,6 +153,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
item.PlatformCooldownUntil = platformCooldownUntil item.PlatformCooldownUntil = platformCooldownUntil
item.ModelCooldownUntil = modelCooldownUntil item.ModelCooldownUntil = modelCooldownUntil
item.RateLimitPolicy = policy item.RateLimitPolicy = policy
item.QueuedTasks = queuedTasks
item.Concurrent = metricStatus(concurrentCurrent, concurrentCurrent, 0, rateLimitForMetric(policy, "concurrent"), "") item.Concurrent = metricStatus(concurrentCurrent, concurrentCurrent, 0, rateLimitForMetric(policy, "concurrent"), "")
item.RPM = metricStatus(rpmUsed+rpmReserved, rpmUsed, rpmReserved, rateLimitForMetric(policy, "rpm"), rpmResetAt) item.RPM = metricStatus(rpmUsed+rpmReserved, rpmUsed, rpmReserved, rateLimitForMetric(policy, "rpm"), rpmResetAt)
item.TPM = metricStatus(tpmUsed+tpmReserved, tpmUsed, tpmReserved, tpmLimit(policy), tpmResetAt) item.TPM = metricStatus(tpmUsed+tpmReserved, tpmUsed, tpmReserved, tpmLimit(policy), tpmResetAt)

View File

@ -3,6 +3,7 @@ package store
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"time" "time"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
@ -13,6 +14,7 @@ type RuntimeRecoveryResult struct {
ReleasedRateReservations int64 `json:"releasedRateReservations"` ReleasedRateReservations int64 `json:"releasedRateReservations"`
FailedAttempts int64 `json:"failedAttempts"` FailedAttempts int64 `json:"failedAttempts"`
FailedTasks int64 `json:"failedTasks"` FailedTasks int64 `json:"failedTasks"`
RequeuedAsyncTasks int64 `json:"requeuedAsyncTasks"`
} }
func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID string, reservations []RateLimitReservation) (RateLimitResult, error) { func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID string, reservations []RateLimitReservation) (RateLimitResult, error) {
@ -28,7 +30,11 @@ func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID
continue continue
} }
if reservation.Metric == "" || reservation.Amount > reservation.Limit { if reservation.Metric == "" || reservation.Amount > reservation.Limit {
return RateLimitResult{}, ErrRateLimited return RateLimitResult{}, &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit),
Retryable: false,
}
} }
if reservation.WindowSeconds <= 0 { if reservation.WindowSeconds <= 0 {
reservation.WindowSeconds = 60 reservation.WindowSeconds = 60
@ -55,8 +61,10 @@ func reserveConcurrencyLease(ctx context.Context, tx pgx.Tx, taskID string, atte
reservation.LeaseTTLSeconds = 120 reservation.LeaseTTLSeconds = 120
} }
var active float64 var active float64
var nextAvailableAt time.Time
if err := tx.QueryRow(ctx, ` if err := tx.QueryRow(ctx, `
SELECT COALESCE(SUM(lease_value), 0)::float8 SELECT COALESCE(SUM(lease_value), 0)::float8,
COALESCE(MIN(expires_at), now() + ($3::int * interval '1 second'))
FROM gateway_concurrency_leases FROM gateway_concurrency_leases
WHERE scope_type = $1 WHERE scope_type = $1
AND scope_key = $2 AND scope_key = $2
@ -64,11 +72,17 @@ WHERE scope_type = $1
AND expires_at > now()`, AND expires_at > now()`,
reservation.ScopeType, reservation.ScopeType,
reservation.ScopeKey, reservation.ScopeKey,
).Scan(&active); err != nil { reservation.LeaseTTLSeconds,
).Scan(&active, &nextAvailableAt); err != nil {
return "", err return "", err
} }
if active+reservation.Amount > reservation.Limit { if active+reservation.Amount > reservation.Limit {
return "", ErrRateLimited return "", &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit),
RetryAfter: concurrencyRetryAfter(nextAvailableAt),
Retryable: true,
}
} }
var leaseID string var leaseID string
if err := tx.QueryRow(ctx, ` if err := tx.QueryRow(ctx, `
@ -92,13 +106,16 @@ func reserveCounterWindow(ctx context.Context, tx pgx.Tx, taskID string, attempt
reservedAmount := reservation.Amount reservedAmount := reservation.Amount
var windowStart time.Time var windowStart time.Time
err := tx.QueryRow(ctx, ` err := tx.QueryRow(ctx, `
WITH bounds AS (
SELECT
to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) AS window_start,
to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) + ($7::int * interval '1 second') AS reset_at
)
INSERT INTO gateway_rate_limit_counters ( INSERT INTO gateway_rate_limit_counters (
scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at
) )
VALUES ( SELECT $1, $2, $3, bounds.window_start, $4, $5, $6, bounds.reset_at
$1, $2, $3, date_trunc('minute', now()), $4, $5, $6, FROM bounds
date_trunc('minute', now()) + ($7::int * interval '1 second')
)
ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE
SET limit_value = EXCLUDED.limit_value, SET limit_value = EXCLUDED.limit_value,
used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value, used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value,
@ -117,7 +134,28 @@ RETURNING window_start`,
).Scan(&windowStart) ).Scan(&windowStart)
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return RateLimitReservation{}, ErrRateLimited resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second)
_ = tx.QueryRow(ctx, `
WITH bounds AS (
SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start
)
SELECT counters.reset_at
FROM gateway_rate_limit_counters counters
JOIN bounds ON counters.window_start = bounds.window_start
WHERE scope_type = $1
AND scope_key = $2
AND metric = $3`,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Metric,
reservation.WindowSeconds,
).Scan(&resetAt)
return RateLimitReservation{}, &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric),
RetryAfter: retryAfterUntil(resetAt),
Retryable: true,
}
} }
return RateLimitReservation{}, err return RateLimitReservation{}, err
} }
@ -144,6 +182,28 @@ RETURNING id::text`,
return reservation, nil return reservation, nil
} }
func retryAfterUntil(when time.Time) time.Duration {
if when.IsZero() {
return 0
}
duration := time.Until(when)
if duration < time.Second {
return time.Second
}
return duration
}
func concurrencyRetryAfter(leaseExpiresAt time.Time) time.Duration {
if leaseExpiresAt.IsZero() {
return time.Second
}
duration := time.Until(leaseExpiresAt)
if duration <= time.Second {
return time.Second
}
return time.Second
}
func (s *Store) CommitRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64) error { func (s *Store) CommitRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64) error {
return s.finishRateLimitReservations(ctx, reservations, actualByMetric, "committed", "success") return s.finishRateLimitReservations(ctx, reservations, actualByMetric, "committed", "success")
} }
@ -189,23 +249,26 @@ RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`)
if err != nil { if err != nil {
return RuntimeRecoveryResult{}, err return RuntimeRecoveryResult{}, err
} }
reservations := make([]RateLimitReservation, 0)
for rows.Next() { for rows.Next() {
var reservation RateLimitReservation var reservation RateLimitReservation
if err := rows.Scan(&reservation.ScopeType, &reservation.ScopeKey, &reservation.Metric, &reservation.WindowStart, &reservation.Amount); err != nil { if err := rows.Scan(&reservation.ScopeType, &reservation.ScopeKey, &reservation.Metric, &reservation.WindowStart, &reservation.Amount); err != nil {
rows.Close() rows.Close()
return RuntimeRecoveryResult{}, err return RuntimeRecoveryResult{}, err
} }
if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil { reservations = append(reservations, reservation)
rows.Close()
return RuntimeRecoveryResult{}, err
}
result.ReleasedRateReservations++
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
rows.Close() rows.Close()
return RuntimeRecoveryResult{}, err return RuntimeRecoveryResult{}, err
} }
rows.Close() rows.Close()
for _, reservation := range reservations {
if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil {
return RuntimeRecoveryResult{}, err
}
}
result.ReleasedRateReservations = int64(len(reservations))
tag, err := tx.Exec(ctx, ` tag, err := tx.Exec(ctx, `
UPDATE gateway_concurrency_leases UPDATE gateway_concurrency_leases
@ -220,7 +283,7 @@ WHERE released_at IS NULL
tag, err = tx.Exec(ctx, ` tag, err = tx.Exec(ctx, `
UPDATE gateway_task_attempts UPDATE gateway_task_attempts
SET status = 'failed', SET status = 'failed',
retryable = false, retryable = true,
error_code = 'server_restarted', error_code = 'server_restarted',
error_message = 'attempt interrupted by service restart', error_message = 'attempt interrupted by service restart',
finished_at = now() finished_at = now()
@ -230,6 +293,57 @@ WHERE status = 'running'`)
} }
result.FailedAttempts = tag.RowsAffected() result.FailedAttempts = tag.RowsAffected()
asyncTaskRows, err := tx.Query(ctx, `
UPDATE gateway_tasks
SET status = 'queued',
error = NULL,
error_code = NULL,
error_message = NULL,
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
next_run_at = now(),
finished_at = NULL,
updated_at = now()
WHERE async_mode = true
AND status = 'running'
RETURNING id::text`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
asyncTaskIDs := make([]string, 0)
for asyncTaskRows.Next() {
var taskID string
if err := asyncTaskRows.Scan(&taskID); err != nil {
asyncTaskRows.Close()
return RuntimeRecoveryResult{}, err
}
asyncTaskIDs = append(asyncTaskIDs, taskID)
}
if err := asyncTaskRows.Err(); err != nil {
asyncTaskRows.Close()
return RuntimeRecoveryResult{}, err
}
asyncTaskRows.Close()
for _, taskID := range asyncTaskIDs {
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
VALUES (
$1::uuid,
COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1),
'task.recovered',
'queued',
'recovered',
0.2,
'async task recovered after service restart',
'{"code":"server_restarted"}'::jsonb,
false
)`, taskID); err != nil {
return RuntimeRecoveryResult{}, err
}
}
result.RequeuedAsyncTasks = int64(len(asyncTaskIDs))
taskRows, err := tx.Query(ctx, ` taskRows, err := tx.Query(ctx, `
UPDATE gateway_tasks UPDATE gateway_tasks
SET status = 'failed', SET status = 'failed',
@ -238,7 +352,8 @@ SET status = 'failed',
error_message = 'task interrupted by service restart', error_message = 'task interrupted by service restart',
finished_at = now(), finished_at = now(),
updated_at = now() updated_at = now()
WHERE status IN ('queued', 'running') WHERE async_mode = false
AND status = 'running'
RETURNING id::text`) RETURNING id::text`)
if err != nil { if err != nil {
return RuntimeRecoveryResult{}, err return RuntimeRecoveryResult{}, err
@ -301,9 +416,9 @@ func (s *Store) finishRateLimitReservations(ctx context.Context, reservations []
var stored RateLimitReservation var stored RateLimitReservation
err := tx.QueryRow(ctx, ` err := tx.QueryRow(ctx, `
UPDATE gateway_rate_limit_reservations UPDATE gateway_rate_limit_reservations
SET status = $2, SET status = $2::text,
reason = NULLIF($3, ''), reason = NULLIF($3::text, ''),
actual_amount = CASE WHEN $2 = 'committed' THEN $4 ELSE actual_amount END, actual_amount = CASE WHEN $2::text = 'committed' THEN $4 ELSE actual_amount END,
finalized_at = now(), finalized_at = now(),
updated_at = now() updated_at = now()
WHERE id = $1::uuid WHERE id = $1::uuid

View File

@ -32,6 +32,43 @@ func ModelCandidateErrorCode(err error) string {
return "no_model_candidate" return "no_model_candidate"
} }
type RateLimitExceededError struct {
Metric string
Message string
RetryAfter time.Duration
Retryable bool
}
func (e *RateLimitExceededError) Error() string {
if strings.TrimSpace(e.Message) != "" {
return e.Message
}
if strings.TrimSpace(e.Metric) != "" {
return "rate limit exceeded for " + e.Metric
}
return ErrRateLimited.Error()
}
func (e *RateLimitExceededError) Unwrap() error {
return ErrRateLimited
}
func RateLimitRetryAfter(err error) time.Duration {
var limitErr *RateLimitExceededError
if errors.As(err, &limitErr) && limitErr.RetryAfter > 0 {
return limitErr.RetryAfter
}
return 0
}
func RateLimitRetryable(err error) bool {
var limitErr *RateLimitExceededError
if errors.As(err, &limitErr) {
return limitErr.Retryable
}
return errors.Is(err, ErrRateLimited)
}
type CreatePlatformModelInput struct { type CreatePlatformModelInput struct {
PlatformID string `json:"platformId"` PlatformID string `json:"platformId"`
BaseModelID string `json:"baseModelId"` BaseModelID string `json:"baseModelId"`
@ -132,6 +169,11 @@ type CreateTaskAttemptInput struct {
Metrics map[string]any Metrics map[string]any
} }
type AsyncTaskQueueItem struct {
ID string
Priority int
}
type FinishTaskAttemptInput struct { type FinishTaskAttemptInput struct {
AttemptID string AttemptID string
Status string Status string

View File

@ -157,7 +157,7 @@ func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType st
_, err := s.pool.Exec(ctx, ` _, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks UPDATE gateway_tasks
SET status = 'running', SET status = 'running',
model_type = NULLIF($2, ''), model_type = NULLIF($2::text, ''),
normalized_request = $3::jsonb, normalized_request = $3::jsonb,
locked_at = now(), locked_at = now(),
heartbeat_at = now(), heartbeat_at = now(),
@ -166,6 +166,124 @@ WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON))
return err return err
} }
func (s *Store) ClaimAsyncQueuedTask(ctx context.Context, workerID string) (GatewayTask, error) {
return scanGatewayTask(s.pool.QueryRow(ctx, `
WITH picked AS (
SELECT id AS task_id
FROM gateway_tasks
WHERE async_mode = true
AND status = 'queued'
AND next_run_at <= now()
ORDER BY priority ASC, created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE gateway_tasks t
SET status = 'running',
locked_by = NULLIF($1::text, ''),
locked_at = now(),
heartbeat_at = now(),
updated_at = now()
FROM picked
WHERE t.id = picked.task_id
RETURNING `+gatewayTaskColumns, workerID))
}
func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration) (GatewayTask, error) {
if delay < time.Second {
delay = time.Second
}
if delay > 10*time.Minute {
delay = 10 * time.Minute
}
nextRunAt := time.Now().Add(delay)
return scanGatewayTask(s.pool.QueryRow(ctx, `
UPDATE gateway_tasks
SET status = 'queued',
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
next_run_at = $2::timestamptz,
error = NULL,
error_code = NULL,
error_message = NULL,
updated_at = now()
WHERE id = $1::uuid
RETURNING `+gatewayTaskColumns, taskID, nextRunAt))
}
func (s *Store) SetTaskRiverJobID(ctx context.Context, taskID string, riverJobID int64) error {
if riverJobID <= 0 {
return nil
}
_, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET river_job_id = $2,
updated_at = now()
WHERE id = $1::uuid`, taskID, riverJobID)
return err
}
func (s *Store) SetTaskRemoteTask(ctx context.Context, taskID string, attemptID string, remoteTaskID string, payload map[string]any) error {
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
if _, err := tx.Exec(ctx, `
UPDATE gateway_tasks
SET remote_task_id = NULLIF($2::text, ''),
remote_task_payload = $3::jsonb,
updated_at = now()
WHERE id = $1::uuid`,
taskID,
remoteTaskID,
string(payloadJSON),
); err != nil {
return err
}
if strings.TrimSpace(attemptID) == "" {
return nil
}
_, err := tx.Exec(ctx, `
UPDATE gateway_task_attempts
SET remote_task_id = NULLIF($2::text, ''),
response_snapshot = COALESCE(response_snapshot, '{}'::jsonb) || jsonb_build_object('remote_task_payload', $3::jsonb)
WHERE id = $1::uuid`,
attemptID,
remoteTaskID,
string(payloadJSON),
)
return err
})
}
func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) {
if limit <= 0 {
limit = 500
}
rows, err := s.pool.Query(ctx, `
SELECT id::text, priority
FROM gateway_tasks
WHERE async_mode = true
AND status IN ('queued', 'running')
ORDER BY priority ASC, created_at ASC
LIMIT $1`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]AsyncTaskQueueItem, 0)
for rows.Next() {
var item AsyncTaskQueueItem
if err := rows.Scan(&item.ID, &item.Priority); err != nil {
return nil, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) { func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) {
requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot)) requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot))
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
@ -182,7 +300,7 @@ INSERT INTO gateway_task_attempts (
status, simulated, request_snapshot, metrics status, simulated, request_snapshot, metrics
) )
VALUES ( VALUES (
$1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6, $1::uuid, $2::int, NULLIF($3::text, '')::uuid, NULLIF($4::text, '')::uuid, NULLIF($5::text, ''), $6,
$7, $8, $9::jsonb, $10::jsonb $7, $8, $9::jsonb, $10::jsonb
) )
RETURNING id::text`, RETURNING id::text`,
@ -202,7 +320,7 @@ RETURNING id::text`,
} }
if _, err := tx.Exec(ctx, ` if _, err := tx.Exec(ctx, `
UPDATE gateway_tasks UPDATE gateway_tasks
SET attempt_count = GREATEST(attempt_count, $2), updated_at = now() SET attempt_count = GREATEST(attempt_count, $2::int), updated_at = now()
WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil { WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil {
return "", err return "", err
} }
@ -252,7 +370,7 @@ SET metrics = jsonb_set(
true true
) )
WHERE task_id = $1::uuid WHERE task_id = $1::uuid
AND attempt_no = $2`, taskID, attemptNo, string(entryJSON)) AND attempt_no = $2::int`, taskID, attemptNo, string(entryJSON))
return err return err
} }
@ -386,17 +504,17 @@ func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptIn
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
_, err := s.pool.Exec(ctx, ` _, err := s.pool.Exec(ctx, `
UPDATE gateway_task_attempts UPDATE gateway_task_attempts
SET status = $2, SET status = $2::text,
retryable = $3, retryable = $3,
request_id = NULLIF($4, ''), request_id = NULLIF($4::text, ''),
usage = $5::jsonb, usage = $5::jsonb,
metrics = $6::jsonb, metrics = $6::jsonb,
response_snapshot = $7::jsonb, response_snapshot = $7::jsonb,
response_started_at = $8::timestamptz, response_started_at = $8::timestamptz,
response_finished_at = $9::timestamptz, response_finished_at = $9::timestamptz,
response_duration_ms = $10, response_duration_ms = $10,
error_code = NULLIF($11, ''), error_code = NULLIF($11::text, ''),
error_message = NULLIF($12, ''), error_message = NULLIF($12::text, ''),
finished_at = now() finished_at = now()
WHERE id = $1::uuid`, WHERE id = $1::uuid`,
input.AttemptID, input.AttemptID,
@ -438,6 +556,9 @@ SET status = 'succeeded',
error = NULL, error = NULL,
error_code = NULL, error_code = NULL,
error_message = NULL, error_message = NULL,
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
finished_at = now(), finished_at = now(),
updated_at = now() updated_at = now()
WHERE id = $1::uuid`, WHERE id = $1::uuid`,
@ -561,14 +682,17 @@ func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureIn
if _, err := s.pool.Exec(ctx, ` if _, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks UPDATE gateway_tasks
SET status = 'failed', SET status = 'failed',
error = NULLIF($2, ''), error = NULLIF($2::text, ''),
error_code = NULLIF($3, ''), error_code = NULLIF($3::text, ''),
error_message = NULLIF($2, ''), error_message = NULLIF($2::text, ''),
request_id = NULLIF($4, ''), request_id = NULLIF($4::text, ''),
metrics = $5::jsonb, metrics = $5::jsonb,
response_started_at = $6::timestamptz, response_started_at = $6::timestamptz,
response_finished_at = $7::timestamptz, response_finished_at = $7::timestamptz,
response_duration_ms = $8, response_duration_ms = $8,
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
finished_at = now(), finished_at = now(),
updated_at = now() updated_at = now()
WHERE id = $1::uuid`, WHERE id = $1::uuid`,
@ -604,7 +728,7 @@ WITH next_seq AS (
WHERE task_id = $1::uuid WHERE task_id = $1::uuid
) )
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8 SELECT $1::uuid, next_seq.seq, $2::text, NULLIF($3::text, ''), NULLIF($4::text, ''), $5, NULLIF($6::text, ''), $7::jsonb, $8
FROM next_seq FROM next_seq
RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''), RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''),
COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`, COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`,
@ -688,7 +812,7 @@ func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastEr
_, err := s.pool.Exec(ctx, ` _, err := s.pool.Exec(ctx, `
UPDATE runtime_client_states UPDATE runtime_client_states
SET running_count = GREATEST(running_count - 1, 0), SET running_count = GREATEST(running_count - 1, 0),
last_error = NULLIF($2, ''), last_error = NULLIF($2::text, ''),
updated_at = now() updated_at = now()
WHERE client_id = $1`, clientID, lastError) WHERE client_id = $1`, clientID, lastError)
return err return err

View File

@ -0,0 +1,6 @@
ALTER TABLE IF EXISTS gateway_tasks
ADD COLUMN IF NOT EXISTS async_mode boolean NOT NULL DEFAULT false;
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_async_queue
ON gateway_tasks(async_mode, status, next_run_at, priority, created_at)
WHERE async_mode = true;

View File

@ -0,0 +1,19 @@
ALTER TABLE IF EXISTS gateway_tasks
ADD COLUMN IF NOT EXISTS river_job_id bigint,
ADD COLUMN IF NOT EXISTS remote_task_id text,
ADD COLUMN IF NOT EXISTS remote_task_payload jsonb;
UPDATE gateway_tasks
SET remote_task_payload = '{}'::jsonb
WHERE remote_task_payload IS NULL;
ALTER TABLE IF EXISTS gateway_tasks
ALTER COLUMN remote_task_payload SET DEFAULT '{}'::jsonb;
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_river_job
ON gateway_tasks(river_job_id)
WHERE river_job_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_async_recover
ON gateway_tasks(async_mode, status, priority, created_at)
WHERE async_mode = true AND status IN ('queued', 'running');

View File

@ -46,7 +46,6 @@ import {
getHealth, getHealth,
getNetworkProxyConfig, getNetworkProxyConfig,
getRunnerPolicy, getRunnerPolicy,
getTask,
getWalletSummary, getWalletSummary,
listAccessRules, listAccessRules,
listAuditLogs, listAuditLogs,
@ -72,6 +71,7 @@ import {
listUserGroups, listUserGroups,
listUsers, listUsers,
loginLocalAccount, loginLocalAccount,
pollTaskUntilSettled,
registerLocalAccount, registerLocalAccount,
replacePlatformModels, replacePlatformModels,
setUserWalletBalance, setUserWalletBalance,
@ -788,7 +788,11 @@ export function App() {
setCoreMessage(''); setCoreMessage('');
try { try {
const response = await runTask(credential, taskForm); const response = await runTask(credential, taskForm);
const detail = await getTask(credential, response.task.id); const syncTask = (detail: GatewayTask) => {
setTaskResult(detail);
setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]);
};
const detail = await pollTaskUntilSettled(credential, response.task, { onUpdate: syncTask });
setTaskResult(detail); setTaskResult(detail);
setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]); setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]);
invalidateDataKeys('tasks', 'wallet', 'walletTransactions'); invalidateDataKeys('tasks', 'wallet', 'walletTransactions');

View File

@ -521,6 +521,7 @@ export async function createChatTask(
): Promise<{ task: GatewayTask; next: Record<string, string> }> { ): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/chat/completions', { return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/chat/completions', {
body: input, body: input,
headers: { 'X-Async': 'true' },
method: 'POST', method: 'POST',
token, token,
}); });
@ -598,6 +599,7 @@ export async function createImageGenerationTask(
): Promise<{ task: GatewayTask; next: Record<string, string> }> { ): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/generations', { return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/generations', {
body: input, body: input,
headers: { 'X-Async': 'true' },
method: 'POST', method: 'POST',
token, token,
}); });
@ -609,6 +611,7 @@ export async function createImageEditTask(
): Promise<{ task: GatewayTask; next: Record<string, string> }> { ): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/edits', { return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/edits', {
body: input, body: input,
headers: { 'X-Async': 'true' },
method: 'POST', method: 'POST',
token, token,
}); });
@ -636,6 +639,7 @@ export async function createVideoGenerationTask(
): Promise<{ task: GatewayTask; next: Record<string, string> }> { ): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/videos/generations', { return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/videos/generations', {
body: input, body: input,
headers: { 'X-Async': 'true' },
method: 'POST', method: 'POST',
token, token,
}); });
@ -656,6 +660,33 @@ export async function getTask(token: string, taskId: string): Promise<GatewayTas
return request<GatewayTask>(`/api/workspace/tasks/${taskId}`, { token }); return request<GatewayTask>(`/api/workspace/tasks/${taskId}`, { token });
} }
export async function pollTaskUntilSettled(
token: string,
task: GatewayTask,
options: { intervalMs?: number; maxAttempts?: number | null; onUpdate?: (task: GatewayTask) => void } = {},
): Promise<GatewayTask> {
let detail = task;
const intervalMs = options.intervalMs ?? 1200;
const maxAttempts = options.maxAttempts ?? Number.POSITIVE_INFINITY;
for (let attempt = 0; attempt < maxAttempts; attempt += 1) {
if (!taskIsPending(detail.status)) return detail;
try {
detail = await getTask(token, detail.id);
options.onUpdate?.(detail);
if (!taskIsPending(detail.status)) return detail;
} catch {
// Backend restarts or short network gaps should not turn a durable task into a failed UI run.
// Only an explicit terminal task status from the task detail endpoint settles the run.
}
await delay(intervalMs);
}
return detail;
}
export function taskIsPending(status: string) {
return status === 'queued' || status === 'running' || status === 'submitting';
}
export async function listTasks(token: string, query: WorkspaceTaskQuery): Promise<ListResponse<GatewayTask>> { export async function listTasks(token: string, query: WorkspaceTaskQuery): Promise<ListResponse<GatewayTask>> {
const search = new URLSearchParams({ const search = new URLSearchParams({
page: String(query.page), page: String(query.page),
@ -707,9 +738,9 @@ export async function getNetworkProxyConfig(token: string): Promise<GatewayNetwo
async function request<T>( async function request<T>(
path: string, path: string,
options: { token?: string; auth?: boolean; method?: string; body?: unknown } = {}, options: { token?: string; auth?: boolean; method?: string; body?: unknown; headers?: Record<string, string> } = {},
): Promise<T> { ): Promise<T> {
const headers: Record<string, string> = {}; const headers: Record<string, string> = { ...(options.headers ?? {}) };
if (options.auth !== false && options.token) { if (options.auth !== false && options.token) {
headers.Authorization = `Bearer ${options.token}`; headers.Authorization = `Bearer ${options.token}`;
} }
@ -731,6 +762,10 @@ async function request<T>(
return response.json() as Promise<T>; return response.json() as Promise<T>;
} }
function delay(ms: number) {
return new Promise((resolve) => window.setTimeout(resolve, ms));
}
function parseErrorMessage(body: string) { function parseErrorMessage(body: string) {
return formatGatewayErrorDetails(parseErrorDetails(body)); return formatGatewayErrorDetails(parseErrorDetails(body));
} }

View File

@ -21,7 +21,7 @@ import { mermaid } from '@streamdown/mermaid';
import type { GatewayApiKey, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts'; import type { GatewayApiKey, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts';
import { Bot, ChevronDown, Image as ImageIcon, MessageSquarePlus, Paperclip, Send, Sparkles, Video } from 'lucide-react'; import { Bot, ChevronDown, Image as ImageIcon, MessageSquarePlus, Paperclip, Send, Sparkles, Video } from 'lucide-react';
import { Badge, Button, Select, Textarea } from '../components/ui'; import { Badge, Button, Select, Textarea } from '../components/ui';
import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, getTask, streamChatCompletionText } from '../api'; import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, pollTaskUntilSettled, streamChatCompletionText, taskIsPending } from '../api';
import type { PlaygroundMode } from '../types'; import type { PlaygroundMode } from '../types';
import { import {
defaultMediaGenerationSettings, defaultMediaGenerationSettings,
@ -170,14 +170,12 @@ export function PlaygroundPage(props: {
resumableRuns.forEach((run) => { resumableRuns.forEach((run) => {
if (!run.task?.id) return; if (!run.task?.id) return;
resumedTaskIdsRef.current.add(run.task.id); resumedTaskIdsRef.current.add(run.task.id);
void pollTaskUntilSettled(credential, run.task) void pollTaskUntilSettled(credential, run.task, {
onUpdate: (detail) => updateMediaRunFromTask(run.localId, detail),
})
.then((detail) => { .then((detail) => {
if (!isMountedRef.current) return; if (!isMountedRef.current) return;
setMediaRuns((current) => updateMediaRun(current, run.localId, { updateMediaRunFromTask(run.localId, detail);
error: gatewayTaskErrorText(detail, '任务执行失败'),
status: detail.status,
task: detail,
}));
}) })
.catch((err) => { .catch((err) => {
if (!isMountedRef.current) return; if (!isMountedRef.current) return;
@ -249,13 +247,11 @@ export function PlaygroundPage(props: {
async function pollMediaRunUntilSettled(credential: string, localId: string, task: GatewayTask) { async function pollMediaRunUntilSettled(credential: string, localId: string, task: GatewayTask) {
try { try {
const detail = await pollTaskUntilSettled(credential, task); const detail = await pollTaskUntilSettled(credential, task, {
onUpdate: (nextTask) => updateMediaRunFromTask(localId, nextTask),
});
if (!isMountedRef.current) return; if (!isMountedRef.current) return;
setMediaRuns((current) => updateMediaRun(current, localId, { updateMediaRunFromTask(localId, detail);
error: gatewayTaskErrorText(detail, '任务执行失败'),
status: detail.status,
task: detail,
}));
} catch (err) { } catch (err) {
if (!isMountedRef.current) return; if (!isMountedRef.current) return;
const errorMessage = err instanceof Error ? err.message : '任务状态同步失败'; const errorMessage = err instanceof Error ? err.message : '任务状态同步失败';
@ -263,6 +259,15 @@ export function PlaygroundPage(props: {
} }
} }
function updateMediaRunFromTask(localId: string, task: GatewayTask) {
if (!isMountedRef.current) return;
setMediaRuns((current) => updateMediaRun(current, localId, {
error: taskIsPending(task.status) ? '' : gatewayTaskErrorText(task, '任务执行失败'),
status: task.status,
task,
}));
}
function editMediaRun(run: MediaGenerationRun) { function editMediaRun(run: MediaGenerationRun) {
setPrompt(run.prompt); setPrompt(run.prompt);
setMediaSettings(run.settings); setMediaSettings(run.settings);
@ -1042,20 +1047,6 @@ function updateMediaRun(runs: MediaGenerationRun[], localId: string, patch: Part
return runs.map((run) => run.localId === localId ? { ...run, ...patch } : run); return runs.map((run) => run.localId === localId ? { ...run, ...patch } : run);
} }
async function pollTaskUntilSettled(token: string, task: GatewayTask) {
let detail = task;
for (let attempt = 0; attempt < 20; attempt += 1) {
detail = await getTask(token, detail.id);
if (!taskIsPending(detail.status)) return detail;
await delay(1200);
}
return detail;
}
function taskIsPending(status: string) {
return status === 'queued' || status === 'running' || status === 'submitting';
}
function readStoredMediaRuns(): MediaGenerationRun[] { function readStoredMediaRuns(): MediaGenerationRun[] {
if (typeof window === 'undefined') return []; if (typeof window === 'undefined') return [];
try { try {
@ -1182,10 +1173,6 @@ function booleanFromUnknown(value: unknown, fallback: boolean) {
return fallback; return fallback;
} }
function delay(ms: number) {
return new Promise((resolve) => window.setTimeout(resolve, ms));
}
function newLocalId() { function newLocalId() {
return typeof crypto !== 'undefined' && 'randomUUID' in crypto return typeof crypto !== 'undefined' && 'randomUUID' in crypto
? crypto.randomUUID() ? crypto.randomUUID()

View File

@ -593,11 +593,14 @@ function RateLimitStatusTable(props: { statuses: ModelRateLimitStatus[]; platfor
<TableRow className="shTableHeader"> <TableRow className="shTableHeader">
<TableHead></TableHead> <TableHead></TableHead>
<TableHead></TableHead> <TableHead></TableHead>
<TableHead></TableHead> <TableHead className="platformLimitMetricHead platformLimitNumberHead" title="正在执行 / 并发上限 / 排队任务">
<TableHead>TPM</TableHead> <span></span>
<TableHead>RPM</TableHead> <small> / / </small>
<TableHead></TableHead> </TableHead>
<TableHead></TableHead> <TableHead className="platformLimitNumberHead">TPM</TableHead>
<TableHead className="platformLimitNumberHead">RPM</TableHead>
<TableHead className="platformLimitStatusHead"></TableHead>
<TableHead className="platformLimitNumberHead"></TableHead>
</TableRow> </TableRow>
{props.statuses.map((status) => { {props.statuses.map((status) => {
const platform = props.platformMap.get(status.platformId); const platform = props.platformMap.get(status.platformId);
@ -615,12 +618,12 @@ function RateLimitStatusTable(props: { statuses: ModelRateLimitStatus[]; platfor
<small>{status.provider}</small> <small>{status.provider}</small>
</span> </span>
</TableCell> </TableCell>
<TableCell>{metricCell(status.concurrent)}</TableCell> <TableCell className="platformLimitNumberCell">{concurrencyMetricCell(status)}</TableCell>
<TableCell>{metricCell(status.tpm, true)}</TableCell> <TableCell className="platformLimitNumberCell">{metricCell(status.tpm, true)}</TableCell>
<TableCell>{metricCell(status.rpm)}</TableCell> <TableCell className="platformLimitNumberCell">{metricCell(status.rpm)}</TableCell>
<TableCell>{modelRuntimeStatusCell(status, props.now)}</TableCell> <TableCell className="platformLimitStatusCell">{modelRuntimeStatusCell(status, props.now)}</TableCell>
<TableCell> <TableCell className="platformLimitNumberCell">
<span className="rateLoadCell"> <span className="rateLoadCell" data-overloaded={status.loadRatio > 0.8 ? 'true' : undefined}>
<strong>{formatPercent(status.loadRatio)}</strong> <strong>{formatPercent(status.loadRatio)}</strong>
<span className="rateLoadTrack"><i style={{ width: `${Math.min(status.loadRatio * 100, 100)}%` }} /></span> <span className="rateLoadTrack"><i style={{ width: `${Math.min(status.loadRatio * 100, 100)}%` }} /></span>
</span> </span>
@ -1210,6 +1213,16 @@ function metricCell(metric: ModelRateLimitStatus['rpm'], includeReserved = false
); );
} }
function concurrencyMetricCell(status: ModelRateLimitStatus) {
const queuedTasks = status.queuedTasks ?? 0;
const limitText = status.concurrent.limited ? formatLimit(status.concurrent.limitValue) : '不限';
return (
<span className="rateMetricCell" title="正在执行 / 并发上限 / 排队任务">
<strong>{formatLimit(status.concurrent.currentValue)} / {limitText} / {formatLimit(queuedTasks)}</strong>
</span>
);
}
function reservedMetricText(metric: ModelRateLimitStatus['rpm']) { function reservedMetricText(metric: ModelRateLimitStatus['rpm']) {
return `已结算 ${formatLimit(metric.usedValue)} + 预占 ${formatLimit(metric.reservedValue)}`; return `已结算 ${formatLimit(metric.usedValue)} + 预占 ${formatLimit(metric.reservedValue)}`;
} }

View File

@ -1016,27 +1016,65 @@
} }
.platformLimitTable .shTableRow { .platformLimitTable .shTableRow {
grid-template-columns: clamp(150px, 16vw, 220px) minmax(148px, 1fr) minmax(104px, max-content) minmax(136px, max-content) minmax(122px, max-content) minmax(132px, max-content) minmax(128px, max-content); grid-template-columns: minmax(180px, 1.15fr) minmax(160px, 0.95fr) 150px 170px 140px 132px 132px;
min-width: 920px; min-width: 1064px;
} }
.platformLimitTable .shTableHead, .platformLimitTable .shTableHead,
.platformLimitTable .shTableCell { .platformLimitTable .shTableCell {
display: grid;
align-content: center;
min-height: 68px;
padding-right: 10px; padding-right: 10px;
padding-left: 10px; padding-left: 10px;
} }
.platformLimitTable .shTableHead {
min-height: 74px;
}
.platformLimitMetricHead {
display: grid;
align-content: center;
gap: 3px;
white-space: normal;
}
.platformLimitMetricHead small {
overflow: hidden;
color: var(--muted-foreground);
font-size: var(--font-size-xs);
font-weight: var(--font-weight-medium);
line-height: 1.2;
text-overflow: ellipsis;
white-space: nowrap;
}
.platformLimitNumberHead,
.platformLimitNumberCell {
justify-items: start;
}
.platformLimitStatusHead,
.platformLimitStatusCell {
justify-items: start;
}
.rateMetricCell, .rateMetricCell,
.rateLoadCell { .rateLoadCell {
display: grid; display: grid;
min-width: 0; min-width: 0;
gap: 4px; gap: 4px;
width: 100%;
align-content: start;
font-variant-numeric: tabular-nums;
} }
.rateMetricCell strong, .rateMetricCell strong,
.rateLoadCell strong { .rateLoadCell strong {
color: var(--text-strong); color: var(--text-strong);
font-size: var(--font-size-sm); font-size: var(--font-size-sm);
line-height: 1.25;
} }
.rateMetricCell small { .rateMetricCell small {
@ -1050,6 +1088,7 @@
.rateLoadTrack { .rateLoadTrack {
display: block; display: block;
height: 6px; height: 6px;
width: min(112px, 100%);
overflow: hidden; overflow: hidden;
border-radius: 999px; border-radius: 999px;
background: #eef2f6; background: #eef2f6;
@ -1062,6 +1101,18 @@
background: #0f766e; background: #0f766e;
} }
.rateLoadCell[data-overloaded="true"] strong {
color: var(--destructive);
}
.rateLoadCell[data-overloaded="true"] .rateLoadTrack {
background: #fee2e2;
}
.rateLoadCell[data-overloaded="true"] .rateLoadTrack i {
background: var(--destructive);
}
.platformModelToolbar { .platformModelToolbar {
display: grid; display: grid;
grid-template-columns: minmax(220px, 0.6fr) minmax(260px, 1fr); grid-template-columns: minmax(220px, 0.6fr) minmax(260px, 1fr);

View File

@ -1,5 +1,3 @@
go 1.23 go 1.23
use ( use ./apps/api
./apps/api
)

View File

@ -1,16 +1,12 @@
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=

View File

@ -775,6 +775,7 @@ export interface ModelRateLimitStatus {
platformCooldownUntil?: string; platformCooldownUntil?: string;
modelCooldownUntil?: string; modelCooldownUntil?: string;
concurrent: RateLimitMetricStatus; concurrent: RateLimitMetricStatus;
queuedTasks: number;
rpm: RateLimitMetricStatus; rpm: RateLimitMetricStatus;
tpm: RateLimitMetricStatus; tpm: RateLimitMetricStatus;
loadRatio: number; loadRatio: number;
@ -808,6 +809,7 @@ export interface GatewayTask {
resolvedModel?: string; resolvedModel?: string;
requestId?: string; requestId?: string;
request?: Record<string, unknown>; request?: Record<string, unknown>;
asyncMode?: boolean;
status: 'queued' | 'running' | 'succeeded' | 'failed' | 'cancelled' | string; status: 'queued' | 'running' | 'succeeded' | 'failed' | 'cancelled' | string;
result?: Record<string, unknown>; result?: Record<string, unknown>;
billings?: unknown[]; billings?: unknown[];