diff --git a/apps/api/cmd/gateway/main.go b/apps/api/cmd/gateway/main.go index e3f329a..9bb935b 100644 --- a/apps/api/cmd/gateway/main.go +++ b/apps/api/cmd/gateway/main.go @@ -33,18 +33,19 @@ func main() { if recovery, err := db.RecoverInterruptedRuntimeState(ctx); err != nil { logger.Error("recover interrupted runtime state failed", "error", err) os.Exit(1) - } else if recovery.ReleasedConcurrencyLeases > 0 || recovery.ReleasedRateReservations > 0 || recovery.FailedAttempts > 0 || recovery.FailedTasks > 0 { + } else if recovery.ReleasedConcurrencyLeases > 0 || recovery.ReleasedRateReservations > 0 || recovery.FailedAttempts > 0 || recovery.FailedTasks > 0 || recovery.RequeuedAsyncTasks > 0 { logger.Warn("interrupted runtime state recovered", "releasedConcurrencyLeases", recovery.ReleasedConcurrencyLeases, "releasedRateReservations", recovery.ReleasedRateReservations, "failedAttempts", recovery.FailedAttempts, "failedTasks", recovery.FailedTasks, + "requeuedAsyncTasks", recovery.RequeuedAsyncTasks, ) } server := &http.Server{ Addr: cfg.HTTPAddr, - Handler: httpapi.NewServer(cfg, db, logger), + Handler: httpapi.NewServerWithContext(ctx, cfg, db, logger), ReadHeaderTimeout: 10 * time.Second, } diff --git a/apps/api/go.mod b/apps/api/go.mod index 6b7a1fa..fe401f1 100644 --- a/apps/api/go.mod +++ b/apps/api/go.mod @@ -4,14 +4,28 @@ go 1.23 require ( github.com/golang-jwt/jwt/v5 v5.2.2 - github.com/jackc/pgx/v5 v5.7.2 - golang.org/x/crypto v0.31.0 + github.com/jackc/pgx/v5 v5.9.2 + github.com/riverqueue/river v0.24.0 + github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0 + github.com/riverqueue/river/rivertype v0.24.0 + golang.org/x/crypto v0.37.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - golang.org/x/sync v0.10.0 // indirect - golang.org/x/text v0.21.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/riverqueue/river/riverdriver v0.24.0 // indirect + github.com/riverqueue/river/rivershared v0.24.0 // indirect + github.com/stretchr/testify v1.11.1 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + go.uber.org/goleak v1.3.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/text v0.36.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/apps/api/go.sum b/apps/api/go.sum index 5a44fea..26c9282 100644 --- a/apps/api/go.sum +++ b/apps/api/go.sum @@ -3,28 +3,63 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= -github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riverqueue/river v0.24.0 h1:CesL6vymWgz0d+zNwtnSGRWaB+E8Dax+o9cxD7sUmKc= +github.com/riverqueue/river v0.24.0/go.mod h1:UZ3AxU5t6WtyqNssaea/AkRS8h/kJ+E9ImSB3xyb3ns= +github.com/riverqueue/river/riverdriver v0.24.0 h1:HqGgGkls11u+YKDA7cKOdYKlQwRNJyHuGa3UtOvpdT0= +github.com/riverqueue/river/riverdriver v0.24.0/go.mod h1:dEew9DDIKenNvzpm8Edw8+PkqP3c0zl1fKjiQTq2n/w= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0 h1:yV37OIbRrhRwIiGeRT7P4D3szhAemu87BgCf8gTCoU4= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0/go.mod h1:QfznySVKC4ljx53syd/bA/LRSsydAyuD3Q9/EbSniKA= +github.com/riverqueue/river/rivershared v0.24.0 h1:KysokksW75pug2a5RTOc6WESOupWmsylVc6VWvAx+4Y= +github.com/riverqueue/river/rivershared v0.24.0/go.mod h1:UIBfSdai0oWFlwFcoqG4DZX83iA/fLWTEBGrj7Oe1ho= +github.com/riverqueue/river/rivertype v0.24.0 h1:xrQZm/h6U8TBPyTsQPYD5leOapuoBAcdz30bdBwTqOg= +github.com/riverqueue/river/rivertype v0.24.0/go.mod h1:lmdl3vLNDfchDWbYdW2uAocIuwIN+ZaXqAukdSCFqWs= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/apps/api/internal/clients/clients_test.go b/apps/api/internal/clients/clients_test.go index da8563c..6e6ca0b 100644 --- a/apps/api/internal/clients/clients_test.go +++ b/apps/api/internal/clients/clients_test.go @@ -329,6 +329,7 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { var gotModel string var gotText string var gotFirstFrameRole string + var submittedRemoteTaskID string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotAuth = r.Header.Get("Authorization") switch r.Method + " " + r.URL.Path { @@ -385,6 +386,13 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { "volcesPollTimeoutSeconds": 1, }, }, + OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error { + submittedRemoteTaskID = remoteTaskID + if payload["id"] != "cgt-test" { + t.Fatalf("unexpected submitted payload: %+v", payload) + } + return nil + }, }) if err != nil { t.Fatalf("run volces video: %v", err) @@ -392,6 +400,9 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { if submitPath != "/contents/generations/tasks" || pollPath != "/contents/generations/tasks/cgt-test" || gotAuth != "Bearer volces-key" { t.Fatalf("unexpected paths/auth submit=%s poll=%s auth=%s", submitPath, pollPath, gotAuth) } + if submittedRemoteTaskID != "cgt-test" { + t.Fatalf("remote task submit callback did not receive task id, got %q", submittedRemoteTaskID) + } if gotModel != "doubao-seedance-2-0-260128" || gotFirstFrameRole != "first_frame" { t.Fatalf("unexpected submitted model=%s role=%s", gotModel, gotFirstFrameRole) } @@ -407,6 +418,53 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { } } +func TestVolcesClientVideoResumePollsExistingTaskID(t *testing.T) { + var submitCalled bool + var pollPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "POST /contents/generations/tasks": + submitCalled = true + t.Fatalf("resume should skip upstream submit when remote task id exists") + case "GET /contents/generations/tasks/cgt-existing": + pollPath = r.URL.Path + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "cgt-existing", + "status": "succeeded", + "created_at": 789, + "content": map[string]any{"video_url": "https://example.com/resumed.mp4"}, + }) + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + })) + defer server.Close() + + response, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Model: "豆包Seedance-2.0", + Body: map[string]any{"prompt": "resume polling", "pollIntervalMs": 100, "pollTimeoutSeconds": 1}, + RemoteTaskID: "cgt-existing", + Candidate: store.RuntimeModelCandidate{ + BaseURL: server.URL, + ModelName: "豆包Seedance-2.0", + Credentials: map[string]any{"apiKey": "volces-key"}, + }, + }) + if err != nil { + t.Fatalf("resume volces video: %v", err) + } + if submitCalled || pollPath != "/contents/generations/tasks/cgt-existing" { + t.Fatalf("resume should poll existing task only, submit=%v poll=%s", submitCalled, pollPath) + } + data, _ := response.Result["data"].([]any) + item, _ := data[0].(map[string]any) + if response.Result["upstream_task_id"] != "cgt-existing" || item["url"] != "https://example.com/resumed.mp4" { + t.Fatalf("unexpected resumed response: %+v", response.Result) + } +} + func extractText(result map[string]any) string { choices, _ := result["choices"].([]any) choice, _ := choices[0].(map[string]any) diff --git a/apps/api/internal/clients/types.go b/apps/api/internal/clients/types.go index 09b4e4f..778ae92 100644 --- a/apps/api/internal/clients/types.go +++ b/apps/api/internal/clients/types.go @@ -11,14 +11,17 @@ import ( ) type Request struct { - Kind string - ModelType string - Model string - Body map[string]any - Candidate store.RuntimeModelCandidate - HTTPClient *http.Client - Stream bool - StreamDelta StreamDelta + Kind string + ModelType string + Model string + Body map[string]any + Candidate store.RuntimeModelCandidate + HTTPClient *http.Client + RemoteTaskID string + RemoteTaskPayload map[string]any + OnRemoteTaskSubmitted func(remoteTaskID string, payload map[string]any) error + Stream bool + StreamDelta StreamDelta } type Response struct { diff --git a/apps/api/internal/clients/volces.go b/apps/api/internal/clients/volces.go index 5deb9fc..020d134 100644 --- a/apps/api/internal/clients/volces.go +++ b/apps/api/internal/clients/volces.go @@ -67,16 +67,25 @@ func (c VolcesClient) runImage(ctx context.Context, request Request, apiKey stri } func (c VolcesClient) runVideo(ctx context.Context, request Request, apiKey string) (Response, error) { - body := volcesVideoBody(request) submitStartedAt := time.Now() - submitResult, submitRequestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body) - submitFinishedAt := time.Now() - if err != nil { - return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, submitFinishedAt) - } - upstreamTaskID := strings.TrimSpace(stringFromAny(submitResult["id"])) + submitRequestID := strings.TrimSpace(request.RemoteTaskID) + upstreamTaskID := strings.TrimSpace(request.RemoteTaskID) if upstreamTaskID == "" { - return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false} + body := volcesVideoBody(request) + submitResult, requestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body) + submitRequestID = requestID + if err != nil { + return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, time.Now()) + } + upstreamTaskID = strings.TrimSpace(stringFromAny(submitResult["id"])) + if upstreamTaskID == "" { + return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false} + } + if request.OnRemoteTaskSubmitted != nil { + if err := request.OnRemoteTaskSubmitted(upstreamTaskID, submitResult); err != nil { + return Response{}, err + } + } } interval := volcesPollInterval(request) diff --git a/apps/api/internal/httpapi/core_flow_integration_test.go b/apps/api/internal/httpapi/core_flow_integration_test.go index 3257ee3..61ad82d 100644 --- a/apps/api/internal/httpapi/core_flow_integration_test.go +++ b/apps/api/internal/httpapi/core_flow_integration_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" "github.com/jackc/pgx/v5/pgxpool" @@ -37,7 +38,11 @@ func TestCoreLocalFlow(t *testing.T) { } defer db.Close() - handler := NewServer(config.Config{ + assertRuntimeRecoveryReleasesPendingRateReservations(t, ctx, db) + + serverCtx, cancelServer := context.WithCancel(ctx) + defer cancelServer() + handler := NewServerWithContext(serverCtx, config.Config{ AppEnv: "test", HTTPAddr: ":0", DatabaseURL: databaseURL, @@ -169,13 +174,14 @@ func TestCoreLocalFlow(t *testing.T) { if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountBefore); err != nil { t.Fatalf("count tasks before scoped request: %v", err) } + defaultImageModel := "openai:gpt-image-1" doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", chatOnlyAPIKeyResponse.Secret, map[string]any{ - "model": "gpt-image-1", + "model": defaultImageModel, "prompt": "scope should block this", }, http.StatusForbidden, nil) doJSON(t, server.URL, http.MethodPost, "/api/v1/pricing/estimate", chatOnlyAPIKeyResponse.Secret, map[string]any{ "kind": "images.generations", - "model": "gpt-image-1", + "model": defaultImageModel, "prompt": "scope should block this estimate", }, http.StatusForbidden, nil) var taskCountAfter int @@ -309,8 +315,9 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { Result map[string]any `json:"result"` } `json:"task"` } + defaultTextModel := "openai:gpt-4o-mini" doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ - "model": "gpt-4o-mini", + "model": defaultTextModel, "runMode": "simulation", "simulation": true, "simulationDurationMs": 5, @@ -334,7 +341,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { var compatChat map[string]any doJSON(t, server.URL, http.MethodPost, "/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ - "model": "gpt-4o-mini", + "model": defaultTextModel, "runMode": "simulation", "messages": []map[string]any{{"role": "user", "content": "ping"}}, "simulation": true, @@ -352,7 +359,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { } `json:"task"` } doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", apiKeyResponse.Secret, map[string]any{ - "model": "gpt-image-1", + "model": defaultImageModel, "runMode": "simulation", "prompt": "a tiny gateway console", "size": "1024x1024", @@ -372,7 +379,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { } `json:"task"` } doJSON(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{ - "model": "gpt-image-1", + "model": defaultImageModel, "runMode": "simulation", "prompt": "replace background with clean studio light", "image": "https://example.com/source.png", @@ -384,6 +391,42 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil { t.Fatalf("unexpected image edit task: %+v", imageEditResponse.Task) } + doubaoLiteImageEditModel := "doubao-5.0-lite图像编辑" + var doubaoLitePlatformModel struct { + ID string `json:"id"` + } + doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{ + "canonicalModelKey": "easyai:doubao-5.0-lite图像编辑", + "modelName": doubaoLiteImageEditModel, + "modelAlias": doubaoLiteImageEditModel, + "modelType": []string{"image_edit", "image_generate"}, + "displayName": doubaoLiteImageEditModel, + }, http.StatusCreated, &doubaoLitePlatformModel) + var doubaoLiteAsyncEdit struct { + TaskID string `json:"taskId"` + Task struct { + ID string `json:"id"` + Status string `json:"status"` + AsyncMode bool `json:"asyncMode"` + } `json:"task"` + } + doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{ + "model": doubaoLiteImageEditModel, + "runMode": "simulation", + "prompt": "turn the attached bright desktop object into a clean product-style render", + "image": "https://example.com/doubao-lite-source.png", + "size": "2048x2048", + "simulation": true, + "simulationDurationMs": 5, + }, map[string]string{"X-Async": "true"}, http.StatusAccepted, &doubaoLiteAsyncEdit) + if doubaoLiteAsyncEdit.TaskID == "" || !doubaoLiteAsyncEdit.Task.AsyncMode { + t.Fatalf("doubao-5.0-lite image edit async task should be accepted: %+v", doubaoLiteAsyncEdit) + } + doubaoLiteAsyncEditDone := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, doubaoLiteAsyncEdit.TaskID, []string{"succeeded"}, 5*time.Second) + if doubaoLiteAsyncEditDone.Status != "succeeded" { + t.Fatalf("doubao-5.0-lite image edit async task should succeed through river queue, got %+v", doubaoLiteAsyncEditDone) + } + var gptImageModelTypesRaw []byte if err := testPool.QueryRow(ctx, ` SELECT model_type @@ -641,6 +684,7 @@ WHERE reference_type = 'gateway_task' } rateLimitedModel := "rate-limit-smoke-" + suffixText + rateLimitWindowSeconds := 3 var rateLimitPolicySet struct { ID string `json:"id"` } @@ -652,7 +696,7 @@ WHERE reference_type = 'gateway_task' "maxAttempts": 1, }, "rateLimitPolicy": map[string]any{ - "rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": 60}}, + "rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": rateLimitWindowSeconds}}, }, }, http.StatusCreated, &rateLimitPolicySet) var rateLimitPlatformModel map[string]any @@ -682,6 +726,7 @@ WHERE reference_type = 'gateway_task' if rateLimitFailedTask.Task.Status != "failed" || rateLimitFailedTask.Task.ErrorCode != "bad_request" { t.Fatalf("failed rate-limited task should fail before consuming rpm: %+v", rateLimitFailedTask.Task) } + waitForRateLimitWindowHead(t, rateLimitWindowSeconds) var rateLimitTaskOne struct { Task struct { Status string `json:"status"` @@ -713,6 +758,38 @@ WHERE reference_type = 'gateway_task' if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" { t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task) } + var asyncRateLimitTask struct { + TaskID string `json:"taskId"` + Task struct { + ID string `json:"id"` + Status string `json:"status"` + AsyncMode bool `json:"asyncMode"` + } `json:"task"` + } + doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ + "model": rateLimitedModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "async queued"}}, + }, map[string]string{"X-Async": "true"}, http.StatusAccepted, &asyncRateLimitTask) + if asyncRateLimitTask.TaskID == "" || asyncRateLimitTask.Task.ID != asyncRateLimitTask.TaskID || !asyncRateLimitTask.Task.AsyncMode { + t.Fatalf("async task response should expose task id and async mode: %+v", asyncRateLimitTask) + } + asyncRateLimitDetail := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"queued"}, 2*time.Second) + if asyncRateLimitDetail.Status != "queued" { + t.Fatalf("async rate-limited task should return to queued state, got %+v", asyncRateLimitDetail) + } + if len(asyncRateLimitDetail.Attempts) == 0 || asyncRateLimitDetail.Attempts[0].ErrorCode != "rate_limit" { + t.Fatalf("async rate-limited task should record a rate_limit attempt before requeue: %+v", asyncRateLimitDetail) + } + asyncRateLimitCompleted := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"succeeded"}, time.Duration(rateLimitWindowSeconds+3)*time.Second) + if asyncRateLimitCompleted.Status != "succeeded" { + t.Fatalf("async rate-limited task should be pulled from queue after the limit window resets, got %+v", asyncRateLimitCompleted) + } + if len(asyncRateLimitCompleted.Attempts) < 2 || asyncRateLimitCompleted.Attempts[len(asyncRateLimitCompleted.Attempts)-1].Status != "succeeded" { + t.Fatalf("async rate-limited task should create a new successful attempt after requeue: %+v", asyncRateLimitCompleted) + } videoRouteModel := "video-route-smoke-" + suffixText var videoRoutePlatformModel map[string]any @@ -823,16 +900,19 @@ WHERE reference_type = 'gateway_task' Metrics map[string]any `json:"metrics"` } doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks/"+failoverTask.Task.ID, apiKeyResponse.Secret, nil, http.StatusOK, &failoverDetail) - if len(failoverDetail.Attempts) != 2 { - t.Fatalf("failover task history should include two attempts, got %+v", failoverDetail.Attempts) + if len(failoverDetail.Attempts) != 3 { + t.Fatalf("failover task history should include two failed retries plus one successful failover attempt, got %+v", failoverDetail.Attempts) } if failoverDetail.Attempts[0].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[0].Status != "failed" || !failoverDetail.Attempts[0].Retryable || failoverDetail.Attempts[0].ErrorCode == "" { t.Fatalf("first failover attempt should preserve failed platform and reason: %+v", failoverDetail.Attempts[0]) } - if failoverDetail.Attempts[1].PlatformName != "OpenAI Retry Success" || failoverDetail.Attempts[1].Status != "succeeded" || failoverDetail.Attempts[1].ResponseMS <= 0 { - t.Fatalf("second failover attempt should preserve successful platform: %+v", failoverDetail.Attempts[1]) + if failoverDetail.Attempts[1].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[1].Status != "failed" || !failoverDetail.Attempts[1].Retryable { + t.Fatalf("second failover attempt should preserve the same-client retry failure: %+v", failoverDetail.Attempts[1]) } - if summary, ok := failoverDetail.Metrics["attempts"].([]any); !ok || len(summary) != 2 { + if failoverDetail.Attempts[2].PlatformName != "OpenAI Retry Success" || failoverDetail.Attempts[2].Status != "succeeded" || failoverDetail.Attempts[2].ResponseMS <= 0 { + t.Fatalf("third failover attempt should preserve successful platform: %+v", failoverDetail.Attempts[2]) + } + if summary, ok := failoverDetail.Metrics["attempts"].([]any); !ok || len(summary) != 3 { t.Fatalf("task metrics should keep attempt-chain summary, got %+v", failoverDetail.Metrics) } @@ -1045,6 +1125,9 @@ WHERE m.platform_id = $1::uuid if resp.StatusCode != http.StatusOK || !bytes.Contains(body, []byte("task.completed")) { t.Fatalf("unexpected events response status=%d body=%s", resp.StatusCode, string(body)) } + if !bytes.Contains(body, []byte("task.running")) { + t.Fatalf("events response should include running transition event body=%s", string(body)) + } if !bytes.Contains(body, []byte("task.progress")) { t.Fatalf("events response should include progress events body=%s", string(body)) } @@ -1074,6 +1157,47 @@ WHERE m.platform_id = $1::uuid if callbackRows == 0 { t.Fatal("task progress callback outbox should receive events") } + + var restartAsyncTask struct { + TaskID string `json:"taskId"` + Task struct { + ID string `json:"id"` + Status string `json:"status"` + AsyncMode bool `json:"asyncMode"` + } `json:"task"` + } + doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{ + "model": defaultTextModel, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 2000, + "messages": []map[string]any{{"role": "user", "content": "river worker restart"}}, + }, map[string]string{"X-Async": "true"}, http.StatusAccepted, &restartAsyncTask) + if restartAsyncTask.TaskID == "" || !restartAsyncTask.Task.AsyncMode { + t.Fatalf("restart async task should be accepted as async: %+v", restartAsyncTask) + } + restartRunning := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, restartAsyncTask.TaskID, []string{"running"}, 3*time.Second) + if restartRunning.Status != "running" { + t.Fatalf("restart async task should be running before worker restart, got %+v", restartRunning) + } + cancelServer() + serverCtx2, cancelServer2 := context.WithCancel(ctx) + defer cancelServer2() + server2 := httptest.NewServer(NewServerWithContext(serverCtx2, config.Config{ + AppEnv: "test", + HTTPAddr: ":0", + DatabaseURL: databaseURL, + IdentityMode: "hybrid", + JWTSecret: "test-secret", + TaskProgressCallbackEnabled: true, + TaskProgressCallbackURL: "http://callback.local/task-progress", + CORSAllowedOrigin: "*", + }, db, slog.New(slog.NewTextHandler(io.Discard, nil)))) + defer server2.Close() + restartRecovered := waitForTaskStatus(t, server2.URL, apiKeyResponse.Secret, restartAsyncTask.TaskID, []string{"succeeded"}, 8*time.Second) + if restartRecovered.Status != "succeeded" { + t.Fatalf("river worker restart should recover and finish async task, got %+v", restartRecovered) + } } func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) { @@ -1089,6 +1213,51 @@ func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) { } } +func assertRuntimeRecoveryReleasesPendingRateReservations(t *testing.T, ctx context.Context, db *store.Store) { + t.Helper() + suffix := strconv.FormatInt(time.Now().UnixNano(), 10) + task, err := db.CreateTask(ctx, store.CreateTaskInput{ + Kind: "chat.completions", + Model: "recovery-reservation-smoke-" + suffix, + RunMode: "test", + Async: true, + Request: map[string]any{"messages": []any{}}, + }, &auth.User{ID: "recovery-user-" + suffix}) + if err != nil { + t.Fatalf("create recovery reservation task: %v", err) + } + scopeKey := "recovery-client-" + suffix + if _, err := db.ReserveRateLimits(ctx, task.ID, "", []store.RateLimitReservation{{ + ScopeType: "client", + ScopeKey: scopeKey, + Metric: "rpm", + Limit: 10, + Amount: 3, + WindowSeconds: 60, + }}); err != nil { + t.Fatalf("reserve recovery rate limit: %v", err) + } + recovery, err := db.RecoverInterruptedRuntimeState(ctx) + if err != nil { + t.Fatalf("recover interrupted runtime state with pending reservation: %v", err) + } + if recovery.ReleasedRateReservations == 0 { + t.Fatalf("recovery should release pending rate reservation, got %+v", recovery) + } + var reservedValue float64 + if err := db.Pool().QueryRow(ctx, ` +SELECT COALESCE(MAX(reserved_value), 0)::float8 +FROM gateway_rate_limit_counters +WHERE scope_type = 'client' + AND scope_key = $1 + AND metric = 'rpm'`, scopeKey).Scan(&reservedValue); err != nil { + t.Fatalf("query recovered rate limit counter: %v", err) + } + if reservedValue != 0 { + t.Fatalf("recovery should release reserved counter value, got %f", reservedValue) + } +} + func applyMigration(t *testing.T, ctx context.Context, databaseURL string) { t.Helper() _, filename, _, _ := runtime.Caller(0) @@ -1114,6 +1283,11 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) { } func doJSON(t *testing.T, baseURL string, method string, path string, token string, payload any, expectedStatus int, out any) { + t.Helper() + doJSONWithHeaders(t, baseURL, method, path, token, payload, nil, expectedStatus, out) +} + +func doJSONWithHeaders(t *testing.T, baseURL string, method string, path string, token string, payload any, headers map[string]string, expectedStatus int, out any) { t.Helper() var body io.Reader if payload != nil { @@ -1133,6 +1307,9 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri if token != "" { req.Header.Set("Authorization", "Bearer "+token) } + for key, value := range headers { + req.Header.Set(key, value) + } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("%s %s: %v", method, path, err) @@ -1149,6 +1326,50 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri } } +type taskWaitDetail struct { + ID string `json:"id"` + Status string `json:"status"` + Attempts []struct { + Status string `json:"status"` + ErrorCode string `json:"errorCode"` + } `json:"attempts"` +} + +func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string, statuses []string, timeout time.Duration) taskWaitDetail { + t.Helper() + wanted := map[string]bool{} + for _, status := range statuses { + wanted[status] = true + } + var detail taskWaitDetail + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskID, token, nil, http.StatusOK, &detail) + if wanted[detail.Status] { + return detail + } + time.Sleep(100 * time.Millisecond) + } + return detail +} + +func waitForRateLimitWindowHead(t *testing.T, windowSeconds int) { + t.Helper() + if windowSeconds <= 0 { + return + } + window := time.Duration(windowSeconds) * time.Second + deadline := time.Now().Add(window + time.Second) + for time.Now().Before(deadline) { + elapsed := time.Duration(time.Now().UnixNano() % int64(window)) + if elapsed < 300*time.Millisecond { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("timed out waiting for %ds rate limit window head", windowSeconds) +} + func stringSliceContains(values []string, target string) bool { for _, value := range values { if value == target { diff --git a/apps/api/internal/httpapi/handlers.go b/apps/api/internal/httpapi/handlers.go index 0ab9d15..7cf2f00 100644 --- a/apps/api/internal/httpapi/handlers.go +++ b/apps/api/internal/httpapi/handlers.go @@ -542,11 +542,13 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { writeError(w, http.StatusForbidden, "api key scope does not allow this capability") return } + asyncMode := asyncRequest(r) task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{ Kind: kind, Model: model, RunMode: runModeFromRequest(body), + Async: asyncMode, Request: body, }, user) if err != nil { @@ -554,6 +556,14 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { writeError(w, http.StatusInternalServerError, "create task failed") return } + if asyncMode { + if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil { + writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed") + return + } + writeTaskAccepted(w, task) + return + } if compatible { if boolValue(body, "stream") { flusher := prepareCompatibleStream(w) @@ -602,13 +612,23 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr) } - writeJSON(w, http.StatusAccepted, map[string]any{ - "task": result.Task, - "next": map[string]string{ - "events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID), - "detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID), - }, - }) + writeTaskAccepted(w, result.Task) + }) +} + +func asyncRequest(r *http.Request) bool { + value := strings.TrimSpace(strings.ToLower(r.Header.Get("x-async"))) + return value == "1" || value == "true" || value == "yes" || value == "on" +} + +func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) { + writeJSON(w, http.StatusAccepted, map[string]any{ + "taskId": task.ID, + "task": task, + "next": map[string]string{ + "events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID), + "detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID), + }, }) } diff --git a/apps/api/internal/httpapi/server.go b/apps/api/internal/httpapi/server.go index 8b2158a..8eb5a95 100644 --- a/apps/api/internal/httpapi/server.go +++ b/apps/api/internal/httpapi/server.go @@ -1,6 +1,7 @@ package httpapi import ( + "context" "log/slog" "net/http" "strings" @@ -12,6 +13,7 @@ import ( ) type Server struct { + ctx context.Context cfg config.Config store *store.Store auth *auth.Authenticator @@ -20,7 +22,12 @@ type Server struct { } func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler { + return NewServerWithContext(context.Background(), cfg, db, logger) +} + +func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler { server := &Server{ + ctx: ctx, cfg: cfg, store: db, auth: auth.New(cfg.JWTSecret, cfg.ServerMainBaseURL, cfg.ServerMainInternalToken), @@ -28,6 +35,7 @@ func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Han logger: logger, } server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey + server.runner.StartAsyncQueueWorker(ctx) mux := http.NewServeMux() mux.HandleFunc("GET /healthz", server.health) @@ -146,7 +154,7 @@ func (s *Server) cors(next http.Handler) http.Handler { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") } if r.Method == http.MethodOptions { diff --git a/apps/api/internal/runner/limits.go b/apps/api/internal/runner/limits.go index 1cc0744..94f7f21 100644 --- a/apps/api/internal/runner/limits.go +++ b/apps/api/internal/runner/limits.go @@ -2,13 +2,49 @@ package runner import ( "context" + "errors" "strings" + "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) +type localRateLimitError struct { + clientErr *clients.ClientError + cause error + retryAfter time.Duration +} + +func (e *localRateLimitError) Error() string { + if e == nil || e.clientErr == nil { + return store.ErrRateLimited.Error() + } + return e.clientErr.Error() +} + +func (e *localRateLimitError) Unwrap() []error { + if e == nil || e.clientErr == nil { + if e != nil && e.cause != nil { + return []error{e.cause} + } + return []error{store.ErrRateLimited} + } + if e.cause != nil { + return []error{e.clientErr, e.cause} + } + return []error{e.clientErr, store.ErrRateLimited} +} + +func localRateLimitRetryAfter(err error) time.Duration { + var limitErr *localRateLimitError + if errors.As(err, &limitErr) && limitErr.retryAfter > 0 { + return limitErr.retryAfter + } + return store.RateLimitRetryAfter(err) +} + func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation { out := make([]store.RateLimitReservation, 0) out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...) diff --git a/apps/api/internal/runner/queue_worker.go b/apps/api/internal/runner/queue_worker.go new file mode 100644 index 0000000..040d3ba --- /dev/null +++ b/apps/api/internal/runner/queue_worker.go @@ -0,0 +1,216 @@ +package runner + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivermigrate" + "github.com/riverqueue/river/rivertype" +) + +const asyncTaskQueueName = "gateway_tasks" + +type asyncTaskArgs struct { + TaskID string `json:"task_id" river:"unique"` +} + +func (asyncTaskArgs) Kind() string { return "gateway_task_run" } + +type asyncTaskWorker struct { + river.WorkerDefaults[asyncTaskArgs] + + service *Service +} + +func (w *asyncTaskWorker) Work(ctx context.Context, job *river.Job[asyncTaskArgs]) error { + task, err := w.service.store.GetTask(ctx, job.Args.TaskID) + if err != nil { + return err + } + if task.Status == "succeeded" || task.Status == "failed" || task.Status == "cancelled" { + return nil + } + result, runErr := w.service.Execute(ctx, task, authUserFromTask(task)) + if runErr == nil { + w.service.logger.Debug("river async task completed", "taskID", task.ID, "status", result.Task.Status, "riverJobID", job.ID) + return nil + } + var queuedErr *TaskQueuedError + if errors.As(runErr, &queuedErr) { + return river.JobSnooze(queuedErr.Delay) + } + if ctx.Err() != nil { + queued, queueErr := w.service.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task) + if queueErr != nil { + return queueErr + } + w.service.logger.Debug("river async task interrupted and requeued", "taskID", task.ID, "status", queued.Status, "riverJobID", job.ID) + return river.JobSnooze(0) + } + w.service.logger.Warn("river async task completed with failure", "taskID", task.ID, "error", runErr, "riverJobID", job.ID) + return nil +} + +func (s *Service) StartAsyncQueueWorker(ctx context.Context) { + if err := s.startRiverQueue(ctx); err != nil { + s.logger.Error("start river async queue failed", "error", err) + panic(err) + } +} + +func (s *Service) startRiverQueue(ctx context.Context) error { + driver := riverpgxv5.New(s.store.Pool()) + migrator, err := rivermigrate.New(driver, nil) + if err != nil { + return err + } + if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil { + return err + } + + workers := river.NewWorkers() + if err := river.AddWorkerSafely(workers, &asyncTaskWorker{service: s}); err != nil { + return err + } + riverClient, err := river.NewClient(driver, &river.Config{ + ID: asyncWorkerID(), + JobTimeout: -1, + Logger: s.logger, + CompletedJobRetentionPeriod: 24 * time.Hour, + Queues: map[string]river.QueueConfig{ + asyncTaskQueueName: {MaxWorkers: 32}, + }, + RescueStuckJobsAfter: 30 * time.Second, + TestOnly: s.cfg.AppEnv == "test", + Workers: workers, + }) + if err != nil { + return err + } + s.riverClient = riverClient + if err := riverClient.Start(ctx); err != nil { + return err + } + if err := s.recoverAsyncRiverJobs(ctx); err != nil { + return err + } + go func() { + <-ctx.Done() + stopCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := riverClient.StopAndCancel(stopCtx); err != nil { + s.logger.Warn("stop river async queue failed", "error", err) + } + }() + return nil +} + +func (s *Service) EnqueueAsyncTask(ctx context.Context, task store.GatewayTask) error { + if s.riverClient == nil { + return errors.New("river async queue is not started") + } + result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: task.ID}, asyncTaskInsertOpts(task)) + if err != nil { + return err + } + if result.Job != nil { + return s.store.SetTaskRiverJobID(ctx, task.ID, result.Job.ID) + } + return nil +} + +func (s *Service) WakeAsyncQueueAfter(ctx context.Context, delay time.Duration) { +} + +func (s *Service) RunAsyncTask(ctx context.Context, task store.GatewayTask, user *auth.User) { + if err := s.EnqueueAsyncTask(ctx, task); err != nil { + s.logger.Warn("enqueue river async task failed", "taskID", task.ID, "error", err) + } +} + +func (s *Service) recoverAsyncRiverJobs(ctx context.Context) error { + items, err := s.store.ListRecoverableAsyncTasks(ctx, 1000) + if err != nil { + return err + } + for _, item := range items { + task := store.GatewayTask{ID: item.ID} + result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: item.ID}, asyncTaskInsertOpts(task)) + if err != nil { + return err + } + if result.Job != nil { + if err := s.store.SetTaskRiverJobID(ctx, item.ID, result.Job.ID); err != nil { + return err + } + } + } + if len(items) > 0 { + s.logger.Info("river async queue recovered persisted tasks", "count", len(items)) + } + return nil +} + +func asyncTaskInsertOpts(task store.GatewayTask) *river.InsertOpts { + priority := 2 + if task.ID == "" { + priority = 3 + } + return &river.InsertOpts{ + MaxAttempts: 1000, + Priority: priority, + Queue: asyncTaskQueueName, + Tags: []string{"gateway-task"}, + UniqueOpts: river.UniqueOpts{ + ByArgs: true, + ByQueue: true, + ByState: []rivertype.JobState{ + rivertype.JobStateAvailable, + rivertype.JobStatePending, + rivertype.JobStateRetryable, + rivertype.JobStateRunning, + rivertype.JobStateScheduled, + }, + }, + } +} + +func authUserFromTask(task store.GatewayTask) *auth.User { + roles := []string{"user"} + if strings.TrimSpace(task.UserID) == "" { + roles = nil + } + return &auth.User{ + ID: firstNonEmptyString(task.GatewayUserID, task.UserID), + Roles: roles, + TenantID: task.TenantID, + GatewayTenantID: task.GatewayTenantID, + TenantKey: task.TenantKey, + Source: firstNonEmptyString(task.UserSource, "gateway"), + GatewayUserID: task.GatewayUserID, + UserGroupID: task.UserGroupID, + UserGroupKey: task.UserGroupKey, + APIKeyID: task.APIKeyID, + APIKeyName: task.APIKeyName, + APIKeyPrefix: task.APIKeyPrefix, + } +} + +func asyncWorkerID() string { + host, _ := os.Hostname() + host = strings.TrimSpace(host) + if host == "" { + host = "localhost" + } + return fmt.Sprintf("%s:%d:%d", host, os.Getpid(), time.Now().UnixNano()) +} + +var _ river.Worker[asyncTaskArgs] = (*asyncTaskWorker)(nil) diff --git a/apps/api/internal/runner/retry_decision.go b/apps/api/internal/runner/retry_decision.go index b32daaa..b9ac960 100644 --- a/apps/api/internal/runner/retry_decision.go +++ b/apps/api/internal/runner/retry_decision.go @@ -1,6 +1,7 @@ package runner import ( + "errors" "fmt" "strings" @@ -54,6 +55,9 @@ func shouldRetrySameClient(candidate store.RuntimeModelCandidate, err error) boo func retryDecisionForCandidate(candidate store.RuntimeModelCandidate, err error) retryDecision { policy := effectiveRetryPolicy(candidate) info := failureInfoFromError(err) + if errors.Is(err, store.ErrRateLimited) { + return retryDecision{Retry: false, Reason: "local_rate_limit_wait_queue", Match: policyRuleMatch{Source: "gateway_rate_limits", Policy: "rateLimitPolicy", Rule: "localCapacity", Value: "exceeded"}, Info: info} + } if !boolFromPolicy(policy, "enabled", true) { return retryDecision{Retry: false, Reason: "retry_disabled", Match: policyRuleMatch{Source: "model_runtime_policy_sets.retry_policy", Policy: "retryPolicy", Rule: "enabled", Value: "false"}, Info: info} } @@ -94,6 +98,9 @@ func failoverDecisionForCandidate(runnerPolicy store.RunnerPolicy, candidate sto if cooldownSeconds <= 0 { cooldownSeconds = 300 } + if errors.Is(err, store.ErrRateLimited) && store.RateLimitRetryable(err) { + return failoverDecision{Retry: true, Action: "next", Reason: "local_rate_limit_try_next_candidate", CooldownSeconds: cooldownSeconds, Match: policyRuleMatch{Source: "gateway_rate_limits", Policy: "rateLimitPolicy", Rule: "localCapacity", Value: "exceeded"}, Info: info} + } if match, ok := failoverAllowMatchWithSources(runnerPolicy.FailoverPolicy, overridePolicy, info); ok { return failoverDecision{Retry: true, Action: action, Reason: "failover_allow_policy", CooldownSeconds: cooldownSeconds, Match: match, Info: info} } diff --git a/apps/api/internal/runner/runtime_policy.go b/apps/api/internal/runner/runtime_policy.go index 0e0dc71..58ce14e 100644 --- a/apps/api/internal/runner/runtime_policy.go +++ b/apps/api/internal/runner/runtime_policy.go @@ -2,6 +2,7 @@ package runner import ( "context" + "errors" "strings" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" @@ -65,6 +66,9 @@ func (s *Service) applyFailoverAction(ctx context.Context, taskID string, candid } func (s *Service) applyPriorityDemotePolicy(ctx context.Context, taskID string, attemptNo int, runnerPolicy store.RunnerPolicy, candidate store.RuntimeModelCandidate, cause error, simulated bool) { + if errors.Is(cause, store.ErrRateLimited) { + return + } decision := priorityDemoteDecisionForCandidate(runnerPolicy, cause) if !decision.Demote { return diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index b468e72..dffba70 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -3,6 +3,7 @@ package runner import ( "context" "errors" + "fmt" "log/slog" "strconv" "strings" @@ -12,6 +13,8 @@ import ( "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" + "github.com/jackc/pgx/v5" + "github.com/riverqueue/river" ) type Service struct { @@ -20,6 +23,7 @@ type Service struct { logger *slog.Logger clients map[string]clients.Client httpClients *httpClientCache + riverClient *river.Client[pgx.Tx] } type Result struct { @@ -27,6 +31,20 @@ type Result struct { Output map[string]any } +var ErrTaskQueued = errors.New("task queued") + +type TaskQueuedError struct { + Delay time.Duration +} + +func (e *TaskQueuedError) Error() string { + return ErrTaskQueued.Error() +} + +func (e *TaskQueuedError) Is(target error) bool { + return target == ErrTaskQueued +} + func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service { httpClients := newHTTPClientCache() return &Service{ @@ -55,6 +73,14 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut executeStartedAt := time.Now() body := normalizeRequest(task.Kind, task.Request) modelType := modelTypeFromKind(task.Kind, body) + if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil { + return Result{}, err + } + if task.Status != "running" { + if err := s.emit(ctx, task.ID, "task.running", "running", "starting", 0.12, "task pulled from queue and started", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil { + return Result{}, err + } + } if err := validateRequest(task.Kind, body); err != nil { failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err) if finishErr != nil { @@ -83,9 +109,6 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut return Result{}, err } } - if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil { - return Result{}, err - } if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil { return Result{}, err } @@ -96,7 +119,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut } maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy) maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy) - attemptNo := 0 + attemptNo := task.AttemptCount var lastErr error for index, candidate := range candidates { if index >= maxPlatforms { @@ -251,6 +274,20 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut if lastErr != nil { message = lastErr.Error() } + if task.AsyncMode && ctx.Err() != nil { + queued, queueErr := s.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task) + if queueErr != nil { + return Result{}, queueErr + } + return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: 0} + } + if task.AsyncMode && errors.Is(lastErr, store.ErrRateLimited) && store.RateLimitRetryable(lastErr) { + queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr) + if queueErr != nil { + return Result{}, queueErr + } + return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay} + } failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr) if err != nil { return Result{}, err @@ -261,7 +298,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) { simulated := isSimulation(task, candidate) if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil { - return clients.Response{}, err + return clients.Response{}, fmt.Errorf("emit attempt started: %w", err) } attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{ TaskID: task.ID, @@ -276,21 +313,22 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user Metrics: attemptMetrics(candidate, attemptNo, simulated), }) if err != nil { - return clients.Response{}, err + return clients.Response{}, fmt.Errorf("create task attempt: %w", err) } reservations := s.rateLimitReservations(ctx, user, candidate, body) limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, attemptID, reservations) if err != nil { - clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: false} + retryable := store.RateLimitRetryable(err) + clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: retryable} _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", - Retryable: false, - Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(clientErr, false)}}), + Retryable: retryable, + Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": retryable, "retryAfterMs": localRateLimitRetryAfter(err).Milliseconds(), "trace": []any{failureTraceEntry(clientErr, retryable)}}), ErrorCode: "rate_limit", ErrorMessage: err.Error(), }) - return clients.Response{}, clientErr + return clients.Response{}, &localRateLimitError{clientErr: clientErr, cause: err, retryAfter: localRateLimitRetryAfter(err)} } rateReservationsFinalized := false defer func() { @@ -301,7 +339,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs) if err := s.store.RecordClientAssignment(ctx, candidate); err != nil { - return clients.Response{}, err + return clients.Response{}, fmt.Errorf("record client assignment: %w", err) } defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "") @@ -315,17 +353,25 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user ErrorCode: clients.ErrorCode(err), ErrorMessage: err.Error(), }) - return clients.Response{}, err + return clients.Response{}, fmt.Errorf("prepare http client: %w", err) } client := s.clientFor(candidate, simulated) callStartedAt := time.Now() response, err := client.Run(ctx, clients.Request{ - Kind: task.Kind, - ModelType: candidate.ModelType, - Model: task.Model, - Body: body, - Candidate: candidate, - HTTPClient: requestHTTPClient, + Kind: task.Kind, + ModelType: candidate.ModelType, + Model: task.Model, + Body: body, + Candidate: candidate, + HTTPClient: requestHTTPClient, + RemoteTaskID: task.RemoteTaskID, + RemoteTaskPayload: task.RemoteTaskPayload, + OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error { + if strings.TrimSpace(remoteTaskID) == "" { + return nil + } + return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload) + }, Stream: boolFromMap(body, "stream"), StreamDelta: onDelta, }) @@ -400,11 +446,11 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user response.Result = uploadedResult for _, progress := range response.Progress { if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil { - return clients.Response{}, err + return clients.Response{}, fmt.Errorf("emit task progress: %w", err) } } if err := s.store.CommitRateLimitReservations(ctx, limitResult.Reservations, tokenUsageAmounts(response.Usage)); err != nil { - return clients.Response{}, err + return clients.Response{}, fmt.Errorf("commit rate limit reservations: %w", err) } rateReservationsFinalized = true if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ @@ -418,7 +464,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user ResponseFinishedAt: response.ResponseFinishedAt, ResponseDurationMS: response.ResponseDurationMS, }); err != nil { - return clients.Response{}, err + return clients.Response{}, fmt.Errorf("finish task attempt: %w", err) } return response, nil } @@ -459,6 +505,41 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess return failed, nil } +func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error) (store.GatewayTask, time.Duration, error) { + delay := localRateLimitRetryAfter(cause) + if delay <= 0 { + delay = 5 * time.Second + } + queued, err := s.store.RequeueTask(ctx, task.ID, delay) + if err != nil { + return store.GatewayTask{}, 0, err + } + payload := map[string]any{ + "code": "rate_limit", + "message": cause.Error(), + "retryAfterMs": delay.Milliseconds(), + } + if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "rate_limited", 0.2, "task queued by local rate limit", payload, task.RunMode == "simulation"); eventErr != nil { + return store.GatewayTask{}, 0, eventErr + } + return queued, delay, nil +} + +func (s *Service) requeueInterruptedAsyncTask(ctx context.Context, task store.GatewayTask) (store.GatewayTask, error) { + queued, err := s.store.RequeueTask(ctx, task.ID, 0) + if err != nil { + return store.GatewayTask{}, err + } + payload := map[string]any{"code": "worker_interrupted"} + if task.RemoteTaskID != "" { + payload["remoteTaskId"] = task.RemoteTaskID + } + if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "worker_interrupted", 0.2, "async task queued after worker interruption", payload, task.RunMode == "simulation"); eventErr != nil { + return store.GatewayTask{}, eventErr + } + return queued, nil +} + func (s *Service) withAttemptHistory(ctx context.Context, taskID string, metrics map[string]any) map[string]any { attempts, err := s.store.ListTaskAttempts(ctx, taskID) if err != nil { diff --git a/apps/api/internal/store/candidates.go b/apps/api/internal/store/candidates.go index ff56b12..a01409b 100644 --- a/apps/api/internal/store/candidates.go +++ b/apps/api/internal/store/candidates.go @@ -19,7 +19,7 @@ SELECT p.id::text, p.platform_key, p.name, p.provider, COALESCE(p.dynamic_priority, p.priority) AS effective_priority, m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''), COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''), - $2 AS requested_model_type, m.display_name, m.capabilities, m.capability_override, + $2::text AS requested_model_type, m.display_name, m.capabilities, m.capability_override, COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override, m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''), COALESCE(b.pricing_rule_set_id::text, ''), @@ -33,21 +33,21 @@ LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provi LEFT JOIN base_model_catalog b ON b.id = m.base_model_id LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id) LEFT JOIN runtime_client_states s - ON s.client_id = p.platform_key || ':' || $2 || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name) + ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name) WHERE p.status = 'enabled' AND p.deleted_at IS NULL AND m.enabled = true - AND m.model_type @> jsonb_build_array($2) + AND m.model_type @> jsonb_build_array($2::text) AND (p.cooldown_until IS NULL OR p.cooldown_until <= now()) AND (m.cooldown_until IS NULL OR m.cooldown_until <= now()) AND ( - (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1) + (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text) OR ( COALESCE(m.model_alias, '') = '' AND ( - m.model_name = $1 - OR b.canonical_model_key = $1 - OR b.provider_model_name = $1 + m.model_name = $1::text + OR b.canonical_model_key = $1::text + OR b.provider_model_name = $1::text ) ) ) @@ -184,15 +184,15 @@ LEFT JOIN base_model_catalog b ON b.id = m.base_model_id WHERE p.status = 'enabled' AND p.deleted_at IS NULL AND m.enabled = true - AND m.model_type @> jsonb_build_array($2) + AND m.model_type @> jsonb_build_array($2::text) AND ( - (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1) + (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text) OR ( COALESCE(m.model_alias, '') = '' AND ( - m.model_name = $1 - OR b.canonical_model_key = $1 - OR b.provider_model_name = $1 + m.model_name = $1::text + OR b.canonical_model_key = $1::text + OR b.provider_model_name = $1::text ) ) ) diff --git a/apps/api/internal/store/postgres.go b/apps/api/internal/store/postgres.go index c308a10..73b1858 100644 --- a/apps/api/internal/store/postgres.go +++ b/apps/api/internal/store/postgres.go @@ -53,6 +53,10 @@ func (s *Store) Ping(ctx context.Context) error { return s.pool.Ping(ctx) } +func (s *Store) Pool() *pgxpool.Pool { + return s.pool +} + type Platform struct { ID string `json:"id"` Provider string `json:"provider"` @@ -374,6 +378,7 @@ type CreateTaskInput struct { Kind string `json:"kind"` Model string `json:"model"` RunMode string `json:"runMode"` + Async bool `json:"async"` Request map[string]any `json:"request"` } @@ -398,7 +403,12 @@ type GatewayTask struct { ResolvedModel string `json:"resolvedModel,omitempty"` RequestID string `json:"requestId,omitempty"` Request map[string]any `json:"request,omitempty"` + AsyncMode bool `json:"asyncMode"` + RiverJobID int64 `json:"riverJobId,omitempty"` Status string `json:"status"` + AttemptCount int `json:"attemptCount"` + RemoteTaskID string `json:"remoteTaskId,omitempty"` + RemoteTaskPayload map[string]any `json:"remoteTaskPayload,omitempty"` Result map[string]any `json:"result,omitempty"` Billings []any `json:"billings,omitempty"` Usage map[string]any `json:"usage"` @@ -423,7 +433,9 @@ COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_ COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''), COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model, COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''), -request, status, COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb), +request, COALESCE(async_mode, false), COALESCE(river_job_id, 0), status, COALESCE(attempt_count, 0), +COALESCE(remote_task_id, ''), COALESCE(remote_task_payload, '{}'::jsonb), +COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb), COALESCE(usage, '{}'::jsonb), COALESCE(metrics, '{}'::jsonb), COALESCE(billing_summary, '{}'::jsonb), COALESCE(final_charge_amount, 0)::float8, COALESCE(response_started_at::text, ''), COALESCE(response_finished_at::text, ''), COALESCE(response_duration_ms, 0), COALESCE(error, ''), @@ -1675,11 +1687,11 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut INSERT INTO gateway_tasks ( kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key, api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key, - model, requested_model, request, status, result, billings, finished_at + model, requested_model, request, async_mode, status, result, billings, finished_at ) - VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17::jsonb, $18::jsonb, CASE WHEN $19 THEN now() ELSE NULL END) + VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, CASE WHEN $20 THEN now() ELSE NULL END) RETURNING `+gatewayTaskColumns, - input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, status, resultBody, billingsBody, false, + input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, false, )) if err != nil { return GatewayTask{}, err @@ -1689,7 +1701,7 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut payload, _ := json.Marshal(event.Payload) if _, err := tx.Exec(ctx, ` INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) -VALUES ($1::uuid, $2, $3, NULLIF($4, ''), NULLIF($5, ''), $6, NULLIF($7, ''), $8::jsonb, $9)`, +VALUES ($1::uuid, $2, $3::text, NULLIF($4::text, ''), NULLIF($5::text, ''), $6, NULLIF($7::text, ''), $8::jsonb, $9)`, task.ID, event.Seq, event.EventType, event.Status, event.Phase, event.Progress, event.Message, string(payload), event.Simulated, ); err != nil { return GatewayTask{}, err @@ -1730,6 +1742,7 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) { var usageBytes []byte var metricsBytes []byte var billingSummaryBytes []byte + var remoteTaskPayloadBytes []byte if err := scanner.Scan( &task.ID, &task.Kind, @@ -1751,7 +1764,12 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) { &task.ResolvedModel, &task.RequestID, &requestBytes, + &task.AsyncMode, + &task.RiverJobID, &task.Status, + &task.AttemptCount, + &task.RemoteTaskID, + &remoteTaskPayloadBytes, &resultBytes, &billingsBytes, &usageBytes, @@ -1771,6 +1789,7 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) { return GatewayTask{}, err } task.Request = decodeObject(requestBytes) + task.RemoteTaskPayload = decodeObject(remoteTaskPayloadBytes) task.Result = decodeObject(resultBytes) task.Billings = decodeArray(billingsBytes) task.Usage = decodeObject(usageBytes) diff --git a/apps/api/internal/store/rate_limit_status.go b/apps/api/internal/store/rate_limit_status.go index ca177df..2bb0152 100644 --- a/apps/api/internal/store/rate_limit_status.go +++ b/apps/api/internal/store/rate_limit_status.go @@ -31,6 +31,7 @@ type ModelRateLimitStatus struct { PlatformCooldownUntil string `json:"platformCooldownUntil,omitempty"` ModelCooldownUntil string `json:"modelCooldownUntil,omitempty"` Concurrent RateLimitMetricStatus `json:"concurrent"` + QueuedTasks float64 `json:"queuedTasks"` RPM RateLimitMetricStatus `json:"rpm"` TPM RateLimitMetricStatus `json:"tpm"` LoadRatio float64 `json:"loadRatio"` @@ -38,15 +39,16 @@ type ModelRateLimitStatus struct { func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) { rows, err := s.pool.Query(ctx, ` -SELECT m.id::text, m.platform_id::text, p.name, p.provider, - m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''), - m.model_type, m.display_name, m.enabled, - p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy, - COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''), - COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''), - COALESCE(con.active, 0)::float8, - COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''), - COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '') + SELECT m.id::text, m.platform_id::text, p.name, p.provider, + m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''), + m.model_type, m.display_name, m.enabled, + p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy, + COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''), + COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''), + COALESCE(con.active, 0)::float8, + COALESCE(queued.waiting, 0)::float8, + COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''), + COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '') FROM platform_models m JOIN integration_platforms p ON p.id = m.platform_id LEFT JOIN base_model_catalog b ON b.id = m.base_model_id @@ -59,6 +61,19 @@ LEFT JOIN ( AND expires_at > now() GROUP BY scope_key ) con ON con.scope_key = m.id::text +LEFT JOIN ( + SELECT latest.platform_model_id, COUNT(*) AS waiting + FROM ( + SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id + FROM gateway_tasks t + JOIN gateway_task_attempts a ON a.task_id = t.id + WHERE t.async_mode = true + AND t.status = 'queued' + AND a.platform_model_id IS NOT NULL + ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC + ) latest + GROUP BY latest.platform_model_id +) queued ON queued.platform_model_id = m.id::text LEFT JOIN ( SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value, reset_at FROM gateway_rate_limit_counters @@ -93,6 +108,7 @@ ORDER BY p.priority ASC, m.model_name ASC`) var platformCooldownUntil string var modelCooldownUntil string var concurrentCurrent float64 + var queuedTasks float64 var rpmUsed float64 var rpmReserved float64 var rpmResetAt string @@ -117,6 +133,7 @@ ORDER BY p.priority ASC, m.model_name ASC`) &platformCooldownUntil, &modelCooldownUntil, &concurrentCurrent, + &queuedTasks, &rpmUsed, &rpmReserved, &rpmResetAt, @@ -136,6 +153,7 @@ ORDER BY p.priority ASC, m.model_name ASC`) item.PlatformCooldownUntil = platformCooldownUntil item.ModelCooldownUntil = modelCooldownUntil item.RateLimitPolicy = policy + item.QueuedTasks = queuedTasks item.Concurrent = metricStatus(concurrentCurrent, concurrentCurrent, 0, rateLimitForMetric(policy, "concurrent"), "") item.RPM = metricStatus(rpmUsed+rpmReserved, rpmUsed, rpmReserved, rateLimitForMetric(policy, "rpm"), rpmResetAt) item.TPM = metricStatus(tpmUsed+tpmReserved, tpmUsed, tpmReserved, tpmLimit(policy), tpmResetAt) diff --git a/apps/api/internal/store/rate_limits.go b/apps/api/internal/store/rate_limits.go index 7bea2cf..d00e9a7 100644 --- a/apps/api/internal/store/rate_limits.go +++ b/apps/api/internal/store/rate_limits.go @@ -3,6 +3,7 @@ package store import ( "context" "errors" + "fmt" "time" "github.com/jackc/pgx/v5" @@ -13,6 +14,7 @@ type RuntimeRecoveryResult struct { ReleasedRateReservations int64 `json:"releasedRateReservations"` FailedAttempts int64 `json:"failedAttempts"` FailedTasks int64 `json:"failedTasks"` + RequeuedAsyncTasks int64 `json:"requeuedAsyncTasks"` } func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID string, reservations []RateLimitReservation) (RateLimitResult, error) { @@ -28,7 +30,11 @@ func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID continue } if reservation.Metric == "" || reservation.Amount > reservation.Limit { - return RateLimitResult{}, ErrRateLimited + return RateLimitResult{}, &RateLimitExceededError{ + Metric: reservation.Metric, + Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit), + Retryable: false, + } } if reservation.WindowSeconds <= 0 { reservation.WindowSeconds = 60 @@ -55,8 +61,10 @@ func reserveConcurrencyLease(ctx context.Context, tx pgx.Tx, taskID string, atte reservation.LeaseTTLSeconds = 120 } var active float64 + var nextAvailableAt time.Time if err := tx.QueryRow(ctx, ` -SELECT COALESCE(SUM(lease_value), 0)::float8 +SELECT COALESCE(SUM(lease_value), 0)::float8, + COALESCE(MIN(expires_at), now() + ($3::int * interval '1 second')) FROM gateway_concurrency_leases WHERE scope_type = $1 AND scope_key = $2 @@ -64,11 +72,17 @@ WHERE scope_type = $1 AND expires_at > now()`, reservation.ScopeType, reservation.ScopeKey, - ).Scan(&active); err != nil { + reservation.LeaseTTLSeconds, + ).Scan(&active, &nextAvailableAt); err != nil { return "", err } if active+reservation.Amount > reservation.Limit { - return "", ErrRateLimited + return "", &RateLimitExceededError{ + Metric: reservation.Metric, + Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit), + RetryAfter: concurrencyRetryAfter(nextAvailableAt), + Retryable: true, + } } var leaseID string if err := tx.QueryRow(ctx, ` @@ -92,13 +106,16 @@ func reserveCounterWindow(ctx context.Context, tx pgx.Tx, taskID string, attempt reservedAmount := reservation.Amount var windowStart time.Time err := tx.QueryRow(ctx, ` +WITH bounds AS ( + SELECT + to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) AS window_start, + to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) + ($7::int * interval '1 second') AS reset_at +) INSERT INTO gateway_rate_limit_counters ( scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at ) -VALUES ( - $1, $2, $3, date_trunc('minute', now()), $4, $5, $6, - date_trunc('minute', now()) + ($7::int * interval '1 second') -) +SELECT $1, $2, $3, bounds.window_start, $4, $5, $6, bounds.reset_at +FROM bounds ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE SET limit_value = EXCLUDED.limit_value, used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value, @@ -117,7 +134,28 @@ RETURNING window_start`, ).Scan(&windowStart) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return RateLimitReservation{}, ErrRateLimited + resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second) + _ = tx.QueryRow(ctx, ` +WITH bounds AS ( + SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start +) +SELECT counters.reset_at +FROM gateway_rate_limit_counters counters +JOIN bounds ON counters.window_start = bounds.window_start +WHERE scope_type = $1 + AND scope_key = $2 + AND metric = $3`, + reservation.ScopeType, + reservation.ScopeKey, + reservation.Metric, + reservation.WindowSeconds, + ).Scan(&resetAt) + return RateLimitReservation{}, &RateLimitExceededError{ + Metric: reservation.Metric, + Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric), + RetryAfter: retryAfterUntil(resetAt), + Retryable: true, + } } return RateLimitReservation{}, err } @@ -144,6 +182,28 @@ RETURNING id::text`, return reservation, nil } +func retryAfterUntil(when time.Time) time.Duration { + if when.IsZero() { + return 0 + } + duration := time.Until(when) + if duration < time.Second { + return time.Second + } + return duration +} + +func concurrencyRetryAfter(leaseExpiresAt time.Time) time.Duration { + if leaseExpiresAt.IsZero() { + return time.Second + } + duration := time.Until(leaseExpiresAt) + if duration <= time.Second { + return time.Second + } + return time.Second +} + func (s *Store) CommitRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64) error { return s.finishRateLimitReservations(ctx, reservations, actualByMetric, "committed", "success") } @@ -189,23 +249,26 @@ RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`) if err != nil { return RuntimeRecoveryResult{}, err } + reservations := make([]RateLimitReservation, 0) for rows.Next() { var reservation RateLimitReservation if err := rows.Scan(&reservation.ScopeType, &reservation.ScopeKey, &reservation.Metric, &reservation.WindowStart, &reservation.Amount); err != nil { rows.Close() return RuntimeRecoveryResult{}, err } - if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil { - rows.Close() - return RuntimeRecoveryResult{}, err - } - result.ReleasedRateReservations++ + reservations = append(reservations, reservation) } if err := rows.Err(); err != nil { rows.Close() return RuntimeRecoveryResult{}, err } rows.Close() + for _, reservation := range reservations { + if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil { + return RuntimeRecoveryResult{}, err + } + } + result.ReleasedRateReservations = int64(len(reservations)) tag, err := tx.Exec(ctx, ` UPDATE gateway_concurrency_leases @@ -220,7 +283,7 @@ WHERE released_at IS NULL tag, err = tx.Exec(ctx, ` UPDATE gateway_task_attempts SET status = 'failed', - retryable = false, + retryable = true, error_code = 'server_restarted', error_message = 'attempt interrupted by service restart', finished_at = now() @@ -230,6 +293,57 @@ WHERE status = 'running'`) } result.FailedAttempts = tag.RowsAffected() + asyncTaskRows, err := tx.Query(ctx, ` +UPDATE gateway_tasks +SET status = 'queued', + error = NULL, + error_code = NULL, + error_message = NULL, + locked_by = NULL, + locked_at = NULL, + heartbeat_at = NULL, + next_run_at = now(), + finished_at = NULL, + updated_at = now() +WHERE async_mode = true + AND status = 'running' +RETURNING id::text`) + if err != nil { + return RuntimeRecoveryResult{}, err + } + asyncTaskIDs := make([]string, 0) + for asyncTaskRows.Next() { + var taskID string + if err := asyncTaskRows.Scan(&taskID); err != nil { + asyncTaskRows.Close() + return RuntimeRecoveryResult{}, err + } + asyncTaskIDs = append(asyncTaskIDs, taskID) + } + if err := asyncTaskRows.Err(); err != nil { + asyncTaskRows.Close() + return RuntimeRecoveryResult{}, err + } + asyncTaskRows.Close() + for _, taskID := range asyncTaskIDs { + if _, err := tx.Exec(ctx, ` +INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) +VALUES ( + $1::uuid, + COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1), + 'task.recovered', + 'queued', + 'recovered', + 0.2, + 'async task recovered after service restart', + '{"code":"server_restarted"}'::jsonb, + false +)`, taskID); err != nil { + return RuntimeRecoveryResult{}, err + } + } + result.RequeuedAsyncTasks = int64(len(asyncTaskIDs)) + taskRows, err := tx.Query(ctx, ` UPDATE gateway_tasks SET status = 'failed', @@ -238,7 +352,8 @@ SET status = 'failed', error_message = 'task interrupted by service restart', finished_at = now(), updated_at = now() -WHERE status IN ('queued', 'running') +WHERE async_mode = false + AND status = 'running' RETURNING id::text`) if err != nil { return RuntimeRecoveryResult{}, err @@ -301,9 +416,9 @@ func (s *Store) finishRateLimitReservations(ctx context.Context, reservations [] var stored RateLimitReservation err := tx.QueryRow(ctx, ` UPDATE gateway_rate_limit_reservations -SET status = $2, - reason = NULLIF($3, ''), - actual_amount = CASE WHEN $2 = 'committed' THEN $4 ELSE actual_amount END, +SET status = $2::text, + reason = NULLIF($3::text, ''), + actual_amount = CASE WHEN $2::text = 'committed' THEN $4 ELSE actual_amount END, finalized_at = now(), updated_at = now() WHERE id = $1::uuid diff --git a/apps/api/internal/store/runtime_types.go b/apps/api/internal/store/runtime_types.go index f08b848..72d4ffc 100644 --- a/apps/api/internal/store/runtime_types.go +++ b/apps/api/internal/store/runtime_types.go @@ -32,6 +32,43 @@ func ModelCandidateErrorCode(err error) string { return "no_model_candidate" } +type RateLimitExceededError struct { + Metric string + Message string + RetryAfter time.Duration + Retryable bool +} + +func (e *RateLimitExceededError) Error() string { + if strings.TrimSpace(e.Message) != "" { + return e.Message + } + if strings.TrimSpace(e.Metric) != "" { + return "rate limit exceeded for " + e.Metric + } + return ErrRateLimited.Error() +} + +func (e *RateLimitExceededError) Unwrap() error { + return ErrRateLimited +} + +func RateLimitRetryAfter(err error) time.Duration { + var limitErr *RateLimitExceededError + if errors.As(err, &limitErr) && limitErr.RetryAfter > 0 { + return limitErr.RetryAfter + } + return 0 +} + +func RateLimitRetryable(err error) bool { + var limitErr *RateLimitExceededError + if errors.As(err, &limitErr) { + return limitErr.Retryable + } + return errors.Is(err, ErrRateLimited) +} + type CreatePlatformModelInput struct { PlatformID string `json:"platformId"` BaseModelID string `json:"baseModelId"` @@ -132,6 +169,11 @@ type CreateTaskAttemptInput struct { Metrics map[string]any } +type AsyncTaskQueueItem struct { + ID string + Priority int +} + type FinishTaskAttemptInput struct { AttemptID string Status string diff --git a/apps/api/internal/store/tasks_runtime.go b/apps/api/internal/store/tasks_runtime.go index ec88d2d..a3a1f42 100644 --- a/apps/api/internal/store/tasks_runtime.go +++ b/apps/api/internal/store/tasks_runtime.go @@ -157,7 +157,7 @@ func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType st _, err := s.pool.Exec(ctx, ` UPDATE gateway_tasks SET status = 'running', - model_type = NULLIF($2, ''), + model_type = NULLIF($2::text, ''), normalized_request = $3::jsonb, locked_at = now(), heartbeat_at = now(), @@ -166,6 +166,124 @@ WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON)) return err } +func (s *Store) ClaimAsyncQueuedTask(ctx context.Context, workerID string) (GatewayTask, error) { + return scanGatewayTask(s.pool.QueryRow(ctx, ` +WITH picked AS ( + SELECT id AS task_id + FROM gateway_tasks + WHERE async_mode = true + AND status = 'queued' + AND next_run_at <= now() + ORDER BY priority ASC, created_at ASC + LIMIT 1 + FOR UPDATE SKIP LOCKED +) +UPDATE gateway_tasks t +SET status = 'running', + locked_by = NULLIF($1::text, ''), + locked_at = now(), + heartbeat_at = now(), + updated_at = now() +FROM picked +WHERE t.id = picked.task_id +RETURNING `+gatewayTaskColumns, workerID)) +} + +func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration) (GatewayTask, error) { + if delay < time.Second { + delay = time.Second + } + if delay > 10*time.Minute { + delay = 10 * time.Minute + } + nextRunAt := time.Now().Add(delay) + return scanGatewayTask(s.pool.QueryRow(ctx, ` +UPDATE gateway_tasks +SET status = 'queued', + locked_by = NULL, + locked_at = NULL, + heartbeat_at = NULL, + next_run_at = $2::timestamptz, + error = NULL, + error_code = NULL, + error_message = NULL, + updated_at = now() +WHERE id = $1::uuid +RETURNING `+gatewayTaskColumns, taskID, nextRunAt)) +} + +func (s *Store) SetTaskRiverJobID(ctx context.Context, taskID string, riverJobID int64) error { + if riverJobID <= 0 { + return nil + } + _, err := s.pool.Exec(ctx, ` +UPDATE gateway_tasks +SET river_job_id = $2, + updated_at = now() +WHERE id = $1::uuid`, taskID, riverJobID) + return err +} + +func (s *Store) SetTaskRemoteTask(ctx context.Context, taskID string, attemptID string, remoteTaskID string, payload map[string]any) error { + payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload)) + return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error { + if _, err := tx.Exec(ctx, ` +UPDATE gateway_tasks +SET remote_task_id = NULLIF($2::text, ''), + remote_task_payload = $3::jsonb, + updated_at = now() +WHERE id = $1::uuid`, + taskID, + remoteTaskID, + string(payloadJSON), + ); err != nil { + return err + } + if strings.TrimSpace(attemptID) == "" { + return nil + } + _, err := tx.Exec(ctx, ` +UPDATE gateway_task_attempts +SET remote_task_id = NULLIF($2::text, ''), + response_snapshot = COALESCE(response_snapshot, '{}'::jsonb) || jsonb_build_object('remote_task_payload', $3::jsonb) +WHERE id = $1::uuid`, + attemptID, + remoteTaskID, + string(payloadJSON), + ) + return err + }) +} + +func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) { + if limit <= 0 { + limit = 500 + } + rows, err := s.pool.Query(ctx, ` +SELECT id::text, priority +FROM gateway_tasks +WHERE async_mode = true + AND status IN ('queued', 'running') +ORDER BY priority ASC, created_at ASC +LIMIT $1`, limit) + if err != nil { + return nil, err + } + defer rows.Close() + items := make([]AsyncTaskQueueItem, 0) + for rows.Next() { + var item AsyncTaskQueueItem + if err := rows.Scan(&item.ID, &item.Priority); err != nil { + return nil, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) { requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot)) metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) @@ -182,7 +300,7 @@ INSERT INTO gateway_task_attempts ( status, simulated, request_snapshot, metrics ) VALUES ( - $1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6, + $1::uuid, $2::int, NULLIF($3::text, '')::uuid, NULLIF($4::text, '')::uuid, NULLIF($5::text, ''), $6, $7, $8, $9::jsonb, $10::jsonb ) RETURNING id::text`, @@ -202,7 +320,7 @@ RETURNING id::text`, } if _, err := tx.Exec(ctx, ` UPDATE gateway_tasks -SET attempt_count = GREATEST(attempt_count, $2), updated_at = now() +SET attempt_count = GREATEST(attempt_count, $2::int), updated_at = now() WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil { return "", err } @@ -252,7 +370,7 @@ SET metrics = jsonb_set( true ) WHERE task_id = $1::uuid - AND attempt_no = $2`, taskID, attemptNo, string(entryJSON)) + AND attempt_no = $2::int`, taskID, attemptNo, string(entryJSON)) return err } @@ -386,17 +504,17 @@ func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptIn metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) _, err := s.pool.Exec(ctx, ` UPDATE gateway_task_attempts -SET status = $2, +SET status = $2::text, retryable = $3, - request_id = NULLIF($4, ''), + request_id = NULLIF($4::text, ''), usage = $5::jsonb, metrics = $6::jsonb, response_snapshot = $7::jsonb, response_started_at = $8::timestamptz, response_finished_at = $9::timestamptz, response_duration_ms = $10, - error_code = NULLIF($11, ''), - error_message = NULLIF($12, ''), + error_code = NULLIF($11::text, ''), + error_message = NULLIF($12::text, ''), finished_at = now() WHERE id = $1::uuid`, input.AttemptID, @@ -438,6 +556,9 @@ SET status = 'succeeded', error = NULL, error_code = NULL, error_message = NULL, + locked_by = NULL, + locked_at = NULL, + heartbeat_at = NULL, finished_at = now(), updated_at = now() WHERE id = $1::uuid`, @@ -561,14 +682,17 @@ func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureIn if _, err := s.pool.Exec(ctx, ` UPDATE gateway_tasks SET status = 'failed', - error = NULLIF($2, ''), - error_code = NULLIF($3, ''), - error_message = NULLIF($2, ''), - request_id = NULLIF($4, ''), + error = NULLIF($2::text, ''), + error_code = NULLIF($3::text, ''), + error_message = NULLIF($2::text, ''), + request_id = NULLIF($4::text, ''), metrics = $5::jsonb, response_started_at = $6::timestamptz, response_finished_at = $7::timestamptz, response_duration_ms = $8, + locked_by = NULL, + locked_at = NULL, + heartbeat_at = NULL, finished_at = now(), updated_at = now() WHERE id = $1::uuid`, @@ -604,7 +728,7 @@ WITH next_seq AS ( WHERE task_id = $1::uuid ) INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) -SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8 +SELECT $1::uuid, next_seq.seq, $2::text, NULLIF($3::text, ''), NULLIF($4::text, ''), $5, NULLIF($6::text, ''), $7::jsonb, $8 FROM next_seq RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''), COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`, @@ -688,7 +812,7 @@ func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastEr _, err := s.pool.Exec(ctx, ` UPDATE runtime_client_states SET running_count = GREATEST(running_count - 1, 0), - last_error = NULLIF($2, ''), + last_error = NULLIF($2::text, ''), updated_at = now() WHERE client_id = $1`, clientID, lastError) return err diff --git a/apps/api/migrations/0031_async_task_queue.sql b/apps/api/migrations/0031_async_task_queue.sql new file mode 100644 index 0000000..2abfd41 --- /dev/null +++ b/apps/api/migrations/0031_async_task_queue.sql @@ -0,0 +1,6 @@ +ALTER TABLE IF EXISTS gateway_tasks + ADD COLUMN IF NOT EXISTS async_mode boolean NOT NULL DEFAULT false; + +CREATE INDEX IF NOT EXISTS idx_gateway_tasks_async_queue + ON gateway_tasks(async_mode, status, next_run_at, priority, created_at) + WHERE async_mode = true; diff --git a/apps/api/migrations/0032_river_async_queue.sql b/apps/api/migrations/0032_river_async_queue.sql new file mode 100644 index 0000000..f1d307d --- /dev/null +++ b/apps/api/migrations/0032_river_async_queue.sql @@ -0,0 +1,19 @@ +ALTER TABLE IF EXISTS gateway_tasks + ADD COLUMN IF NOT EXISTS river_job_id bigint, + ADD COLUMN IF NOT EXISTS remote_task_id text, + ADD COLUMN IF NOT EXISTS remote_task_payload jsonb; + +UPDATE gateway_tasks +SET remote_task_payload = '{}'::jsonb +WHERE remote_task_payload IS NULL; + +ALTER TABLE IF EXISTS gateway_tasks + ALTER COLUMN remote_task_payload SET DEFAULT '{}'::jsonb; + +CREATE INDEX IF NOT EXISTS idx_gateway_tasks_river_job + ON gateway_tasks(river_job_id) + WHERE river_job_id IS NOT NULL; + +CREATE INDEX IF NOT EXISTS idx_gateway_tasks_async_recover + ON gateway_tasks(async_mode, status, priority, created_at) + WHERE async_mode = true AND status IN ('queued', 'running'); diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx index 6d7d557..0b2bcc4 100644 --- a/apps/web/src/App.tsx +++ b/apps/web/src/App.tsx @@ -46,7 +46,6 @@ import { getHealth, getNetworkProxyConfig, getRunnerPolicy, - getTask, getWalletSummary, listAccessRules, listAuditLogs, @@ -72,6 +71,7 @@ import { listUserGroups, listUsers, loginLocalAccount, + pollTaskUntilSettled, registerLocalAccount, replacePlatformModels, setUserWalletBalance, @@ -788,7 +788,11 @@ export function App() { setCoreMessage(''); try { const response = await runTask(credential, taskForm); - const detail = await getTask(credential, response.task.id); + const syncTask = (detail: GatewayTask) => { + setTaskResult(detail); + setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]); + }; + const detail = await pollTaskUntilSettled(credential, response.task, { onUpdate: syncTask }); setTaskResult(detail); setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]); invalidateDataKeys('tasks', 'wallet', 'walletTransactions'); diff --git a/apps/web/src/api.ts b/apps/web/src/api.ts index 0d96d30..7d256f9 100644 --- a/apps/web/src/api.ts +++ b/apps/web/src/api.ts @@ -521,6 +521,7 @@ export async function createChatTask( ): Promise<{ task: GatewayTask; next: Record }> { return request<{ task: GatewayTask; next: Record }>('/api/v1/chat/completions', { body: input, + headers: { 'X-Async': 'true' }, method: 'POST', token, }); @@ -598,6 +599,7 @@ export async function createImageGenerationTask( ): Promise<{ task: GatewayTask; next: Record }> { return request<{ task: GatewayTask; next: Record }>('/api/v1/images/generations', { body: input, + headers: { 'X-Async': 'true' }, method: 'POST', token, }); @@ -609,6 +611,7 @@ export async function createImageEditTask( ): Promise<{ task: GatewayTask; next: Record }> { return request<{ task: GatewayTask; next: Record }>('/api/v1/images/edits', { body: input, + headers: { 'X-Async': 'true' }, method: 'POST', token, }); @@ -636,6 +639,7 @@ export async function createVideoGenerationTask( ): Promise<{ task: GatewayTask; next: Record }> { return request<{ task: GatewayTask; next: Record }>('/api/v1/videos/generations', { body: input, + headers: { 'X-Async': 'true' }, method: 'POST', token, }); @@ -656,6 +660,33 @@ export async function getTask(token: string, taskId: string): Promise(`/api/workspace/tasks/${taskId}`, { token }); } +export async function pollTaskUntilSettled( + token: string, + task: GatewayTask, + options: { intervalMs?: number; maxAttempts?: number | null; onUpdate?: (task: GatewayTask) => void } = {}, +): Promise { + let detail = task; + const intervalMs = options.intervalMs ?? 1200; + const maxAttempts = options.maxAttempts ?? Number.POSITIVE_INFINITY; + for (let attempt = 0; attempt < maxAttempts; attempt += 1) { + if (!taskIsPending(detail.status)) return detail; + try { + detail = await getTask(token, detail.id); + options.onUpdate?.(detail); + if (!taskIsPending(detail.status)) return detail; + } catch { + // Backend restarts or short network gaps should not turn a durable task into a failed UI run. + // Only an explicit terminal task status from the task detail endpoint settles the run. + } + await delay(intervalMs); + } + return detail; +} + +export function taskIsPending(status: string) { + return status === 'queued' || status === 'running' || status === 'submitting'; +} + export async function listTasks(token: string, query: WorkspaceTaskQuery): Promise> { const search = new URLSearchParams({ page: String(query.page), @@ -707,9 +738,9 @@ export async function getNetworkProxyConfig(token: string): Promise( path: string, - options: { token?: string; auth?: boolean; method?: string; body?: unknown } = {}, + options: { token?: string; auth?: boolean; method?: string; body?: unknown; headers?: Record } = {}, ): Promise { - const headers: Record = {}; + const headers: Record = { ...(options.headers ?? {}) }; if (options.auth !== false && options.token) { headers.Authorization = `Bearer ${options.token}`; } @@ -731,6 +762,10 @@ async function request( return response.json() as Promise; } +function delay(ms: number) { + return new Promise((resolve) => window.setTimeout(resolve, ms)); +} + function parseErrorMessage(body: string) { return formatGatewayErrorDetails(parseErrorDetails(body)); } diff --git a/apps/web/src/pages/PlaygroundPage.tsx b/apps/web/src/pages/PlaygroundPage.tsx index 2c93f12..8ef1a34 100644 --- a/apps/web/src/pages/PlaygroundPage.tsx +++ b/apps/web/src/pages/PlaygroundPage.tsx @@ -21,7 +21,7 @@ import { mermaid } from '@streamdown/mermaid'; import type { GatewayApiKey, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts'; import { Bot, ChevronDown, Image as ImageIcon, MessageSquarePlus, Paperclip, Send, Sparkles, Video } from 'lucide-react'; import { Badge, Button, Select, Textarea } from '../components/ui'; -import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, getTask, streamChatCompletionText } from '../api'; +import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, pollTaskUntilSettled, streamChatCompletionText, taskIsPending } from '../api'; import type { PlaygroundMode } from '../types'; import { defaultMediaGenerationSettings, @@ -170,14 +170,12 @@ export function PlaygroundPage(props: { resumableRuns.forEach((run) => { if (!run.task?.id) return; resumedTaskIdsRef.current.add(run.task.id); - void pollTaskUntilSettled(credential, run.task) + void pollTaskUntilSettled(credential, run.task, { + onUpdate: (detail) => updateMediaRunFromTask(run.localId, detail), + }) .then((detail) => { if (!isMountedRef.current) return; - setMediaRuns((current) => updateMediaRun(current, run.localId, { - error: gatewayTaskErrorText(detail, '任务执行失败'), - status: detail.status, - task: detail, - })); + updateMediaRunFromTask(run.localId, detail); }) .catch((err) => { if (!isMountedRef.current) return; @@ -249,13 +247,11 @@ export function PlaygroundPage(props: { async function pollMediaRunUntilSettled(credential: string, localId: string, task: GatewayTask) { try { - const detail = await pollTaskUntilSettled(credential, task); + const detail = await pollTaskUntilSettled(credential, task, { + onUpdate: (nextTask) => updateMediaRunFromTask(localId, nextTask), + }); if (!isMountedRef.current) return; - setMediaRuns((current) => updateMediaRun(current, localId, { - error: gatewayTaskErrorText(detail, '任务执行失败'), - status: detail.status, - task: detail, - })); + updateMediaRunFromTask(localId, detail); } catch (err) { if (!isMountedRef.current) return; const errorMessage = err instanceof Error ? err.message : '任务状态同步失败'; @@ -263,6 +259,15 @@ export function PlaygroundPage(props: { } } + function updateMediaRunFromTask(localId: string, task: GatewayTask) { + if (!isMountedRef.current) return; + setMediaRuns((current) => updateMediaRun(current, localId, { + error: taskIsPending(task.status) ? '' : gatewayTaskErrorText(task, '任务执行失败'), + status: task.status, + task, + })); + } + function editMediaRun(run: MediaGenerationRun) { setPrompt(run.prompt); setMediaSettings(run.settings); @@ -1042,20 +1047,6 @@ function updateMediaRun(runs: MediaGenerationRun[], localId: string, patch: Part return runs.map((run) => run.localId === localId ? { ...run, ...patch } : run); } -async function pollTaskUntilSettled(token: string, task: GatewayTask) { - let detail = task; - for (let attempt = 0; attempt < 20; attempt += 1) { - detail = await getTask(token, detail.id); - if (!taskIsPending(detail.status)) return detail; - await delay(1200); - } - return detail; -} - -function taskIsPending(status: string) { - return status === 'queued' || status === 'running' || status === 'submitting'; -} - function readStoredMediaRuns(): MediaGenerationRun[] { if (typeof window === 'undefined') return []; try { @@ -1182,10 +1173,6 @@ function booleanFromUnknown(value: unknown, fallback: boolean) { return fallback; } -function delay(ms: number) { - return new Promise((resolve) => window.setTimeout(resolve, ms)); -} - function newLocalId() { return typeof crypto !== 'undefined' && 'randomUUID' in crypto ? crypto.randomUUID() diff --git a/apps/web/src/pages/admin/PlatformManagementPanel.tsx b/apps/web/src/pages/admin/PlatformManagementPanel.tsx index ede72f9..c01a13e 100644 --- a/apps/web/src/pages/admin/PlatformManagementPanel.tsx +++ b/apps/web/src/pages/admin/PlatformManagementPanel.tsx @@ -593,11 +593,14 @@ function RateLimitStatusTable(props: { statuses: ModelRateLimitStatus[]; platfor 模型 平台 - 并发 - TPM - RPM - 状态 - 满载率 + + 并发 + 正在执行 / 并发 / 排队 + + TPM + RPM + 状态 + 满载率 {props.statuses.map((status) => { const platform = props.platformMap.get(status.platformId); @@ -615,12 +618,12 @@ function RateLimitStatusTable(props: { statuses: ModelRateLimitStatus[]; platfor {status.provider} - {metricCell(status.concurrent)} - {metricCell(status.tpm, true)} - {metricCell(status.rpm)} - {modelRuntimeStatusCell(status, props.now)} - - + {concurrencyMetricCell(status)} + {metricCell(status.tpm, true)} + {metricCell(status.rpm)} + {modelRuntimeStatusCell(status, props.now)} + + 0.8 ? 'true' : undefined}> {formatPercent(status.loadRatio)} @@ -1210,6 +1213,16 @@ function metricCell(metric: ModelRateLimitStatus['rpm'], includeReserved = false ); } +function concurrencyMetricCell(status: ModelRateLimitStatus) { + const queuedTasks = status.queuedTasks ?? 0; + const limitText = status.concurrent.limited ? formatLimit(status.concurrent.limitValue) : '不限'; + return ( + + {formatLimit(status.concurrent.currentValue)} / {limitText} / {formatLimit(queuedTasks)} + + ); +} + function reservedMetricText(metric: ModelRateLimitStatus['rpm']) { return `已结算 ${formatLimit(metric.usedValue)} + 预占 ${formatLimit(metric.reservedValue)}`; } diff --git a/apps/web/src/styles/pages.css b/apps/web/src/styles/pages.css index ed795ae..2a242c0 100644 --- a/apps/web/src/styles/pages.css +++ b/apps/web/src/styles/pages.css @@ -1016,27 +1016,65 @@ } .platformLimitTable .shTableRow { - grid-template-columns: clamp(150px, 16vw, 220px) minmax(148px, 1fr) minmax(104px, max-content) minmax(136px, max-content) minmax(122px, max-content) minmax(132px, max-content) minmax(128px, max-content); - min-width: 920px; + grid-template-columns: minmax(180px, 1.15fr) minmax(160px, 0.95fr) 150px 170px 140px 132px 132px; + min-width: 1064px; } .platformLimitTable .shTableHead, .platformLimitTable .shTableCell { + display: grid; + align-content: center; + min-height: 68px; padding-right: 10px; padding-left: 10px; } +.platformLimitTable .shTableHead { + min-height: 74px; +} + +.platformLimitMetricHead { + display: grid; + align-content: center; + gap: 3px; + white-space: normal; +} + +.platformLimitMetricHead small { + overflow: hidden; + color: var(--muted-foreground); + font-size: var(--font-size-xs); + font-weight: var(--font-weight-medium); + line-height: 1.2; + text-overflow: ellipsis; + white-space: nowrap; +} + +.platformLimitNumberHead, +.platformLimitNumberCell { + justify-items: start; +} + +.platformLimitStatusHead, +.platformLimitStatusCell { + justify-items: start; +} + .rateMetricCell, .rateLoadCell { display: grid; min-width: 0; gap: 4px; + width: 100%; + align-content: start; + font-variant-numeric: tabular-nums; } .rateMetricCell strong, .rateLoadCell strong { color: var(--text-strong); font-size: var(--font-size-sm); + line-height: 1.25; } .rateMetricCell small { @@ -1050,6 +1088,7 @@ .rateLoadTrack { display: block; height: 6px; + width: min(112px, 100%); overflow: hidden; border-radius: 999px; background: #eef2f6; @@ -1062,6 +1101,18 @@ background: #0f766e; } +.rateLoadCell[data-overloaded="true"] strong { + color: var(--destructive); +} + +.rateLoadCell[data-overloaded="true"] .rateLoadTrack { + background: #fee2e2; +} + +.rateLoadCell[data-overloaded="true"] .rateLoadTrack i { + background: var(--destructive); +} + .platformModelToolbar { display: grid; grid-template-columns: minmax(220px, 0.6fr) minmax(260px, 1fr); diff --git a/go.work b/go.work index 35bdae7..12fa579 100644 --- a/go.work +++ b/go.work @@ -1,5 +1,3 @@ go 1.23 -use ( - ./apps/api -) +use ./apps/api diff --git a/go.work.sum b/go.work.sum index 7033131..ed3515e 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,16 +1,12 @@ -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= -github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= -github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= -github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= -github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= -github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/packages/contracts/src/index.ts b/packages/contracts/src/index.ts index 2f47344..fdf801d 100644 --- a/packages/contracts/src/index.ts +++ b/packages/contracts/src/index.ts @@ -775,6 +775,7 @@ export interface ModelRateLimitStatus { platformCooldownUntil?: string; modelCooldownUntil?: string; concurrent: RateLimitMetricStatus; + queuedTasks: number; rpm: RateLimitMetricStatus; tpm: RateLimitMetricStatus; loadRatio: number; @@ -808,6 +809,7 @@ export interface GatewayTask { resolvedModel?: string; requestId?: string; request?: Record; + asyncMode?: boolean; status: 'queued' | 'running' | 'succeeded' | 'failed' | 'cancelled' | string; result?: Record; billings?: unknown[];